Skip to content

Commit 4c53e51

Browse files
Add GATConv support for HeteroGraphConv (#400)
* gat hetero support * Update src/layers/conv.jl Co-authored-by: Carlo Lucibello <[email protected]> * changes made * fix * Update src/layers/conv.jl Co-authored-by: Carlo Lucibello <[email protected]> --------- Co-authored-by: Carlo Lucibello <[email protected]>
1 parent b26f084 commit 4c53e51

File tree

3 files changed

+32
-7
lines changed

3 files changed

+32
-7
lines changed

src/layers/conv.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -409,25 +409,36 @@ end
409409

410410
(l::GATConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g)))
411411

412-
function (l::GATConv)(g::GNNGraph, x::AbstractMatrix,
412+
function (l::GATConv)(g::AbstractGNNGraph, x,
413413
e::Union{Nothing, AbstractMatrix} = nothing)
414414
check_num_nodes(g, x)
415415
@assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer"
416416
@assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor"
417417

418+
xj, xi = expand_srcdst(g, x)
419+
418420
if l.add_self_loops
419421
@assert e===nothing "Using edge features and setting add_self_loops=true at the same time is not yet supported."
420-
g = add_self_loops(g)
422+
if g isa GNNHeteroGraph
423+
g = add_self_loops(g, g.etypes[1])
424+
else
425+
g = add_self_loops(g)
426+
end
421427
end
422428

423429
_, chout = l.channel
424430
heads = l.heads
425431

426-
Wx = l.dense_x(x)
427-
Wx = reshape(Wx, chout, heads, :) # chout × nheads × nnodes
432+
Wxi = Wxj = l.dense_x(xj)
433+
Wxi = Wxj = reshape(Wxj, chout, heads, :)
434+
435+
if xi !== xj
436+
Wxi = l.dense_x(xi)
437+
Wxi = reshape(Wxi, chout, heads, :)
438+
end
428439

429440
# a hand-written message passing
430-
m = apply_edges((xi, xj, e) -> message(l, xi, xj, e), g, Wx, Wx, e)
441+
m = apply_edges((xi, xj, e) -> message(l, xi, xj, e), g, Wxi, Wxj, e)
431442
α = softmax_edge_neighbors(g, m.logα)
432443
β = α .* m.Wxj
433444
x = aggregate_neighbors(g, +, β)

src/utils.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,14 @@ Softmax over each node's neighborhood of the edge features `e`.
8181
{\sum_{j'\in N(i)} e^{\mathbf{e}_{j'\to i}}}.
8282
```
8383
"""
84-
function softmax_edge_neighbors(g::GNNGraph, e)
85-
@assert size(e)[end] == g.num_edges
84+
function softmax_edge_neighbors(g::AbstractGNNGraph, e)
85+
if g isa GNNHeteroGraph
86+
for (key, value) in g.num_edges
87+
@assert size(e)[end] == value
88+
end
89+
else
90+
@assert size(e)[end] == g.num_edges
91+
end
8692
s, t = edge_index(g)
8793
max_ = gather(scatter(max, e, t), t)
8894
num = exp.(e .- max_)

test/layers/heteroconv.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,14 @@
125125
@test size(y.A) == (2, 2) && size(y.B) == (2, 3)
126126
end
127127

128+
@testset "GATConv" begin
129+
x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3))
130+
layers = HeteroGraphConv((:A, :to, :B) => GATConv(4 => 2),
131+
(:B, :to, :A) => GATConv(4 => 2));
132+
y = layers(hg, x);
133+
@test size(y.A) == (2, 2) && size(y.B) == (2, 3)
134+
end
135+
128136
@testset "GINConv" begin
129137
x = (A = rand(4, 2), B = rand(4, 3))
130138
layers = HeteroGraphConv((:A, :to, :B) => GINConv(Dense(4, 2), 0.4),

0 commit comments

Comments
 (0)