Skip to content

Commit a0797e5

Browse files
authored
Fix heteroconv aggregation + tests (#333)
* Fix heteroconv aggregation + tests Fixes #332 * replace @Assert with @test
1 parent 9a4f4b7 commit a0797e5

File tree

2 files changed

+65
-2
lines changed

2 files changed

+65
-2
lines changed

src/layers/heteroconv.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,13 @@ function _reduceby_node_t(aggr, outs, ntypes)
7676
return foldl(aggr, outs[i] for i in idxs)
7777
end
7878
end
79-
vals = [_reduce(node_t) for node_t in ntypes]
80-
return NamedTuple{tuple(ntypes...)}(vals)
79+
# workaround to provide the aggregation once per unique node type,
80+
# gradient is not needed
81+
unique_ntypes = Flux.ChainRulesCore.ignore_derivatives() do
82+
unique(ntypes)
83+
end
84+
vals = [_reduce(node_t) for node_t in unique_ntypes]
85+
return NamedTuple{tuple(unique_ntypes...)}(vals)
8186
end
8287

8388
function Base.show(io::IO, hgc::HeteroGraphConv)

test/layers/heteroconv.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,62 @@
3333
(:B, :to, :A) => GraphConv(64 => 32, relu));
3434
@test length(layer.etypes) == 2
3535
end
36+
37+
@testset "Destination node aggregation" begin
38+
# deterministic setup to validate the aggregation
39+
d, n = 3, 5
40+
g = GNNHeteroGraph(((:A, :to, :B) => ([1, 1, 2, 3], [1, 2, 2, 3]),
41+
(:B, :to, :A) => ([1, 1, 2, 3], [1, 2, 2, 3]),
42+
(:C, :to, :A) => ([1, 1, 2, 3], [1, 2, 2, 3])); num_nodes = Dict(:A => n, :B => n, :C => n))
43+
model = HeteroGraphConv([
44+
(:A, :to, :B) => GraphConv(d => d, init = ones, bias = false),
45+
(:B, :to, :A) => GraphConv(d => d, init = ones, bias = false),
46+
(:C, :to, :A) => GraphConv(d => d, init = ones, bias = false)]; aggr = +)
47+
x = (A = rand(Float32, d, n), B = rand(Float32, d, n), C = rand(Float32, d, n))
48+
y = model(g, x)
49+
weights = ones(Float32, d, d)
50+
51+
### Test default summation aggregation
52+
# B2 has 2 edges from A and itself (sense check)
53+
expected = sum(weights * x.A[:, [1, 2]]; dims = 2) .+ weights * x.B[:, [2]]
54+
output = y.B[:, [2]]
55+
@test expected output
56+
57+
# B5 has only itself
58+
@test weights * x.B[:, [5]] y.B[:, [5]]
59+
60+
# A1 has 1 edge from B, 1 from C and twice itself
61+
expected = sum(weights * x.B[:, [1]] + weights * x.C[:, [1]]; dims = 2) .+
62+
2 * weights * x.A[:, [1]]
63+
output = y.A[:, [1]]
64+
@test expected output
65+
66+
# A2 has 2 edges from B, 2 from C and twice itself
67+
expected = sum(weights * x.B[:, [1, 2]] + weights * x.C[:, [1, 2]]; dims = 2) .+
68+
2 * weights * x.A[:, [2]]
69+
output = y.A[:, [2]]
70+
@test expected output
71+
72+
# A5 has only itself but twice
73+
@test 2 * weights * x.A[:, [5]] y.A[:, [5]]
74+
75+
#### Test different aggregation function
76+
model2 = HeteroGraphConv([
77+
(:A, :to, :B) => GraphConv(d => d, init = ones, bias = false),
78+
(:B, :to, :A) => GraphConv(d => d, init = ones, bias = false),
79+
(:C, :to, :A) => GraphConv(d => d, init = ones, bias = false)]; aggr = -)
80+
y2 = model2(g, x)
81+
# B no change
82+
@test y.B y2.B
83+
84+
# A1 has 1 edge from B, 1 from C, itself cancels out
85+
expected = sum(weights * x.B[:, [1]] - weights * x.C[:, [1]]; dims = 2)
86+
output = y2.A[:, [1]]
87+
@test expected output
88+
89+
# A2 has 2 edges from B, 2 from C, itself cancels out
90+
expected = sum(weights * x.B[:, [1, 2]] - weights * x.C[:, [1, 2]]; dims = 2)
91+
output = y2.A[:, [2]]
92+
@test expected output
93+
end
3694
end

0 commit comments

Comments
 (0)