Skip to content

Commit 7ee354d

Browse files
authored
Fix output shape of ∇conv_data_direct! (#171)
* Fix output shape of ∇conv_data_direct! * Add test to check output shapes of direct conv
1 parent b47a2ef commit 7ee354d

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

src/impl/conv_direct.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ function ∇conv_data_direct!(dx::AbstractArray{xT,5}, dy::AbstractArray{yT,5},
161161
dilation=dilation(cdims),
162162
flipkernel=flipkernel(cdims))
163163
dx = conv_direct!(dx, dy, w, ctdims; alpha=alpha, beta=beta)
164-
return transpose_swapbatch(dx)
164+
return dx
165165
end
166166

167167
"""

test/conv.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,3 +670,17 @@ end
670670
@test size(depthwiseconv(x, w1; stride = (1, 2), pad = (2, 3), dilation = (2, 2))) == (10, 5, 9, 10)
671671
@test size(depthwiseconv(x, w1; stride = (1, 2), pad = (2, 3), dilation = (2, 2), flipped = true)) == (10, 5, 9, 10)
672672
end
673+
674+
# https://github.com/FluxML/NNlib.jl/pull/171
675+
@testset "conv_direct! - Check Sizes" begin
676+
x_size = (6, 7, 8, 5, 3)
677+
y_size = (5, 6, 7, 4, 3)
678+
w_size = (2, 2, 2, 5, 4)
679+
x = randn(Float32, x_size);
680+
y = randn(Float32, y_size);
681+
w = randn(Float32, w_size);
682+
cdims = DenseConvDims(x_size, w_size)
683+
@test size(NNlib.conv_direct!(y, x, w, cdims)) == y_size
684+
@test size(NNlib.∇conv_data_direct!(x, y, w, cdims)) == x_size
685+
@test size(NNlib.∇conv_filter_direct!(w, x, y, cdims)) == w_size
686+
end

0 commit comments

Comments
 (0)