Skip to content

Commit ddc688f

Browse files
revert
1 parent 4734e72 commit ddc688f

File tree

5 files changed

+40
-88
lines changed

5 files changed

+40
-88
lines changed

src/layers/basic.jl

Lines changed: 26 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,16 @@
11
"""
22
Chain(layers...)
3-
43
Chain multiple layers / functions together, so that they are called in sequence
54
on a given input.
6-
75
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
86
`m[1:3](x)` will calculate the output of the first three layers.
9-
107
# Examples
118
```jldoctest
129
julia> m = Chain(x -> x^2, x -> x+1);
13-
1410
julia> m(5) == 26
1511
true
16-
1712
julia> m = Chain(Dense(10, 5), Dense(5, 2));
18-
1913
julia> x = rand(10);
20-
2114
julia> m(x) == m[2](m[1](x))
2215
true
2316
```
@@ -52,7 +45,6 @@ end
5245
# only slightly changed to better handle interaction with Zygote @dsweber2
5346
"""
5447
activations(c::Chain, input)
55-
5648
Calculate the forward results of each layers in Chain `c` with `input` as model input.
5749
"""
5850
function activations(c::Chain, input)
@@ -69,81 +61,75 @@ extraChain(::Tuple{}, x) = ()
6961

7062

7163
"""
72-
Dense(in, out, σ = identity; bias = true, init = glorot_uniform)
64+
Dense(in, out, σ=identity; bias=true, init=glorot_uniform)
7365
Dense(W::AbstractMatrix, [bias, σ])
74-
7566
Create a traditional `Dense` layer, whose forward pass is given by:
76-
7767
y = σ.(W * x .+ bias)
78-
7968
The input `x` should be a vector of length `in`, or batch of vectors represented
8069
as an `in × N` matrix, or any array with `size(x,1) == in`.
8170
The out `y` will be a vector of length `out`, or a batch with
8271
`size(y) == (out, size(x)[2:end]...)`
83-
84-
Keyword `bias = false` will switch off trainable bias for the layer.
72+
Keyword `bias=false` will switch off trainable bias for the layer.
8573
The initialisation of the weight matrix is `W = init(out, in)`, calling the function
8674
given to keyword `init`, with default [`glorot_uniform`](@doc Flux.glorot_uniform).
8775
The weight matrix and/or the bias vector (of length `out`) may also be provided explicitly.
88-
8976
# Examples
9077
```jldoctest
9178
julia> d = Dense(5, 2)
9279
Dense(5, 2)
93-
9480
julia> d(rand(Float32, 5, 64)) |> size
9581
(2, 64)
96-
9782
julia> d(rand(Float32, 5, 1, 1, 64)) |> size # treated as three batch dimensions
9883
(2, 1, 1, 64)
99-
10084
julia> d1 = Dense(ones(2, 5), false, tanh) # using provided weight matrix
10185
Dense(5, 2, tanh; bias=false)
102-
10386
julia> d1(ones(5))
104-
2-element Vector{Float64}:
87+
2-element Array{Float64,1}:
10588
0.9999092042625951
10689
0.9999092042625951
107-
10890
julia> Flux.params(d1) # no trainable bias
10991
Params([[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]])
11092
```
11193
"""
112-
struct Dense{F,S<:AbstractArray,T}
113-
weight::S
114-
bias::T
94+
struct Dense{F, M<:AbstractMatrix, B}
95+
weight::M
96+
bias::B
11597
σ::F
98+
function Dense(W::M, bias = true, σ::F = identity) where {M<:AbstractMatrix, F}
99+
b = create_bias(W, bias, size(W,1))
100+
new{F,M,typeof(b)}(W, b, σ)
101+
end
116102
end
117103

118-
Dense(W, b) = Dense(W, b, identity)
119-
120-
Dense(W::AbstractArray, b::Bool = true, σ = identity) =
121-
Dense(W, create_bias(W, b, size(W,1)), σ)
104+
function Dense(in::Integer, out::Integer, σ = identity;
105+
initW = nothing, initb = nothing,
106+
init = glorot_uniform, bias=true)
122107

123-
function Dense(in::Integer, out::Integer, σ = identity; initW = nothing,
124-
init = glorot_uniform, initb = nothing, bias::Bool = true)
125-
if initW !== nothing
126-
Base.depwarn("initW is deprecated, please use the `init` keyword instead", :Dense)
127-
init = initW
108+
W = if initW !== nothing
109+
Base.depwarn("keyword initW is deprecated, please use init (which similarly accepts a funtion like randn)", :Dense)
110+
initW(out, in)
111+
else
112+
init(out, in)
128113
end
129114

130-
if initb !== nothing
131-
Base.depwarn("initb is deprecated, please use the array based constructors instead", :Dense)
132-
initb = initb
115+
b = if bias === true && initb !== nothing
116+
Base.depwarn("keyword initb is deprecated, please simply supply the bias vector, bias=initb(out)", :Dense)
117+
initb(out)
133118
else
134-
initb = zeros
119+
bias
135120
end
136-
Dense(init(out, in), bias ? initb(out) : Zeros(), σ)
121+
122+
return Dense(W, b, σ)
137123
end
138124

139125
@functor Dense
140126

141127
function (a::Dense)(x::AbstractVecOrMat)
142128
W, b, σ = a.weight, a.bias, a.σ
143-
σ.(W * x .+ b)
129+
return σ.(W*x .+ b)
144130
end
145131

146-
(a::Dense)(x) =
132+
(a::Dense)(x::AbstractArray) =
147133
reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...)
148134

149135
function Base.show(io::IO, l::Dense)
@@ -156,14 +142,10 @@ end
156142
"""
157143
Diagonal(α, β)
158144
Diagonal(size::Integer...)
159-
160145
Create an element-wise linear layer, which performs
161-
162146
y = α .* x .+ β
163-
164147
The learnable arrays are initialised `α = ones(Float32, size)` and
165148
`β = zeros(Float32, size)`.
166-
167149
Used by [`LayerNorm`](@ref).
168150
"""
169151
struct Diagonal{T}
@@ -197,11 +179,9 @@ end
197179

198180
"""
199181
Maxout(over)
200-
201182
The [Maxout](https://arxiv.org/abs/1302.4389) layer has a number of
202183
internal layers which all receive the same input. It returns the elementwise
203184
maximum of the internal layers' outputs.
204-
205185
Maxout over linear dense layers satisfies the univeral approximation theorem.
206186
"""
207187
struct Maxout{FS<:Tuple}
@@ -210,20 +190,15 @@ end
210190

211191
"""
212192
Maxout(f, n_alts)
213-
214193
Construct a Maxout layer over `n_alts` instances of the layer given by `f`.
215194
The function takes no arguments and should return some callable layer.
216195
Conventionally, this is a linear dense layer.
217-
218196
# Examples
219-
220197
This constructs a `Maxout` layer over 4 internal dense linear layers, each
221198
identical in structure (784 inputs, 128 outputs):
222199
```jldoctest
223200
julia> insize = 784;
224-
225201
julia> outsize = 128;
226-
227202
julia> Maxout(()->Dense(insize, outsize), 4);
228203
```
229204
"""
@@ -240,25 +215,19 @@ end
240215

241216
"""
242217
SkipConnection(layer, connection)
243-
244218
Create a skip connection which consists of a layer or `Chain` of consecutive
245219
layers and a shortcut connection linking the block's input to the output
246220
through a user-supplied 2-argument callable. The first argument to the callable
247221
will be propagated through the given `layer` while the second is the unchanged,
248222
"skipped" input.
249-
250223
The simplest "ResNet"-type connection is just `SkipConnection(layer, +)`.
251224
Here is a more complicated example:
252225
```jldoctest
253226
julia> m = Conv((3,3), 4 => 7, pad=(1,1));
254-
255227
julia> x = ones(Float32, 5, 5, 4, 10);
256-
257228
julia> size(m(x)) == (5, 5, 7, 10)
258229
true
259-
260230
julia> sm = SkipConnection(m, (mx, x) -> cat(mx, x, dims=3));
261-
262231
julia> size(sm(x)) == (5, 5, 11, 10)
263232
true
264233
```
@@ -281,45 +250,32 @@ end
281250
"""
282251
Bilinear(in1, in2, out, σ=identity; bias=true, init=glorot_uniform)
283252
Bilinear(W::AbstractArray, [bias, σ])
284-
285253
Creates a Bilinear layer, which operates on two inputs at the same time.
286254
Its output, given vectors `x` & `y`, is another vector `z` with,
287255
for all `i ∈ 1:out`:
288-
289256
z[i] = σ(x' * W[i,:,:] * y + bias[i])
290-
291257
If `x` and `y` are matrices, then each column of the output `z = B(x, y)` is of this form,
292258
with `B` a Bilinear layer.
293-
294259
If `y` is not given, it is taken to be equal to `x`, i.e. `B(x) == B(x, x)`
295260
The two inputs may also be provided as a tuple, `B((x, y)) == B(x, y)`,
296261
which is accepted as the input to a `Chain`.
297-
298262
The initialisation works as for [`Dense`](@ref) layer, with `W = init(out, in1, in2)`.
299263
By default the bias vector is `zeros(Float32, out)`, option `bias=false` will switch off
300264
trainable bias. Either of these may be provided explicitly.
301-
302265
# Examples
303-
304266
```jldoctest
305267
julia> x, y = randn(Float32, 5, 32), randn(Float32, 5, 32);
306-
307268
julia> B = Flux.Bilinear(5, 5, 7);
308-
309269
julia> B(x) |> size # interactions based on one input
310270
(7, 32)
311-
312271
julia> B(x,y) == B((x,y)) # two inputs, may be given as a tuple
313272
true
314-
315273
julia> sc = SkipConnection(
316274
Chain(Dense(5, 20, tanh), Dense(20, 9, tanh)),
317275
Flux.Bilinear(9, 5, 3, bias=false),
318276
); # used as the recombinator, with skip as the second input
319-
320277
julia> sc(x) |> size
321278
(3, 32)
322-
323279
julia> Flux.Bilinear(rand(4,8,16), false, tanh) # first dim of weight is the output
324280
Bilinear(8, 16, 4, tanh, bias=false)
325281
```
@@ -373,26 +329,19 @@ end
373329

374330
"""
375331
Parallel(connection, layers...)
376-
377332
Create a 'Parallel' layer that passes an input array to each path in
378333
`layers`, reducing the output with `connection`.
379-
380334
Called with one input `x`, this is equivalent to `reduce(connection, [l(x) for l in layers])`.
381335
If called with multiple inputs, they are `zip`ped with the layers, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.
382-
383336
# Examples
384-
385337
```jldoctest
386338
julia> model = Chain(Dense(3, 5),
387339
Parallel(vcat, Dense(5, 4), Chain(Dense(5, 7), Dense(7, 4))),
388340
Dense(8, 17));
389-
390341
julia> size(model(rand(3)))
391342
(17,)
392-
393343
julia> model = Parallel(+, Dense(10, 2), Dense(5, 2))
394344
Parallel(+, Dense(10, 2), Dense(5, 2))
395-
396345
julia> size(model(rand(10), rand(5)))
397346
(2,)
398347
```

src/utils.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,10 +388,14 @@ to the constructor's keyword `bias=bias`.
388388
function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
389389
bias ? fill!(similar(weights, dims...), 0) : Zeros()
390390
end
391-
392391
function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
393392
size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
394-
bias
393+
if eltype(bias) == eltype(weights)
394+
return bias
395+
else
396+
@warn "converting bias to match element type of weights" typeof(weights) typeof(bias) maxlog=3 _id=hash(dims)
397+
return broadcast(eltype(weights), bias)
398+
end
395399
end
396400

397401
"""

test/layers/basic.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,16 @@ import Flux: activations
4040
@test Dense(rand(Float16, 100,10), true).bias isa Vector{Float16} # creates matching type
4141
@test_skip Dense(rand(Float16, 100,10), rand(100)).bias isa Vector{Float16} # converts to match
4242

43-
@test_skip Dense(3,4; init=Base.randn, bias=true).bias isa Vector{Float64}
43+
@test Dense(3,4; init=Base.randn, bias=true).bias isa Vector{Float64}
4444
@test_skip Dense(3,4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64}
4545

46-
4746
@test_throws MethodError Dense(10, 10.5)
4847
@test_throws MethodError Dense(10, 10.5, tanh)
49-
# @test_throws DimensionMismatch Dense(3,4; bias=rand(5))
50-
# @test_throws DimensionMismatch Dense(rand(4,3), rand(5))
51-
# @test_throws MethodError Dense(rand(5))
52-
# @test_throws MethodError Dense(rand(5), rand(5))
53-
# @test_throws MethodError Dense(rand(5), rand(5), tanh)
48+
@test_throws DimensionMismatch Dense(3,4; bias=rand(5))
49+
@test_throws DimensionMismatch Dense(rand(4,3), rand(5))
50+
@test_throws MethodError Dense(rand(5))
51+
@test_throws MethodError Dense(rand(5), rand(5))
52+
@test_throws MethodError Dense(rand(5), rand(5), tanh)
5453
end
5554
@testset "dimensions" begin
5655
@test length(Dense(10, 5)(randn(10))) == 5

test/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ end
194194
@test fun(rand(2,3,4,5), false).bias isa Flux.Zeros
195195
if fun == Conv
196196
@test fun(rand(2,3,4,5,6), rand(6)).bias isa Vector{Float64}
197-
@test_skip fun(rand(2,3,4,5,6), 1:6).bias isa Vector{Float64}
197+
@test fun(rand(2,3,4,5,6), 1:6).bias isa Vector{Float64}
198198
elseif fun == DepthwiseConv
199199
@test fun(rand(2,3,4,5,6), rand(30)).bias isa Vector{Float64}
200200
end

test/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ end
342342
testdense(m, bt) = @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, dm(bt)))
343343
@test l1.W == l2.W
344344
@test l1.b == l2.b
345-
@test_skip typeof(l1.b) === typeof(l2.b)
345+
@test typeof(l1.b) === typeof(l2.b)
346346
end
347347

348348
@testset "loadparams!" begin

0 commit comments

Comments
 (0)