Skip to content

Commit 1a683bc

Browse files
committed
Added tests for wrappers
1 parent df48e61 commit 1a683bc

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

test/conv.jl

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using NNlib: input_size, kernel_size, channels_in, channels_out, channel_multipl
1212
elseif T == DepthwiseConvDims
1313
w = randn(1,2,4,3)
1414
end
15-
15+
1616
# First, getters:
1717
cdims = T(x, w)
1818
@test input_size(cdims) == size(x)[1:2]
@@ -39,7 +39,7 @@ using NNlib: input_size, kernel_size, channels_in, channels_out, channel_multipl
3939
@test padding(cdims) == (3,3,3,3)
4040
@test flipkernel(cdims) == true
4141
@test output_size(cdims) == (6,4)
42-
42+
4343
# Next, tuple settings
4444
cdims = T(x, w; stride=(1, 2), dilation=(1, 2), padding=(0,1))
4545
@test stride(cdims) == (1,2)
@@ -215,7 +215,7 @@ conv_answer_dict = Dict(
215215
"dx_dil" => reshape([
216216
4864, 5152, 9696, 4508, 4760, 6304, 6592, 12396, 5768, 6020, 3648,
217217
3864, 7120, 3220, 3400, 4728, 4944, 9100, 4120, 4300, 0, 0, 0, 0, 0, 0,
218-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2432, 2576, 4544, 1932, 2040,
218+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2432, 2576, 4544, 1932, 2040,
219219
3152, 3296, 5804, 2472, 2580, 1216, 1288, 1968, 644, 680, 1576, 1648,
220220
2508, 824, 860.
221221
], (5,4,3)),
@@ -273,7 +273,7 @@ conv_answer_dict = Dict(
273273

274274
# A "drop channels and batch dimension" helper
275275
ddims(x) = dropdims(x, dims=(rank+1, rank+2))
276-
276+
277277
for conv in (NNlib.conv, NNlib.conv_im2col, NNlib.conv_direct)
278278
@testset "$(conv)" begin
279279
# First, your basic convolution with no parameters
@@ -392,7 +392,7 @@ conv_answer_dict = Dict(
392392
D_size in (1, 2, 4, (1,2), (3,2), (4,2,3))
393393

394394
# Skip tests that are impossible due to mismatched sizes
395-
try
395+
try
396396
DenseConvDims(x, w;
397397
stride=S_size, padding=P_size, dilation=D_size,
398398
)
@@ -473,7 +473,7 @@ end
473473

474474
# A "drop channels and batch dimension" helper
475475
ddims(x) = dropdims(x, dims=(rank+1, rank+2))
476-
476+
477477
for conv in (NNlib.depthwiseconv, NNlib.depthwiseconv_im2col, NNlib.depthwiseconv_direct)
478478
@testset "$(conv)" begin
479479
# First, your basic convolution with no parameters
@@ -592,7 +592,7 @@ end
592592
D_size in (1, 2, 4, (1,2), (3,2), (4,2,3))
593593

594594
# Skip tests that are impossible due to mismatched sizes
595-
try
595+
try
596596
DepthwiseConvDims(x, w;
597597
stride=S_size, padding=P_size, dilation=D_size,
598598
)
@@ -639,3 +639,21 @@ end
639639
println()
640640
end
641641
end
642+
643+
@testset "conv_wrapper" begin
644+
x = rand(10, 10, 3, 10)
645+
w = rand(2, 2, 3, 16)
646+
w1 = rand(3, 4, 3, 16)
647+
@test size(conv(x, w)) == (9, 9, 16, 10)
648+
@test size(conv(x, w; stride = (2, 2), pad = (2, 2))) == (7, 7, 16, 10)
649+
@test size(conv(x, w1; stride = (1, 2), pad = (2, 3))) == (12, 7, 16, 10)
650+
end
651+
652+
@testset "depthwiseconv_wrapper" begin
653+
x = rand(10, 10, 3, 10)
654+
w = rand(2, 2, 3, 3)
655+
w1 = rand(3, 4, 3, 3)
656+
@test size(depthwiseconv(x, w)) == (9, 9, 9, 10)
657+
@test size(depthwiseconv(x, w; stride = (2, 2), pad = (2, 2))) == (7, 7, 9, 10)
658+
@test size(depthwiseconv(x, w1; stride = (1, 2), pad = (2, 3))) == (12, 7, 9, 10)
659+
end

0 commit comments

Comments
 (0)