Skip to content

Commit 12d0a04

Browse files
committed
test: move conv configurations
1 parent 29aff8e commit 12d0a04

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

test/nn/nnlib.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,17 @@ end
5454

5555
@testset "Convolution" begin
5656
@testset for groups in (1, 2, 4)
57-
weight = randn(Float32, 10, 10, 8 ÷ groups, groups)
57+
weight = randn(Float32, 4, 4, 8 ÷ groups, groups)
5858
x = randn(Float32, 16, 16, 8, 2)
5959

6060
weight_reactant = Reactant.ConcreteRArray(weight)
6161
x_reactant = Reactant.ConcreteRArray(x)
6262

63-
@testset for stride in ((1, 1), (2, 2), (3, 3)), padding in ((0, 0), (1, 1), (2, 2))
64-
conv_dims = DenseConvDims(x, weight; stride, padding, dilation=1, groups)
63+
@testset for stride in ((1, 1), (2, 2), (3, 3)),
64+
padding in ((0, 0), (1, 1), (2, 2), (0, 2), (2, 0), (0, 1), (1, 0)),
65+
dilation in ((1, 1), (2, 2), (1, 2), (2, 1))
66+
67+
conv_dims = DenseConvDims(x, weight; stride, padding, dilation, groups)
6568

6669
conv_compiled = Reactant.compile(
6770
NNlib.conv, (x_reactant, weight_reactant, conv_dims)

0 commit comments

Comments
 (0)