Skip to content

Commit f9e12e7

Browse files
authored
NNlib: forward feature group counts to convolution (#109)
1 parent 9f450a0 commit f9e12e7

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

ext/ReactantNNlibExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ function NNlib.conv(
4646
stride = NNlib.stride(cdims)
4747
dilation = NNlib.dilation(cdims)
4848
flipkernel = NNlib.flipkernel(cdims)
49+
feature_group_count = NNlib.groupcount(cdims)
4950

5051
input_rank = ndims(x)
5152

@@ -102,7 +103,7 @@ function NNlib.conv(
102103
dimension_numbers,
103104
lhs_dilation=1,
104105
rhs_dilation=collect(dilation),
105-
feature_group_count=1,
106+
feature_group_count,
106107
batch_group_count=1,
107108
)
108109

test/nn.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,21 @@ out2 = model(noisy) # first row is prob. of true, second row p(false)
6868

6969
mean((out2[1, :] .> 0.5) .== truth) # accuracy 94% so far!
7070

71-
@testset "conv" begin
72-
conv = Conv(randn(Float32, 10, 10, 3, 1), randn(Float32, 1))
71+
@testset "conv: groups $groups" for groups in (1, 2, 4)
72+
nn_conv = Conv(randn(Float32, 10, 10, 8 ÷ groups, groups), randn(Float32, groups); groups)
7373
conv_reactant = Conv(
74-
Reactant.ConcreteRArray(conv.weight), Reactant.ConcreteRArray(conv.bias)
74+
Reactant.ConcreteRArray(nn_conv.weight), Reactant.ConcreteRArray(nn_conv.bias);
75+
groups
7576
)
7677

77-
img = randn(Float32, 224, 224, 3, 2)
78+
img = randn(Float32, 224, 224, 8, 2)
7879
img_reactant = Reactant.ConcreteRArray(img)
7980

8081
comp_conv = Reactant.compile(conv_reactant, (img_reactant,))
8182

8283
res_reactant = Array{Float32,4}(comp_conv(img_reactant))
83-
res = conv(img)
8484

85-
@test res_reactant res
85+
@test res_reactant nn_conv(img)
8686
end
8787

8888
@testset "$f" for f in (NNlib.meanpool, NNlib.maxpool)

0 commit comments

Comments
 (0)