Skip to content

Commit d677811

Browse files
mikymatt01mcabbott
andauthored
Added shape validation for Conv weight tensor (#2590)
* Added shape validation for Conv weight tensor * Moved weight shape validation logic into for reusability and added a test to verify shape validation * sizecheck function moved after _conv_size_check definition * Inline DimensionMismatch error to simplify weight shape check Co-authored-by: Michael Abbott <[email protected]> * Remove unnecessary newline Co-authored-by: Michael Abbott <[email protected]> --------- Co-authored-by: Michael Abbott <[email protected]>
1 parent fa108bb commit d677811

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

src/layers/conv.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,15 @@ distribution.
176176
177177
This is internally used by the [`Conv`](@ref) layer.
178178
"""
179+
179180
function convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
180181
init = glorot_uniform, groups = 1) where N
181182
cin, cout = ch
182183
@assert cin % groups == 0 "Input channel dimension must be divisible by groups."
183184
@assert cout % groups == 0 "Output channel dimension must be divisible by groups."
184-
init(filter..., cin÷groups, cout)
185+
shape = (filter..., cin ÷ groups, cout)
186+
weight = _sizecheck(init, shape...)
187+
weight
185188
end
186189

187190
@layer Conv
@@ -501,6 +504,25 @@ function _conv_size_check(layer, x::AbstractArray)
501504
lazy" expects size(input, $d) == $n, but got ", summary(x))))
502505
end
503506
ChainRulesCore.@non_differentiable _conv_size_check(::Any, ::Any)
507+
508+
"""
509+
_sizecheck(f, sz::Integer...)
510+
511+
Ensures that the output of `f(sz...)` has the expected shape `sz`.
512+
513+
Constructs a tensor using the function `f` with the given size `sz` and verifies that its shape matches `sz`.
514+
If the shape does not match, a `DimensionMismatch` error is thrown.
515+
516+
This is internally used to validate weight initialization functions.
517+
"""
518+
function _sizecheck(f, sz::Integer...)
519+
W = f(sz...)
520+
size(W) == sz || throw(DimensionMismatch(
521+
"Weight shape mismatch: expected $(sz), got $(size(W))",
522+
))
523+
W
524+
end
525+
504526
"""
505527
AdaptiveMaxPool(out::NTuple)
506528

test/layers/conv.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,14 @@ end
8787
# Test that we cannot ask for non-integer multiplication factors
8888
@test_throws AssertionError Conv((2, 2), 3=>10, groups=2)
8989
@test_throws AssertionError Conv((2, 2), 2=>9, groups=2)
90+
91+
# Test that Conv throws a DimensionMismatch error when the initializer
92+
# produces a tensor with an incorrect shape.
93+
@test_throws DimensionMismatch Conv(
94+
(3, 3),
95+
1 => 1;
96+
init = (_...) -> rand(3, 3, 1),
97+
)
9098
end
9199
end
92100

0 commit comments

Comments
 (0)