Skip to content

Commit 8a5b977

Browse files
Merge #1561
1561: Reverts #1557 r=DhairyaLGandhi a=DhairyaLGandhi Reverts #1557 Co-authored-by: Dhairya Gandhi <[email protected]>
2 parents 4734e72 + d98be54 commit 8a5b977

File tree

3 files changed

+33
-31
lines changed

3 files changed

+33
-31
lines changed

src/layers/basic.jl

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ extraChain(::Tuple{}, x) = ()
6969

7070

7171
"""
72-
Dense(in, out, σ = identity; bias = true, init = glorot_uniform)
72+
Dense(in, out, σ=identity; bias=true, init=glorot_uniform)
7373
Dense(W::AbstractMatrix, [bias, σ])
7474
7575
Create a traditional `Dense` layer, whose forward pass is given by:
@@ -81,7 +81,7 @@ as an `in × N` matrix, or any array with `size(x,1) == in`.
8181
The out `y` will be a vector of length `out`, or a batch with
8282
`size(y) == (out, size(x)[2:end]...)`
8383
84-
Keyword `bias = false` will switch off trainable bias for the layer.
84+
Keyword `bias=false` will switch off trainable bias for the layer.
8585
The initialisation of the weight matrix is `W = init(out, in)`, calling the function
8686
given to keyword `init`, with default [`glorot_uniform`](@doc Flux.glorot_uniform).
8787
The weight matrix and/or the bias vector (of length `out`) may also be provided explicitly.
@@ -109,41 +109,45 @@ julia> Flux.params(d1) # no trainable bias
109109
Params([[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]])
110110
```
111111
"""
112-
struct Dense{F,S<:AbstractArray,T}
113-
weight::S
114-
bias::T
112+
struct Dense{F, M<:AbstractMatrix, B}
113+
weight::M
114+
bias::B
115115
σ::F
116+
function Dense(W::M, bias = true, σ::F = identity) where {M<:AbstractMatrix, F}
117+
b = create_bias(W, bias, size(W,1))
118+
new{F,M,typeof(b)}(W, b, σ)
119+
end
116120
end
117121

118-
Dense(W, b) = Dense(W, b, identity)
122+
function Dense(in::Integer, out::Integer, σ = identity;
123+
initW = nothing, initb = nothing,
124+
init = glorot_uniform, bias=true)
119125

120-
Dense(W::AbstractArray, b::Bool = true, σ = identity) =
121-
Dense(W, create_bias(W, b, size(W,1)), σ)
122-
123-
function Dense(in::Integer, out::Integer, σ = identity; initW = nothing,
124-
init = glorot_uniform, initb = nothing, bias::Bool = true)
125-
if initW !== nothing
126-
Base.depwarn("initW is deprecated, please use the `init` keyword instead", :Dense)
127-
init = initW
126+
W = if initW !== nothing
127+
Base.depwarn("keyword initW is deprecated, please use init (which similarly accepts a funtion like randn)", :Dense)
128+
initW(out, in)
129+
else
130+
init(out, in)
128131
end
129132

130-
if initb !== nothing
131-
Base.depwarn("initb is deprecated, please use the array based constructors instead", :Dense)
132-
initb = initb
133+
b = if bias === true && initb !== nothing
134+
Base.depwarn("keyword initb is deprecated, please simply supply the bias vector, bias=initb(out)", :Dense)
135+
initb(out)
133136
else
134-
initb = zeros
137+
bias
135138
end
136-
Dense(init(out, in), bias ? initb(out) : Zeros(), σ)
139+
140+
return Dense(W, b, σ)
137141
end
138142

139143
@functor Dense
140144

141145
function (a::Dense)(x::AbstractVecOrMat)
142146
W, b, σ = a.weight, a.bias, a.σ
143-
σ.(W * x .+ b)
147+
return σ.(W*x .+ b)
144148
end
145149

146-
(a::Dense)(x) =
150+
(a::Dense)(x::AbstractArray) =
147151
reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...)
148152

149153
function Base.show(io::IO, l::Dense)
@@ -292,6 +296,7 @@ If `x` and `y` are matrices, then each column of the output `z = B(x, y)` is of
292296
with `B` a Bilinear layer.
293297
294298
If `y` is not given, it is taken to be equal to `x`, i.e. `B(x) == B(x, x)`
299+
295300
The two inputs may also be provided as a tuple, `B((x, y)) == B(x, y)`,
296301
which is accepted as the input to a `Chain`.
297302
@@ -300,7 +305,6 @@ By default the bias vector is `zeros(Float32, out)`, option `bias=false` will sw
300305
trainable bias. Either of these may be provided explicitly.
301306
302307
# Examples
303-
304308
```jldoctest
305309
julia> x, y = randn(Float32, 5, 32), randn(Float32, 5, 32);
306310
@@ -417,4 +421,4 @@ function Base.show(io::IO, m::Parallel)
417421
print(io, "Parallel(", m.connection, ", ")
418422
join(io, m.layers, ", ")
419423
print(io, ")")
420-
end
424+
end

src/utils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,6 @@ to the constructor's keyword `bias=bias`.
388388
function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
389389
bias ? fill!(similar(weights, dims...), 0) : Zeros()
390390
end
391-
392391
function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
393392
size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
394393
bias

test/layers/basic.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,16 @@ import Flux: activations
4040
@test Dense(rand(Float16, 100,10), true).bias isa Vector{Float16} # creates matching type
4141
@test_skip Dense(rand(Float16, 100,10), rand(100)).bias isa Vector{Float16} # converts to match
4242

43-
@test_skip Dense(3,4; init=Base.randn, bias=true).bias isa Vector{Float64}
43+
@test Dense(3,4; init=Base.randn, bias=true).bias isa Vector{Float64}
4444
@test_skip Dense(3,4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64}
4545

46-
4746
@test_throws MethodError Dense(10, 10.5)
4847
@test_throws MethodError Dense(10, 10.5, tanh)
49-
# @test_throws DimensionMismatch Dense(3,4; bias=rand(5))
50-
# @test_throws DimensionMismatch Dense(rand(4,3), rand(5))
51-
# @test_throws MethodError Dense(rand(5))
52-
# @test_throws MethodError Dense(rand(5), rand(5))
53-
# @test_throws MethodError Dense(rand(5), rand(5), tanh)
48+
@test_throws DimensionMismatch Dense(3,4; bias=rand(5))
49+
@test_throws DimensionMismatch Dense(rand(4,3), rand(5))
50+
@test_throws MethodError Dense(rand(5))
51+
@test_throws MethodError Dense(rand(5), rand(5))
52+
@test_throws MethodError Dense(rand(5), rand(5), tanh)
5453
end
5554
@testset "dimensions" begin
5655
@test length(Dense(10, 5)(randn(10))) == 5

0 commit comments

Comments
 (0)