Skip to content

Commit 8f650ac

Browse files
Michael AbbottMichael Abbott
authored andcommitted
squash PR 1407, eleven commits, 2020
1 parent 69e2198 commit 8f650ac

File tree

8 files changed

+52
-144
lines changed

8 files changed

+52
-144
lines changed

src/Flux.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ using CUDA
3434
const use_cuda = Ref(false)
3535

3636
include("utils.jl")
37-
include("zeros.jl")
3837
include("onehot.jl")
3938
include("functor.jl")
4039

src/deprecations.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
@deprecate InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, true, true, nothing)
44
@deprecate BatchNorm(λ, β, γ, μ, σ², ϵ, momentum) BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, true, true, nothing)
55
@deprecate GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum) GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum, nothing)
6+
67
@deprecate outdims(f, inputsize) outputsize(f, inputsize)
8+
79
@deprecate Conv(; weight, bias, activation=identity, kws...) Conv(weight, bias, activation; kws...)
810
@deprecate ConvTranspose(; weight, bias, activation=identity, kws...) ConvTranspose(weight, bias, activation; kws...)
911
@deprecate DepthwiseConv(; weight, bias, activation=identity, kws...) DepthwiseConv(weight, bias, activation; kws...)
@@ -18,3 +20,11 @@ function Base.getproperty(a::Dense, s::Symbol)
1820
end
1921
return getfield(a, s)
2022
end
23+
24+
struct Zeros # was used both Dense(10, 2, initb = Zeros) and Dense(rand(2,10), Zeros())
25+
function Zeros()
26+
Base.depwarn("Zeros() and Zeros(dims...) are deprecated, please simply use bias=false instead", :Zeros)
27+
false
28+
end
29+
end
30+
Zeros(args...) = Zeros()

src/layers/basic.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ end
6767
extraChain(::Tuple{}, x) = ()
6868

6969

70-
7170
"""
7271
Dense(in, out, σ=identity; bias=true, init=glorot_uniform)
7372
Dense(W::AbstractMatrix, [bias, σ])
@@ -153,7 +152,7 @@ end
153152
function Base.show(io::IO, l::Dense)
154153
print(io, "Dense(", size(l.weight, 2), ", ", size(l.weight, 1))
155154
l.σ == identity || print(io, ", ", l.σ)
156-
l.bias == Zeros() && print(io, "; bias=false")
155+
l.bias == false && print(io, "; bias=false")
157156
print(io, ")")
158157
end
159158

src/layers/conv.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ _paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]
66
expand(N, i::Tuple) = i
77
expand(N, i::Integer) = ntuple(_ -> i, N)
88

9+
conv_reshape_bias(c) = c.bias isa AbstractVector ?
10+
reshape(c.bias, map(_->1, c.stride)..., :, 1) :
11+
c.bias
12+
913
"""
1014
SamePad()
1115
@@ -152,9 +156,8 @@ convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
152156
function (c::Conv)(x::AbstractArray)
153157
# TODO: breaks gpu broadcast :(
154158
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
155-
σ, b = c.σ, reshape(c.bias, ntuple(_->1, length(c.stride))..., :, 1)
156159
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
157-
σ.(conv(x, c.weight, cdims) .+ b)
160+
(c.σ).(conv(x, c.weight, cdims) .+ conv_reshape_bias(c))
158161
end
159162

160163
function Base.show(io::IO, l::Conv)
@@ -248,9 +251,8 @@ end
248251

249252
function (c::ConvTranspose)(x::AbstractArray)
250253
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
251-
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
252254
cdims = conv_transpose_dims(c, x)
253-
σ.(∇conv_data(x, c.weight, cdims) .+ b)
255+
(c.σ).(∇conv_data(x, c.weight, cdims) .+ conv_reshape_bias(c))
254256
end
255257

256258
function Base.show(io::IO, l::ConvTranspose)
@@ -341,9 +343,8 @@ depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
341343
init = glorot_uniform) where N = init(filter..., div(ch[2], ch[1]), ch[1])
342344

343345
function (c::DepthwiseConv)(x)
344-
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
345346
cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
346-
σ.(depthwiseconv(x, c.weight, cdims) .+ b)
347+
(c.σ).(depthwiseconv(x, c.weight, cdims) .+ conv_reshape_bias(c))
347348
end
348349

349350
function Base.show(io::IO, l::DepthwiseConv)
@@ -422,9 +423,8 @@ end
422423
function (c::CrossCor)(x::AbstractArray)
423424
# TODO: breaks gpu broadcast :(
424425
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
425-
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
426426
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
427-
σ.(crosscor(x, c.weight, cdims) .+ b)
427+
(c.σ).(crosscor(x, c.weight, cdims) .+ conv_reshape_bias(c))
428428
end
429429

430430
function Base.show(io::IO, l::CrossCor)

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,11 +297,11 @@ Return a bias parameter for a layer, based on the value given
297297
to the constructor's keyword `bias=bias`.
298298
299299
* `bias == true` creates a zero vector, of the same type as weights.
300-
* `bias == false` returns `Zeros()`, a special struct which exists only to encode the absence of bias.
300+
* `bias == false` returns `false`, to indicate no trainable bias.
301301
* `bias::AbstractArray` uses the array provided, provided it has the correct size and eltype. If the type is wrong, it will be converted.
302302
"""
303303
function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
304-
bias ? fill!(similar(weights, dims...), 0) : Zeros()
304+
bias ? fill!(similar(weights, dims...), 0) : false
305305
end
306306
function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
307307
size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))

src/zeros.jl

Lines changed: 0 additions & 52 deletions
This file was deleted.

test/optimise.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ using Random
1414
Nesterov(), RMSProp(), Momentum()]
1515
Random.seed!(42)
1616
w′ = randn(10, 10)
17-
b = Flux.Zeros()
17+
b = false
1818
loss(x) = Flux.Losses.mse(w*x, w′*x .+ b)
1919
for t = 1: 10^5
2020
θ = params([w′, b])

test/utils.jl

Lines changed: 30 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -187,88 +187,39 @@ end
187187
@test eltype(f32(f64(m))[1].W) == Float32
188188
end
189189

190-
@testset "Zeros" begin
190+
@testset "Without bias" begin
191191
m = Dense(3,2; bias=false)
192-
@test f64(m).b === m.b === Zeros()
193-
@test f32(m).b === m.b === Zeros()
192+
@test f64(m).b === m.b === false === Zeros() # Zeros() is deprecated
193+
@test f32(m).b === m.b === false
194194

195195
@testset "Gradients for broadcasted $op with sizes $s" for op in (+,-,*), s in ((1,), (2,3))
196196
o = ones(s)
197197
z = zeros(s)
198-
Z = Zeros()
199198

200199
@testset "Explicit" begin
201200
gfun(args...) = gradient((x, y) -> sum(op.(x,y)), args...)
202201
g = gfun(o, z)
203-
@test gfun(o, Z) == (g[1], nothing)
202+
@test gfun(o, false) == (g[1], nothing)
204203

205204
g = gfun(z, o)
206-
@test gfun(Z, o) == (nothing, g[2])
205+
@test gfun(false, o) == (nothing, g[2])
207206
end
208207

209208
@testset "Implicit" begin
210209
gfun(args...) = gradient(() -> sum(op.(args...)), params(collect(args)))
211210
g = gfun(o, z)
212211

213-
gres = gfun(o, Z)
212+
gres = gfun(o, false)
214213
@test gres[o] == g[o]
215-
@test Z gres.params
214+
@test false gres.params
215+
@test length(gres.params) == 1
216216

217217
g = gfun(z, o)
218-
gres = gfun(Z, o)
219-
@test gres[o] == g[o]
220-
@test Z gres.params
221-
end
222-
end
223-
224-
@testset "Gradients for broadcasted / with sizes $s" for s in ((1,), (2,3))
225-
o = ones(s)
226-
z = zeros(s)
227-
Z = Zeros() # Only defined for 0-dim
228-
229-
@testset "Explicit" begin
230-
gfun(args...) = gradient((x, y) -> sum(x ./ y), args...)
231-
g = gfun(z, o)
232-
@test gfun(Z, o) == (nothing, g[2])
233-
end
234-
235-
@testset "Implicit" begin
236-
gfun(x,y) = gradient(() -> sum(x ./ y), params([x,y]))
237-
238-
g = gfun(z, o)
239-
gres = gfun(Z, o)
240-
@test gres[o] == g[o]
241-
@test Z gres.params
242-
end
243-
end
244-
245-
@testset "Gradients for $op with sizes $s" for op in (+,-), s in (tuple(), (1,), (2,3))
246-
o = ones(s)
247-
z = zeros(s)
248-
Z = Zeros()
249-
250-
251-
@testset "Explicit" begin
252-
gfun(args...) = gradient((x, y) -> sum(op(x,y)), args...)
253-
254-
g = gfun(o, z)
255-
@test gfun(o, Z) == (g[1], nothing)
256-
257-
g = gfun(z, o)
258-
@test gfun(Z, o) == (nothing, g[2])
259-
end
260218

261-
@testset "Implicit" begin
262-
gfun(args...) = gradient(() -> sum(op(args...)), params(collect(args)))
263-
g = gfun(o, z)
264-
gres = gfun(o, Z)
219+
gres = gfun(false, o)
265220
@test gres[o] == g[o]
266-
@test Z gres.params
267-
268-
g = gfun(z, o)
269-
gres = gfun(Z, o)
270-
@test gres[o] == g[o]
271-
@test Z gres.params
221+
@test false gres.params
222+
@test length(gres.params) == 1
272223
end
273224
end
274225
end
@@ -281,52 +232,53 @@ end
281232
@test stack(unstack(stacked_array, 1), 1) == stacked_array
282233
end
283234

235+
284236
@testset "Param remapping" begin
285-
ls(dims...) = reshape(collect(Float32, 1:prod(dims)), dims...) # accepts dims in reverse order to Dense
286-
dl(nin, nout, bias) = Dense(ls(nout, nin), bias(nout))
287-
dm(bias) = Chain(
288-
dl(3, 5, bias),
289-
dl(5, 4, bias),
290-
dl(4, 3, bias)
237+
count32(dims...) = reshape(collect(Float32, 1:prod(dims)), dims...) # accepts dims in reverse order to Dense
238+
dl(nin, nout, bt) = Dense(count32(nout, nin), bt(nout)) # this accepts dims in same order as Dense
239+
densechain(bt) = Chain(
240+
dl(3, 5, bt),
241+
dl(5, 4, bt),
242+
dl(4, 3, bt)
291243
)
244+
nobias(n) = false
292245

293-
nobias(n) = Zeros()
294-
testdense(m, bt) = @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, dm(bt)))
295-
@test l1.W == l2.W
296-
@test l1.b == l2.b
297-
@test typeof(l1.b) === typeof(l2.b)
246+
testdense(m, bt) = @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, densechain(bt)))
247+
@test l1.weight == l2.weight
248+
@test l1.bias == l2.bias
249+
@test typeof(l1.bias) === typeof(l2.bias)
298250
end
299251

300252
@testset "loadparams!" begin
301-
import Flux: loadparams!
302253
pars(w, b) = [w, b]
303254
import Flux: loadparams!, Zeros
304255
pars(w, b::Zeros) = [w, Flux.zeros(size(w,1))]
305256
pars(l) = pars(l.W, l.b)
306257
pararray(m) = mapreduce(pars, vcat, m)
307258
weights(m) = mapreduce(l -> [l.W], vcat, m)
308-
@testset "Bias type $bt" for bt in (Flux.zeros, nobias)
309-
m = dm(bt)
259+
@testset "Bias type $bt" for bt in (zeros, nobias)
260+
m = densechain(bt)
310261
loadparams!(m, params(m))
311262
testdense(m, bt)
312263
end
313-
264+
#=
314265
@testset "$b1 to $b2" for (b1, b2, be) in (
315266
(Flux.zeros, ones, ones), # Load ones as bias to a model with zeros as bias -> model gets ones as bias
316267
(ones, nobias, Flux.zeros), # Load Zeros as bias to a model with ones as bias-> model gets zeros as bias
317268
(nobias, ones, nobias), # Load ones as bias to a model with Zeros as bias-> model bias does not change
318269
)
319-
m1 = dm(b1)
320-
m2 = dm(b2)
270+
m1 = densechain(b1)
271+
m2 = densechain(b2)
321272
loadparams!(m1, b1 == nobias ? weights(m2) : pararray(m2))
322273
testdense(m1, be)
323274
end
275+
=#
324276
end
325277

326278
@testset "destructure" begin
327279
import Flux: destructure
328280
@testset "Bias type $bt" for bt in (zeros, nobias)
329-
m = dm(bt)
281+
m = densechain(bt)
330282
p, re = destructure(m)
331283
testdense(re(p), bt)
332284
end

0 commit comments

Comments
 (0)