Skip to content

Commit 0c641a2

Browse files
authored
feat: Add GCNConv support for HeteroGraphConv (#367)
* WIP: working state but with degree placeholder * clean up comments * add test * include edge type in degree calc for gnnheterograph * add self loops for gnnheterograph * add TODO comment * add TODO comment * fix failing test * update the new add_self_loops_behavior * change empty string to nothing for memory optimization * add GCNConv support for HeteroGraphConv * fix tests * GCN tests passing * add small optimization to reduce ifs * avoid repeated code * add PR review suggestion * run all tests
1 parent 95e8392 commit 0c641a2

File tree

2 files changed

+33
-16
lines changed

2 files changed

+33
-16
lines changed

src/layers/conv.jl

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -88,23 +88,22 @@ function GCNConv(ch::Pair{Int, Int}, σ = identity;
8888
GCNConv(W, b, σ, add_self_loops, use_edge_weight)
8989
end
9090

91-
check_gcnconv_input(g::GNNGraph{<:ADJMAT_T}, edge_weight::AbstractVector) =
91+
check_gcnconv_input(g::AbstractGNNGraph{<:ADJMAT_T}, edge_weight::AbstractVector) =
9292
throw(ArgumentError("Providing external edge_weight is not yet supported for adjacency matrix graphs"))
9393

94-
function check_gcnconv_input(g::GNNGraph, edge_weight::AbstractVector)
94+
function check_gcnconv_input(g::AbstractGNNGraph, edge_weight::AbstractVector)
9595
if length(edge_weight) !== g.num_edges
9696
throw(ArgumentError("Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))"))
9797
end
9898
end
9999

100-
check_gcnconv_input(g::GNNGraph, edge_weight::Nothing) = nothing
100+
check_gcnconv_input(g::AbstractGNNGraph, edge_weight::Nothing) = nothing
101101

102-
103-
function (l::GCNConv)(g::GNNGraph,
104-
x::AbstractMatrix{T},
102+
function (l::GCNConv)(g::AbstractGNNGraph,
103+
x,
105104
edge_weight::EW = nothing,
106105
norm_fn::Function = d -> 1 ./ sqrt.(d)
107-
) where {T, EW <: Union{Nothing, AbstractVector}}
106+
) where {EW <: Union{Nothing, AbstractVector}}
108107

109108
check_gcnconv_input(g, edge_weight)
110109

@@ -118,26 +117,35 @@ function (l::GCNConv)(g::GNNGraph,
118117
end
119118
end
120119
Dout, Din = size(l.weight)
121-
if Dout < Din
120+
if Dout < Din && !(g isa GNNHeteroGraph)
122121
# multiply before convolution if it is more convenient, otherwise multiply after
122+
# (this works only for homogenous graph)
123123
x = l.weight * x
124124
end
125-
if edge_weight !== nothing
126-
d = degree(g, T; dir = :in, edge_weight)
125+
126+
xj, xi = expand_srcdst(g, x) # expand only after potential multiplication
127+
T = eltype(xi)
128+
129+
if g isa GNNHeteroGraph
130+
d = degree(g, g.etypes[1], T; dir = :in)
127131
else
128-
d = degree(g, T; dir = :in, edge_weight = l.use_edge_weight)
132+
if edge_weight !== nothing
133+
d = degree(g, T; dir = :in, edge_weight)
134+
else
135+
d = degree(g, T; dir = :in, edge_weight = l.use_edge_weight)
136+
end
129137
end
130138
c = norm_fn(d)
131-
x = x .* c'
139+
!(g isa GNNHeteroGraph) ? xj = xj .* c' : Nothing
132140
if edge_weight !== nothing
133-
x = propagate(e_mul_xj, g, +, xj = x, e = edge_weight)
141+
x = propagate(e_mul_xj, g, +, xj = xj, e = edge_weight)
134142
elseif l.use_edge_weight
135-
x = propagate(w_mul_xj, g, +, xj = x)
143+
x = propagate(w_mul_xj, g, +, xj = xj)
136144
else
137-
x = propagate(copy_xj, g, +, xj = x)
145+
x = propagate(copy_xj, g, +, xj = xj)
138146
end
139147
x = x .* c'
140-
if Dout >= Din
148+
if Dout >= Din || g isa GNNHeteroGraph
141149
x = l.weight * x
142150
end
143151
return l.σ.(x .+ l.bias)

test/layers/heteroconv.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,4 +156,13 @@
156156
y = layers(hg, x);
157157
@test size(y.A) == (2, 2) && size(y.B) == (2, 3)
158158
end
159+
160+
@testset "GCNConv" begin
161+
g = rand_bipartite_heterograph((2,3), 6)
162+
x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3))
163+
layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, relu),
164+
(:B, :to, :A) => GCNConv(4 => 2, relu));
165+
y = layers(g, x);
166+
@test size(y.A) == (2,2) && size(y.B) == (2,3)
167+
end
159168
end

0 commit comments

Comments
 (0)