|
54 | 54 |
|
55 | 55 | @testset "Convolution" begin
|
56 | 56 | @testset for groups in (1, 2, 4)
|
57 |
| - weight = randn(Float32, 10, 10, 8 ÷ groups, groups) |
| 57 | + weight = randn(Float32, 4, 4, 8 ÷ groups, groups) |
58 | 58 | x = randn(Float32, 16, 16, 8, 2)
|
59 | 59 |
|
60 | 60 | weight_reactant = Reactant.ConcreteRArray(weight)
|
61 | 61 | x_reactant = Reactant.ConcreteRArray(x)
|
62 | 62 |
|
63 |
| - @testset for stride in ((1, 1), (2, 2), (3, 3)), padding in ((0, 0), (1, 1), (2, 2)) |
64 |
| - conv_dims = DenseConvDims(x, weight; stride, padding, dilation=1, groups) |
| 63 | + @testset for stride in ((1, 1), (2, 2), (3, 3)), |
| 64 | + padding in ((0, 0), (1, 1), (2, 2), (0, 2), (2, 0), (0, 1), (1, 0)), |
| 65 | + dilation in ((1, 1), (2, 2), (1, 2), (2, 1)) |
| 66 | + |
| 67 | + conv_dims = DenseConvDims(x, weight; stride, padding, dilation, groups) |
65 | 68 |
|
66 | 69 | conv_compiled = Reactant.compile(
|
67 | 70 | NNlib.conv, (x_reactant, weight_reactant, conv_dims)
|
|
0 commit comments