Skip to content

Commit acc9ead

Browse files
authored
fix: GCNConv support for GNNHeteroGraph (#417)
* fix: GCNConv support for GNNHeteroGraph * add PR review suggestion
1 parent 005c5e7 commit acc9ead

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/layers/conv.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,24 +127,28 @@ function (l::GCNConv)(g::AbstractGNNGraph,
127127
T = eltype(xi)
128128

129129
if g isa GNNHeteroGraph
130-
d = degree(g, g.etypes[1], T; dir = :in)
130+
din = degree(g, g.etypes[1], T; dir = :in)
131+
dout = degree(g, g.etypes[1], T; dir = :out)
132+
133+
cout = norm_fn(dout)
134+
cin = norm_fn(din)
131135
else
132136
if edge_weight !== nothing
133137
d = degree(g, T; dir = :in, edge_weight)
134138
else
135139
d = degree(g, T; dir = :in, edge_weight = l.use_edge_weight)
136140
end
141+
cin = cout = norm_fn(d)
137142
end
138-
c = norm_fn(d)
139-
!(g isa GNNHeteroGraph) ? xj = xj .* c' : Nothing
143+
xj = xj .* cout'
140144
if edge_weight !== nothing
141145
x = propagate(e_mul_xj, g, +, xj = xj, e = edge_weight)
142146
elseif l.use_edge_weight
143147
x = propagate(w_mul_xj, g, +, xj = xj)
144148
else
145149
x = propagate(copy_xj, g, +, xj = xj)
146150
end
147-
x = x .* c'
151+
x = x .* cin'
148152
if Dout >= Din || g isa GNNHeteroGraph
149153
x = l.weight * x
150154
end

0 commit comments

Comments
 (0)