Skip to content

Commit 938f060

Browse files
authored
Add GATv2Conv support to HeteroGraphConv (#407)
* hetero support * Update conv.jl
1 parent 74e2e07 commit 938f060

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

src/layers/conv.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -579,21 +579,23 @@ end
579579

580580
(l::GATv2Conv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g)))
581581

582-
function (l::GATv2Conv)(g::GNNGraph, x::AbstractMatrix,
582+
function (l::GATv2Conv)(g::AbstractGNNGraph, x,
583583
e::Union{Nothing, AbstractMatrix} = nothing)
584584
check_num_nodes(g, x)
585585
@assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer"
586586
@assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor"
587587

588+
xj, xi = expand_srcdst(g, x)
589+
588590
if l.add_self_loops
589591
@assert e===nothing "Using edge features and setting add_self_loops=true at the same time is not yet supported."
590592
g = add_self_loops(g)
591593
end
592594
_, out = l.channel
593595
heads = l.heads
594596

595-
Wxi = reshape(l.dense_i(x), out, heads, :) # out × heads × nnodes
596-
Wxj = reshape(l.dense_j(x), out, heads, :) # out × heads × nnodes
597+
Wxi = reshape(l.dense_i(xi), out, heads, :) # out × heads × nnodes
598+
Wxj = reshape(l.dense_j(xj), out, heads, :) # out × heads × nnodes
597599

598600
m = apply_edges((xi, xj, e) -> message(l, xi, xj, e), g, Wxi, Wxj, e)
599601
α = softmax_edge_neighbors(g, m.logα)

test/layers/heteroconv.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,4 +148,12 @@
148148
y = layers(hg, x);
149149
@test size(y.A) == (2, 2) && size(y.B) == (2, 3)
150150
end
151+
152+
@testset "GATv2Conv" begin
153+
x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3))
154+
layers = HeteroGraphConv((:A, :to, :B) => GATv2Conv(4 => 2),
155+
(:B, :to, :A) => GATv2Conv(4 => 2));
156+
y = layers(hg, x);
157+
@test size(y.A) == (2, 2) && size(y.B) == (2, 3)
158+
end
151159
end

0 commit comments

Comments
 (0)