Skip to content

Commit 2d357ad

Browse files
authored
Add friendly size check (#2176)
* add _size_check * fix _channels_in(::CrossCor) * outputsize * fix GroupNorm not to re-use the same name for different things, dammit * friendly error for ndims too * is LayerNorm(1) allowed? * rm outputsize(::Chain) * doctest
1 parent d511d7a commit 2d357ad

File tree

7 files changed

+45
-40
lines changed

7 files changed

+45
-40
lines changed

src/layers/basic.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,16 @@ end
168168
@functor Dense
169169

170170
function (a::Dense)(x::AbstractVecOrMat)
171+
_size_check(a, x, 1 => size(a.weight, 2))
171172
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
172173
xT = _match_eltype(a, x) # fixes Float64 input, etc.
173174
return σ.(a.weight * xT .+ a.bias)
174175
end
175176

176-
(a::Dense)(x::AbstractArray) =
177+
function (a::Dense)(x::AbstractArray)
178+
_size_check(a, x, 1 => size(a.weight, 2))
177179
reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...)
180+
end
178181

179182
function Base.show(io::IO, l::Dense)
180183
print(io, "Dense(", size(l.weight, 2), " => ", size(l.weight, 1))
@@ -186,6 +189,14 @@ end
186189
Dense(W::LinearAlgebra.Diagonal, bias = true, σ = identity) =
187190
Scale(W.diag, bias, σ)
188191

192+
function _size_check(layer, x::AbstractArray, (d, n)::Pair)
193+
d > 0 || throw(DimensionMismatch(string("layer ", layer,
194+
" expects ndims(input) > ", ndims(x)-d, ", but got ", summary(x))))
195+
size(x, d) == n || throw(DimensionMismatch(string("layer ", layer,
196+
" expects size(input, $d) == $n, but got ", summary(x))))
197+
end
198+
ChainRulesCore.@non_differentiable _size_check(::Any...)
199+
189200
"""
190201
Scale(size::Integer..., σ=identity; bias=true, init=ones32)
191202
Scale(scale::AbstractArray, [bias, σ])

src/layers/conv.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ conv_dims(c::Conv, x::AbstractArray) =
195195
ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any)
196196

197197
function (c::Conv)(x::AbstractArray)
198+
_size_check(c, x, ndims(x)-1 => _channels_in(c))
198199
σ = NNlib.fast_act(c.σ, x)
199200
cdims = conv_dims(c, x)
200201
xT = _match_eltype(c, x)
@@ -329,6 +330,7 @@ end
329330
ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)
330331

331332
function (c::ConvTranspose)(x::AbstractArray)
333+
_size_check(c, x, ndims(x)-1 => _channels_in(c))
332334
σ = NNlib.fast_act(c.σ, x)
333335
cdims = conv_transpose_dims(c, x)
334336
xT = _match_eltype(c, x)
@@ -418,6 +420,8 @@ struct CrossCor{N,M,F,A,V}
418420
dilation::NTuple{N,Int}
419421
end
420422

423+
_channels_in(l::CrossCor) = size(l.weight, ndims(l.weight)-1)
424+
421425
"""
422426
CrossCor(weight::AbstractArray, [bias, activation; stride, pad, dilation])
423427
@@ -468,6 +472,7 @@ crosscor_dims(c::CrossCor, x::AbstractArray) =
468472
ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any)
469473

470474
function (c::CrossCor)(x::AbstractArray)
475+
_size_check(c, x, ndims(x)-1 => _channels_in(c))
471476
σ = NNlib.fast_act(c.σ, x)
472477
cdims = crosscor_dims(c, x)
473478
xT = _match_eltype(c, x)

src/layers/normalise.jl

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,14 @@ LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]
188188

189189
@functor LayerNorm
190190

191-
(a::LayerNorm)(x) = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))
191+
function (a::LayerNorm)(x::AbstractArray)
192+
ChainRulesCore.@ignore_derivatives if a.diag isa Scale
193+
for d in 1:ndims(a.diag.scale)
194+
_size_check(a, x, d => size(a.diag.scale, d))
195+
end
196+
end
197+
a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))
198+
end
192199

193200
function Base.show(io::IO, l::LayerNorm)
194201
print(io, "LayerNorm(", join(l.size, ", "))
@@ -318,9 +325,8 @@ end
318325
@functor BatchNorm
319326
trainable(bn::BatchNorm) = hasaffine(bn) ?= bn.β, γ = bn.γ) : (;)
320327

321-
function (BN::BatchNorm)(x)
322-
@assert size(x, ndims(x)-1) == BN.chs
323-
N = ndims(x)
328+
function (BN::BatchNorm)(x::AbstractArray{T,N}) where {T,N}
329+
_size_check(BN, x, N-1 => BN.chs)
324330
reduce_dims = [1:N-2; N]
325331
affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
326332
return _norm_layer_forward(BN, x; reduce_dims, affine_shape)
@@ -408,10 +414,8 @@ end
408414
@functor InstanceNorm
409415
trainable(in::InstanceNorm) = hasaffine(in) ?= in.β, γ = in.γ) : (;)
410416

411-
function (l::InstanceNorm)(x)
412-
@assert ndims(x) > 2
413-
@assert size(x, ndims(x)-1) == l.chs
414-
N = ndims(x)
417+
function (l::InstanceNorm)(x::AbstractArray{T,N}) where {T,N}
418+
_size_check(l, x, N-1 => l.chs)
415419
reduce_dims = 1:N-2
416420
affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
417421
return _norm_layer_forward(l, x; reduce_dims, affine_shape)
@@ -511,17 +515,15 @@ end
511515
nothing, chs)
512516
end
513517

514-
function (gn::GroupNorm)(x)
515-
@assert ndims(x) > 2
516-
@assert size(x, ndims(x)-1) == gn.chs
517-
N = ndims(x)
518+
function (gn::GroupNorm)(x::AbstractArray)
519+
_size_check(gn, x, ndims(x)-1 => gn.chs)
518520
sz = size(x)
519-
x = reshape(x, sz[1:N-2]..., sz[N-1]÷gn.G, gn.G, sz[N])
520-
N = ndims(x)
521+
x2 = reshape(x, sz[1:end-2]..., sz[end-1]÷gn.G, gn.G, sz[end])
522+
N = ndims(x2) # == ndims(x)+1
521523
reduce_dims = 1:N-2
522-
affine_shape = ntuple(i -> i (N-1, N-2) ? size(x, i) : 1, N)
523-
x = _norm_layer_forward(gn, x; reduce_dims, affine_shape)
524-
return reshape(x, sz)
524+
affine_shape = ntuple(i -> i (N-1, N-2) ? size(x2, i) : 1, N)
525+
x3 = _norm_layer_forward(gn, x2; reduce_dims, affine_shape)
526+
return reshape(x3, sz)
525527
end
526528

527529
testmode!(m::GroupNorm, mode = true) =

src/layers/recurrent.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ RNNCell((in, out)::Pair, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_
202202

203203
function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {F,I,H,V,T}
204204
Wi, Wh, b = m.Wi, m.Wh, m.b
205+
_size_check(m, x, 1 => size(Wi,2))
205206
σ = NNlib.fast_act(m.σ, x)
206207
xT = _match_eltype(m, T, x)
207208
h = σ.(Wi*xT .+ Wh*h .+ b)
@@ -307,6 +308,7 @@ function LSTMCell((in, out)::Pair;
307308
end
308309

309310
function (m::LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,T}
311+
_size_check(m, x, 1 => size(m.Wi,2))
310312
b, o = m.b, size(h, 1)
311313
xT = _match_eltype(m, T, x)
312314
g = muladd(m.Wi, xT, muladd(m.Wh, h, b))
@@ -379,6 +381,7 @@ GRUCell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = ze
379381
GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1))
380382

381383
function (m::GRUCell{I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,T}
384+
_size_check(m, x, 1 => size(m.Wi,2))
382385
Wi, Wh, b, o = m.Wi, m.Wh, m.b, size(h, 1)
383386
xT = _match_eltype(m, T, x)
384387
gxs, ghs, bs = multigate(Wi*xT, o, Val(3)), multigate(Wh*h, o, Val(3)), multigate(b, o, Val(3))
@@ -448,6 +451,7 @@ GRUv3Cell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state =
448451
init(out, out), init_state(out,1))
449452

450453
function (m::GRUv3Cell{I,H,V,HH,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,HH,T}
454+
_size_check(m, x, 1 => size(m.Wi,2))
451455
Wi, Wh, b, Wh_h̃, o = m.Wi, m.Wh, m.b, m.Wh_h̃, size(h, 1)
452456
xT = _match_eltype(m, T, x)
453457
gxs, ghs, bs = multigate(Wi*xT, o, Val(3)), multigate(Wh*h, o, Val(2)), multigate(b, o, Val(3))

src/outputsize.jl

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ which should work out of the box for custom layers.
6262
If `m` is a `Tuple` or `Vector`, its elements are applied in sequence, like `Chain(m...)`.
6363
6464
# Examples
65-
```julia-repl
65+
```jldoctest
6666
julia> using Flux: outputsize
6767
6868
julia> outputsize(Dense(10 => 4), (10,); padbatch=true)
@@ -80,9 +80,7 @@ julia> outputsize(m, (10, 10, 3, 64))
8080
(6, 6, 32, 64)
8181
8282
julia> try outputsize(m, (10, 10, 7, 64)) catch e println(e) end
83-
┌ Error: layer Conv((3, 3), 3=>16), index 1 in Chain, gave an error with input of size (10, 10, 7, 64)
84-
└ @ Flux ~/.julia/dev/Flux/src/outputsize.jl:114
85-
DimensionMismatch("Input channels must match! (7 vs. 3)")
83+
DimensionMismatch("layer Conv((3, 3), 3 => 16) expects size(input, 3) == 3, but got 10×10×7×64 Array{Flux.NilNumber.Nil, 4}")
8684
8785
julia> outputsize([Dense(10 => 4), Dense(4 => 2)], (10, 1)) # Vector of layers becomes a Chain
8886
(2, 1)
@@ -97,19 +95,6 @@ nil_input(pad::Bool, s::Tuple{Vararg{Integer}}) = pad ? fill(nil, (s...,1)) : fi
9795
nil_input(pad::Bool, multi::Tuple{Vararg{Integer}}...) = nil_input.(pad, multi)
9896
nil_input(pad::Bool, tup::Tuple{Vararg{Tuple}}) = nil_input(pad, tup...)
9997

100-
function outputsize(m::Chain, inputsizes::Tuple{Vararg{Integer}}...; padbatch=false)
101-
x = nil_input(padbatch, inputsizes...)
102-
for (i,lay) in enumerate(m.layers)
103-
try
104-
x = lay(x)
105-
catch err
106-
str = x isa AbstractArray ? "with input of size $(size(x))" : ""
107-
@error "layer $lay, index $i in Chain, gave an error $str"
108-
rethrow(err)
109-
end
110-
end
111-
return size(x)
112-
end
11398

11499
"""
115100
outputsize(m, x_size, y_size, ...; padbatch=false)
@@ -148,9 +133,8 @@ outputsize(m::AbstractVector, input::Tuple...; padbatch=false) = outputsize(Chai
148133
## bypass statistics in normalization layers
149134

150135
for layer in (:BatchNorm, :InstanceNorm, :GroupNorm) # LayerNorm works fine
151-
@eval function (l::$layer)(x::AbstractArray{Nil})
152-
l.chs == size(x, ndims(x)-1) || throw(DimensionMismatch(
153-
string($layer, " expected ", l.chs, " channels, but got size(x) == ", size(x))))
136+
@eval function (l::$layer)(x::AbstractArray{Nil,N}) where N
137+
_size_check(l, x, N-1 => l.chs)
154138
x
155139
end
156140
end

test/cuda/layers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ dropout_layers = [Dropout, AlphaDropout]
115115
gpu_gradtest("Dropout", dropout_layers, r, 0.5f0; test_cpu = false) # dropout is not deterministic
116116

117117
layer_norm = [LayerNorm]
118-
gpu_gradtest("LayerNorm 1", layer_norm, rand(Float32, 28,28,3,4), 1, test_cpu = false) #TODO fix errors
118+
gpu_gradtest("LayerNorm 1", layer_norm, rand(Float32, 28,28,3,4), 28, test_cpu = false) #TODO fix errors
119119
gpu_gradtest("LayerNorm 2", layer_norm, rand(Float32, 5,4), 5)
120120

121121
upsample = [x -> Upsample(scale=x)]

test/outputsize.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ end
144144
@test outputsize(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 1)
145145
m2 = LayerNorm(3, 2)
146146
@test outputsize(m2, (3, 2)) == (3, 2) == size(m2(randn(3, 2)))
147-
@test outputsize(m2, (3,)) == (3, 2) == size(m2(randn(3, 2)))
148147

149148
m = BatchNorm(3)
150149
@test outputsize(m, (32, 32, 3, 16)) == (32, 32, 3, 16)

0 commit comments

Comments
 (0)