Skip to content

Commit 7ddadab

Browse files
feat: Add EdgeConv support for HeteroGraphConv (#364)
* add edgeconv heterograph support * add test & resolve conflicts * cleanup * Update test/layers/heteroconv.jl Co-authored-by: Carlo Lucibello <[email protected]> * cleanup after rebase --------- Co-authored-by: Carlo Lucibello <[email protected]>
1 parent b07aaa2 commit 7ddadab

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

src/layers/conv.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -637,10 +637,13 @@ end
637637

638638
EdgeConv(nn; aggr = max) = EdgeConv(nn, aggr)
639639

640-
function (l::EdgeConv)(g::GNNGraph, x::AbstractMatrix)
640+
function (l::EdgeConv)(g::AbstractGNNGraph, x)
641641
check_num_nodes(g, x)
642+
xj, xi = expand_srcdst(g, x)
643+
642644
message(l, xi, xj, e) = l.nn(vcat(xi, xj .- xi))
643-
x = propagate(message, g, l.aggr, l, xi = x, xj = x)
645+
646+
x = propagate(message, g, l.aggr, l, xi = xi, xj = xj, e = nothing)
644647
return x
645648
end
646649

test/layers/heteroconv.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,14 @@
100100
y = layers(g, x);
101101
@test size(y.A) == (2,2) && size(y.B) == (2,3)
102102
end
103+
104+
@testset "EdgeConv" begin
105+
g = rand_bipartite_heterograph((2,3), 6)
106+
x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3))
107+
layers = HeteroGraphConv( (:A, :to, :B) => EdgeConv(Dense(2 * 4, 2), aggr = +),
108+
(:B, :to, :A) => EdgeConv(Dense(2 * 4, 2), aggr = +));
109+
y = layers(g, x);
110+
@test size(y.A) == (2,2) && size(y.B) == (2,3)
111+
end
112+
103113
end

0 commit comments

Comments
 (0)