@@ -68,21 +68,21 @@ out2 = model(noisy) # first row is prob. of true, second row p(false)
68
68
69
69
mean ((out2[1 , :] .> 0.5 ) .== truth) # accuracy 94% so far!
70
70
71
- @testset " conv" begin
72
- conv = Conv (randn (Float32, 10 , 10 , 3 , 1 ), randn (Float32, 1 ) )
71
+ @testset " conv: groups $groups " for groups in ( 1 , 2 , 4 )
72
+ nn_conv = Conv (randn (Float32, 10 , 10 , 8 ÷ groups, groups ), randn (Float32, groups); groups )
73
73
conv_reactant = Conv (
74
- Reactant. ConcreteRArray (conv. weight), Reactant. ConcreteRArray (conv. bias)
74
+ Reactant. ConcreteRArray (nn_conv. weight), Reactant. ConcreteRArray (nn_conv. bias);
75
+ groups
75
76
)
76
77
77
- img = randn (Float32, 224 , 224 , 3 , 2 )
78
+ img = randn (Float32, 224 , 224 , 8 , 2 )
78
79
img_reactant = Reactant. ConcreteRArray (img)
79
80
80
81
comp_conv = Reactant. compile (conv_reactant, (img_reactant,))
81
82
82
83
res_reactant = Array {Float32,4} (comp_conv (img_reactant))
83
- res = conv (img)
84
84
85
- @test res_reactant ≈ res
85
+ @test res_reactant ≈ nn_conv (img)
86
86
end
87
87
88
88
@testset " $f " for f in (NNlib. meanpool, NNlib. maxpool)
0 commit comments