Skip to content

Commit fa93442

Browse files
author
Michael Abbott
committed
squash PR 1407, eleven commits, 2020
1 parent 5483a12 commit fa93442

File tree

8 files changed

+99
-179
lines changed

8 files changed

+99
-179
lines changed

src/Flux.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ using CUDA
3333
const use_cuda = Ref(false)
3434

3535
include("utils.jl")
36-
include("zeros.jl")
3736
include("onehot.jl")
3837
include("functor.jl")
3938

src/deprecations.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,18 @@
33
@deprecate InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing)
44
@deprecate BatchNorm(λ, β, γ, μ, σ², ϵ, momentum) BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing)
55
@deprecate GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum) GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum, nothing)
6+
67
@deprecate outdims(f, inputsize) outputsize(f, inputsize)
7-
@deprecate Conv(; weight, bias, activation=identity, kws...) Conv(weight, bias, activation; kws...)
8-
@deprecate ConvTranspose(; weight, bias, activation=identity, kws...) ConvTranspose(weight, bias, activation; kws...)
9-
@deprecate DepthwiseConv(; weight, bias, activation=identity, kws...) DepthwiseConv(weight, bias, activation; kws...)
8+
@deprecate Conv(; weight, bias, activation=identity, kws...) Conv(weight, bias, activation; kws...)
9+
@deprecate ConvTranspose(; weight, bias, activation=identity, kws...) ConvTranspose(weight, bias, activation; kws...)
10+
@deprecate DepthwiseConv(; weight, bias, activation=identity, kws...) DepthwiseConv(weight, bias, activation; kws...)
11+
12+
13+
# Was used both Dense(10, 2, initb = Zeros) and Dense(rand(2,10), Zeros())
14+
struct Zeros
15+
function Zeros()
16+
@warn "Zeros() is deprecated, please simply use bias=false instead" maxlog=3
17+
false
18+
end
19+
end
20+
Zeros(args...) = Zeros()

src/layers/basic.jl

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -67,69 +67,73 @@ end
6767
extraChain(::Tuple{}, x) = ()
6868

6969

70-
7170
"""
72-
Dense(in, out, σ=identity; initW=glorot_uniform, initb=zeros, bias=true)
71+
Dense(in, out, σ=identity; bias=true)
7372
Dense(W, b, σ=identity)
7473
75-
Create a traditional `Dense` layer with in×out weight matrix `W` and
74+
Create a traditional `Dense` layer with in×out weight matrix `W` and
7675
bias vector `b` of length `out`. The forward pass is given by:
7776
7877
y = σ.(W * x .+ b)
7978
80-
The input `x` must be a vector of length `in`, a batch of vectors represented
81-
as an `in × N` matrix, or a higher order tensor where all dimensions
82-
after the first one will be treated as batch dimensions.
83-
84-
The out `y` will be a vector of length `out` or a batch whose first
85-
dimension is `out` and the remaining dimensions are the same as in the input.
79+
The input `x` must be a vector of length `in`, or batch of vectors represented
80+
as an `in × N` matrix, or any array with `size(x,1) == in`.
81+
The out `y` will be a vector of length `out`, or a batch with `size(y) == (out, size(x)[2:end]...)`
8682
87-
Setting `bias` to `false` will switch the bias off for the layer.
83+
Setting `bias=false` creates a layer without bias parameters.
8884
89-
`initW` and `initb` are callables used to initialize weights and biases respectively,
90-
through the calls `initW(out, in)` and `initb(out)`.
85+
Two additional keywords `initW=glorot_uniform` and `initb=Flux.zeros` control the
86+
initialisation of parameters, when using the first constructor.
9187
9288
# Examples
93-
94-
```julia-repl
89+
```jldoctest
9590
julia> d = Dense(5, 2)
9691
Dense(5, 2)
9792
98-
julia> d(rand(Float32, 5))
99-
2-element Array{Float32,1}:
100-
-0.16210233
101-
0.123119034
93+
julia> d(rand(Float32, 5, 64)) |> size
94+
(2, 64)
10295
103-
julia> d = Dense(5, 2; bias=false)
104-
Dense(5, 2)
96+
julia> d1 = Dense(ones(2,5), false, tanh)
97+
Dense(5, 2, tanh; bias=false)
98+
99+
julia> d1(ones(5))
100+
2-element Array{Float64,1}:
101+
0.9999092042625951
102+
0.9999092042625951
103+
104+
julia> params(d1) # no trainable bias
105+
Params([[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]])
105106
```
106107
"""
107-
struct Dense{F,S<:AbstractArray,T<:Union{Zeros, AbstractVector}}
108+
struct Dense{F,S<:AbstractMatrix,T}
108109
W::S
109110
b::T
110111
σ::F
111112
end
112113

113-
Dense(W, b) = Dense(W, b, identity)
114+
Dense(W::AbstractMatrix, b) = Dense(W, b, identity)
114115

115116
function Dense(in::Integer, out::Integer, σ = identity;
116-
initW = glorot_uniform, initb = zeros, bias=true)
117-
return Dense(initW(out, in), create_bias(bias, initb, out), σ)
117+
initW = glorot_uniform, initb = zeros, bias = true)
118+
W = initW(out, in)
119+
b = create_bias(bias, initb, out)
120+
Dense(W, b, σ)
118121
end
119122

120123
@functor Dense
121124

122125
function (a::Dense)(x::AbstractArray)
123126
W, b, σ = a.W, a.b, a.σ
124127
sz = size(x)
125-
x = reshape(x, sz[1], :) # reshape to handle dims > 1 as batch dimensions
128+
x = reshape(x, sz[1], :) # reshape to handle dims > 1 as batch dimensions
126129
x = σ.(W*x .+ b)
127-
return reshape(x, :, sz[2:end]...)
130+
reshape(x, :, sz[2:end]...)
128131
end
129132

130133
function Base.show(io::IO, l::Dense)
131134
print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1))
132135
l.σ == identity || print(io, ", ", l.σ)
136+
l.b == false && print(io, "; bias=false")
133137
print(io, ")")
134138
end
135139

src/layers/conv.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ _paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]
66
expand(N, i::Tuple) = i
77
expand(N, i::Integer) = ntuple(_ -> i, N)
88

9+
conv_reshape_bias(c) = c.bias isa AbstractVector ?
10+
reshape(c.bias, map(_->1, c.stride)..., :, 1) :
11+
c.bias
12+
913
"""
1014
SamePad()
1115
@@ -96,7 +100,7 @@ end
96100

97101
"""
98102
Conv(weight::AbstractArray, bias, [activation; stride, pad, dilation])
99-
103+
100104
Constructs a convolutional layer with the given weight and bias.
101105
Accepts the same keywords (and has the same defaults) as the `Conv((4,4), 3=>7, relu)`
102106
method.
@@ -117,7 +121,7 @@ julia> params(c1) |> length
117121
2
118122
```
119123
"""
120-
function Conv(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity;
124+
function Conv(w::AbstractArray{T,N}, b::Union{Bool,AbstractVector{T}}, σ = identity;
121125
stride = 1, pad = 0, dilation = 1) where {T,N}
122126
stride = expand(Val(N-2), stride)
123127
dilation = expand(Val(N-2), dilation)
@@ -152,9 +156,8 @@ convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
152156
function (c::Conv)(x::AbstractArray)
153157
# TODO: breaks gpu broadcast :(
154158
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
155-
σ, b = c.σ, reshape(c.bias, ntuple(_->1, length(c.stride))..., :, 1)
156159
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
157-
σ.(conv(x, c.weight, cdims) .+ b)
160+
(c.σ).(conv(x, c.weight, cdims) .+ conv_reshape_bias(c))
158161
end
159162

160163
function Base.show(io::IO, l::Conv)
@@ -207,16 +210,16 @@ end
207210

208211
"""
209212
ConvTranspose(weight::AbstractArray, bias, [activation; stride, pad, dilation])
210-
213+
211214
Constructs a layer with the given weight and bias arrays.
212215
Accepts the same keywords as the `ConvTranspose((4,4), 3=>7, relu)` method.
213216
"""
214-
function ConvTranspose(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity;
217+
function ConvTranspose(w::AbstractArray{T,N}, b::Union{Bool, AbstractVector{T}}, σ = identity;
215218
stride = 1, pad = 0, dilation = 1) where {T,N}
216219
stride = expand(Val(N-2), stride)
217220
dilation = expand(Val(N-2), dilation)
218221
pad = calc_padding(ConvTranspose, pad, size(w)[1:N-2], dilation, stride)
219-
bias = create_bias(b, zeros, size(w, N-1))
222+
bias = create_bias(b, zeros, size(w, N-1))
220223
return ConvTranspose(σ, w, bias, stride, pad, dilation)
221224
end
222225

@@ -248,9 +251,8 @@ end
248251

249252
function (c::ConvTranspose)(x::AbstractArray)
250253
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
251-
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
252254
cdims = conv_transpose_dims(c, x)
253-
σ.(∇conv_data(x, c.weight, cdims) .+ b)
255+
(c.σ).(∇conv_data(x, c.weight, cdims) .+ conv_reshape_bias(c))
254256
end
255257

256258
function Base.show(io::IO, l::ConvTranspose)
@@ -304,11 +306,11 @@ end
304306

305307
"""
306308
DepthwiseConv(weight::AbstractArray, bias, [activation; stride, pad, dilation])
307-
309+
308310
Constructs a layer with the given weight and bias arrays.
309311
Accepts the same keywords as the `DepthwiseConv((4,4), 3=>6, relu)` method.
310312
"""
311-
function DepthwiseConv(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity;
313+
function DepthwiseConv(w::AbstractArray{T,N}, b::Union{Bool,AbstractVector{T}}, σ = identity;
312314
stride = 1, pad = 0, dilation = 1) where {T,N}
313315
stride = expand(Val(N-2), stride)
314316
dilation = expand(Val(N-2), dilation)
@@ -341,9 +343,8 @@ depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
341343
init = glorot_uniform) where N = init(filter..., div(ch[2], ch[1]), ch[1])
342344

343345
function (c::DepthwiseConv)(x)
344-
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
345346
cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
346-
σ.(depthwiseconv(x, c.weight, cdims) .+ b)
347+
(c.σ).(depthwiseconv(x, c.weight, cdims) .+ conv_reshape_bias(c))
347348
end
348349

349350
function Base.show(io::IO, l::DepthwiseConv)
@@ -392,11 +393,11 @@ end
392393

393394
"""
394395
CrossCor(weight::AbstractArray, bias, [activation; stride, pad, dilation])
395-
396+
396397
Constructs a layer with the given weight and bias arrays.
397398
Accepts the same keywords as the `CrossCor((4,4), 3=>7, relu)` method.
398399
"""
399-
function CrossCor(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity;
400+
function CrossCor(w::AbstractArray{T,N}, b::Union{Bool,AbstractVector{T}} = true, σ = identity;
400401
stride = 1, pad = 0, dilation = 1) where {T,N}
401402
stride = expand(Val(N-2), stride)
402403
dilation = expand(Val(N-2), dilation)
@@ -422,9 +423,8 @@ end
422423
function (c::CrossCor)(x::AbstractArray)
423424
# TODO: breaks gpu broadcast :(
424425
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
425-
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
426426
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
427-
σ.(crosscor(x, c.weight, cdims) .+ b)
427+
(c.σ).(crosscor(x, c.weight, cdims) .+ conv_reshape_bias(c))
428428
end
429429

430430
function Base.show(io::IO, l::CrossCor)

src/utils.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -225,18 +225,21 @@ ones(dims...) = Base.ones(Float32, dims...)
225225
zeros(dims...) = Base.zeros(Float32, dims...)
226226

227227
"""
228-
create_bias(shallcreate::Bool, iftrue, dims...)
229-
create_bias(x, ::Any...)
228+
create_bias(bias::Bool, iftrue, dims...)
230229
231-
Return a bias parameter for a layer.
230+
Return a bias parameter for a layer, based the value given
231+
to the constructor's keyword `bias=bias`.
232232
233-
Essentially handles the allowed input options for the `bias` keyword:
234-
If `false`: Return the `Zeros` type which turns bias off.
235-
If `true` : Return the result of `iftrue(dims)`.
236-
If not a boolean, return self to handle the case of bias=somearray.
233+
* `bias == true` creates `iftrue(dims...)`, typically a dense vector of zeros.
234+
* `bias == false` returns `false`, to indicate no trainable bias.
235+
* `bias::AbstractArray` uses the array provided. It checks size but not eltype.
237236
"""
238-
create_bias(shallcreate::Bool, iftrue, dims...) = shallcreate ? iftrue(dims...) : Zeros()
239-
create_bias(x, ::Any...) = x
237+
function create_bias(bias, iftrue, dims...)
238+
bias===true && return iftrue(dims...)
239+
bias===false && return false
240+
size(bias) == dims || throw(DimensionMismatch("expected bias of size $dims, but got $(size(bias))"))
241+
return bias
242+
end
240243

241244
"""
242245
unsqueeze(xs, dim)

src/zeros.jl

Lines changed: 0 additions & 49 deletions
This file was deleted.

test/optimise.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ using Random
1414
Nesterov(), RMSProp(), Momentum()]
1515
Random.seed!(42)
1616
w′ = randn(10, 10)
17-
b = Flux.Zeros()
17+
b = false
1818
loss(x) = Flux.Losses.mse(w*x, w′*x .+ b)
1919
for t = 1: 10^5
2020
θ = params([w′, b])

0 commit comments

Comments
 (0)