@@ -69,7 +69,7 @@ extraChain(::Tuple{}, x) = ()
69
69
70
70
71
71
"""
72
- Dense(in, out, σ = identity; bias = true, init = glorot_uniform)
72
+ Dense(in, out, σ= identity; bias= true, init= glorot_uniform)
73
73
Dense(W::AbstractMatrix, [bias, σ])
74
74
75
75
Create a traditional `Dense` layer, whose forward pass is given by:
@@ -81,7 +81,7 @@ as an `in × N` matrix, or any array with `size(x,1) == in`.
81
81
The out `y` will be a vector of length `out`, or a batch with
82
82
`size(y) == (out, size(x)[2:end]...)`
83
83
84
- Keyword `bias = false` will switch off trainable bias for the layer.
84
+ Keyword `bias= false` will switch off trainable bias for the layer.
85
85
The initialisation of the weight matrix is `W = init(out, in)`, calling the function
86
86
given to keyword `init`, with default [`glorot_uniform`](@doc Flux.glorot_uniform).
87
87
The weight matrix and/or the bias vector (of length `out`) may also be provided explicitly.
@@ -109,41 +109,45 @@ julia> Flux.params(d1) # no trainable bias
109
109
Params([[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]])
110
110
```
111
111
"""
112
- struct Dense{F,S <: AbstractArray ,T }
113
- weight:: S
114
- bias:: T
112
+ struct Dense{F, M <: AbstractMatrix , B }
113
+ weight:: M
114
+ bias:: B
115
115
σ:: F
116
+ function Dense (W:: M , bias = true , σ:: F = identity) where {M<: AbstractMatrix , F}
117
+ b = create_bias (W, bias, size (W,1 ))
118
+ new {F,M,typeof(b)} (W, b, σ)
119
+ end
116
120
end
117
121
118
- Dense (W, b) = Dense (W, b, identity)
122
+ function Dense (in:: Integer , out:: Integer , σ = identity;
123
+ initW = nothing , initb = nothing ,
124
+ init = glorot_uniform, bias= true )
119
125
120
- Dense (W:: AbstractArray , b:: Bool = true , σ = identity) =
121
- Dense (W, create_bias (W, b, size (W,1 )), σ)
122
-
123
- function Dense (in:: Integer , out:: Integer , σ = identity; initW = nothing ,
124
- init = glorot_uniform, initb = nothing , bias:: Bool = true )
125
- if initW != = nothing
126
- Base. depwarn (" initW is deprecated, please use the `init` keyword instead" , :Dense )
127
- init = initW
126
+ W = if initW != = nothing
127
+ Base. depwarn (" keyword initW is deprecated, please use init (which similarly accepts a funtion like randn)" , :Dense )
128
+ initW (out, in)
129
+ else
130
+ init (out, in)
128
131
end
129
132
130
- if initb != = nothing
131
- Base. depwarn (" initb is deprecated, please use the array based constructors instead " , :Dense )
132
- initb = initb
133
+ b = if bias === true && initb != = nothing
134
+ Base. depwarn (" keyword initb is deprecated, please simply supply the bias vector, bias=initb(out) " , :Dense )
135
+ initb (out)
133
136
else
134
- initb = zeros
137
+ bias
135
138
end
136
- Dense (init (out, in), bias ? initb (out) : Zeros (), σ)
139
+
140
+ return Dense (W, b, σ)
137
141
end
138
142
139
143
@functor Dense
140
144
141
145
function (a:: Dense )(x:: AbstractVecOrMat )
142
146
W, b, σ = a. weight, a. bias, a. σ
143
- σ .(W * x .+ b)
147
+ return σ .(W* x .+ b)
144
148
end
145
149
146
- (a:: Dense )(x) =
150
+ (a:: Dense )(x:: AbstractArray ) =
147
151
reshape (a (reshape (x, size (x,1 ), :)), :, size (x)[2 : end ]. .. )
148
152
149
153
function Base. show (io:: IO , l:: Dense )
@@ -292,6 +296,7 @@ If `x` and `y` are matrices, then each column of the output `z = B(x, y)` is of
292
296
with `B` a Bilinear layer.
293
297
294
298
If `y` is not given, it is taken to be equal to `x`, i.e. `B(x) == B(x, x)`
299
+
295
300
The two inputs may also be provided as a tuple, `B((x, y)) == B(x, y)`,
296
301
which is accepted as the input to a `Chain`.
297
302
@@ -300,7 +305,6 @@ By default the bias vector is `zeros(Float32, out)`, option `bias=false` will sw
300
305
trainable bias. Either of these may be provided explicitly.
301
306
302
307
# Examples
303
-
304
308
```jldoctest
305
309
julia> x, y = randn(Float32, 5, 32), randn(Float32, 5, 32);
306
310
@@ -417,4 +421,4 @@ function Base.show(io::IO, m::Parallel)
417
421
print (io, " Parallel(" , m. connection, " , " )
418
422
join (io, m. layers, " , " )
419
423
print (io, " )" )
420
- end
424
+ end
0 commit comments