Skip to content

Commit bc7f7c7

Browse files
author
Michael Abbott
committed
make Conv constructors more forgiving about bias type
1 parent 3c4875e commit bc7f7c7

File tree

2 files changed

+31
-19
lines changed

2 files changed

+31
-19
lines changed

src/layers/conv.jl

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ struct Conv{N,M,F,A,V}
9595
end
9696

9797
"""
98-
Conv(weight::AbstractArray, bias, [activation; stride, pad, dilation])
99-
98+
Conv(weight::AbstractArray, [bias, activation; stride, pad, dilation])
99+
100100
Constructs a convolutional layer with the given weight and bias.
101101
Accepts the same keywords (and has the same defaults) as the `Conv((4,4), 3=>7, relu)`
102102
method.
@@ -117,13 +117,13 @@ julia> params(c1) |> length
117117
2
118118
```
119119
"""
120-
function Conv(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity;
120+
function Conv(w::AbstractArray{T,N}, bias = true, σ = identity;
121121
stride = 1, pad = 0, dilation = 1) where {T,N}
122122
stride = expand(Val(N-2), stride)
123123
dilation = expand(Val(N-2), dilation)
124124
pad = calc_padding(Conv, pad, size(w)[1:N-2], dilation, stride)
125-
bias = create_bias(w, b, size(w, N))
126-
return Conv(σ, w, bias, stride, pad, dilation)
125+
b = create_bias(w, bias, size(w, N))
126+
return Conv(σ, w, b, stride, pad, dilation)
127127
end
128128

129129
function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
@@ -206,18 +206,18 @@ struct ConvTranspose{N,M,F,A,V}
206206
end
207207

208208
"""
209-
ConvTranspose(weight::AbstractArray, bias, [activation; stride, pad, dilation])
210-
209+
ConvTranspose(weight::AbstractArray, [bias, activation; stride, pad, dilation])
210+
211211
Constructs a layer with the given weight and bias arrays.
212212
Accepts the same keywords as the `ConvTranspose((4,4), 3=>7, relu)` method.
213213
"""
214-
function ConvTranspose(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity;
214+
function ConvTranspose(w::AbstractArray{T,N}, bias = true, σ = identity;
215215
stride = 1, pad = 0, dilation = 1) where {T,N}
216216
stride = expand(Val(N-2), stride)
217217
dilation = expand(Val(N-2), dilation)
218218
pad = calc_padding(ConvTranspose, pad, size(w)[1:N-2], dilation, stride)
219-
bias = create_bias(w, b, size(w, N-1))
220-
return ConvTranspose(σ, w, bias, stride, pad, dilation)
219+
b = create_bias(w, bias, size(w, N-1))
220+
return ConvTranspose(σ, w, b, stride, pad, dilation)
221221
end
222222

223223
function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
@@ -304,17 +304,17 @@ end
304304

305305
"""
306306
DepthwiseConv(weight::AbstractArray, bias, [activation; stride, pad, dilation])
307-
307+
308308
Constructs a layer with the given weight and bias arrays.
309309
Accepts the same keywords as the `DepthwiseConv((4,4), 3=>6, relu)` method.
310310
"""
311-
function DepthwiseConv(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity;
311+
function DepthwiseConv(w::AbstractArray{T,N}, bias = true, σ = identity;
312312
stride = 1, pad = 0, dilation = 1) where {T,N}
313313
stride = expand(Val(N-2), stride)
314314
dilation = expand(Val(N-2), dilation)
315315
pad = calc_padding(DepthwiseConv, pad, size(w)[1:N-2], dilation, stride)
316-
bias = create_bias(w, b, prod(size(w)[N-1:end]))
317-
return DepthwiseConv(σ, w, bias, stride, pad, dilation)
316+
b = create_bias(w, bias, prod(size(w)[N-1:end]))
317+
return DepthwiseConv(σ, w, b, stride, pad, dilation)
318318
end
319319

320320
function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
@@ -391,18 +391,18 @@ struct CrossCor{N,M,F,A,V}
391391
end
392392

393393
"""
394-
CrossCor(weight::AbstractArray, bias, [activation; stride, pad, dilation])
395-
394+
CrossCor(weight::AbstractArray, [bias, activation; stride, pad, dilation])
395+
396396
Constructs a layer with the given weight and bias arrays.
397397
Accepts the same keywords as the `CrossCor((4,4), 3=>7, relu)` method.
398398
"""
399-
function CrossCor(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity;
399+
function CrossCor(w::AbstractArray{T,N}, bias = true, σ = identity;
400400
stride = 1, pad = 0, dilation = 1) where {T,N}
401401
stride = expand(Val(N-2), stride)
402402
dilation = expand(Val(N-2), dilation)
403403
pad = calc_padding(CrossCor, pad, size(w)[1:N-2], dilation, stride)
404-
bias = create_bias(w, b, size(w, N))
405-
return CrossCor(σ, w, bias, stride, pad, dilation)
404+
b = create_bias(w, bias, size(w, N))
405+
return CrossCor(σ, w, b, stride, pad, dilation)
406406
end
407407

408408
function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;

test/layers/conv.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,15 @@ end
188188
# https://github.com/FluxML/Flux.jl/issues/1421
189189
@test Conv((5, 5), 10 => 20, identity; init = Base.randn).bias isa Vector{Float64}
190190
end
191+
192+
@testset "constructors: $fun" for fun in [Conv, CrossCor, ConvTranspose, DepthwiseConv]
193+
@test fun(rand(2,3,4)).bias isa Vector{Float64}
194+
@test fun(rand(2,3,4,5), false).bias isa Flux.Zeros
195+
if fun == Conv
196+
@test fun(rand(2,3,4,5,6), rand(6)).bias isa Vector{Float64}
197+
@test fun(rand(2,3,4,5,6), 1:6).bias isa Vector{Float64}
198+
elseif fun == DepthwiseConv
199+
@test fun(rand(2,3,4,5,6), rand(30)).bias isa Vector{Float64}
200+
end
201+
@test_throws DimensionMismatch fun(rand(2,3,4), rand(6))
202+
end

0 commit comments

Comments
 (0)