Skip to content

Commit b07aaa2

Browse files
authored
feat: Add CGConv support for HeteroGraphConv (#363)
* add gnnheterograph support for cgconv * remove changes for agnnconv for now * add test
1 parent 1dafb8d commit b07aaa2

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

src/layers/conv.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -944,14 +944,16 @@ function CGConv(ch::Pair{NTuple{2, Int}, Int}, act = identity; residual = false,
944944
return CGConv(ch, dense_f, dense_s, residual)
945945
end
946946

947-
function (l::CGConv)(g::GNNGraph, x::AbstractMatrix,
947+
function (l::CGConv)(g::AbstractGNNGraph, x,
948948
e::Union{Nothing, AbstractMatrix} = nothing)
949949
check_num_nodes(g, x)
950+
xj, xi = expand_srcdst(g, x)
951+
950952
if e !== nothing
951953
check_num_edges(g, e)
952954
end
953955

954-
m = propagate(message, g, +, l, xi = x, xj = x, e = e)
956+
m = propagate(message, g, +, l, xi = xi, xj = xj, e = e)
955957

956958
if l.residual
957959
if size(x, 1) == size(m, 1)
@@ -964,6 +966,7 @@ function (l::CGConv)(g::GNNGraph, x::AbstractMatrix,
964966
return m
965967
end
966968

969+
967970
function message(l::CGConv, xi, xj, e)
968971
if e !== nothing
969972
z = vcat(xi, xj, e)

src/msgpass.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ end
8787
# https://github.com/JuliaLang/julia/issues/15276
8888
## and zygote issues
8989
# https://github.com/FluxML/Zygote.jl/issues/1317
90-
function propagate(f, g::GNNGraph, aggr, l::GNNLayer; xi = nothing, xj = nothing,
90+
function propagate(f, g::AbstractGNNGraph, aggr, l::GNNLayer; xi = nothing, xj = nothing,
9191
e = nothing)
9292
propagate((xi, xj, e) -> f(l, xi, xj, e), g, aggr, xi, xj, e)
9393
end
94-
function propagate(f, g::GNNGraph, aggr, l::GNNLayer, xi, xj, e = nothing)
94+
function propagate(f, g::AbstractGNNGraph, aggr, l::GNNLayer, xi, xj, e = nothing)
9595
propagate((xi, xj, e) -> f(l, xi, xj, e), g, aggr, xi, xj, e)
9696
end
9797

test/layers/heteroconv.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,4 +91,13 @@
9191
output = y2.A[:, [2]]
9292
@test expected output
9393
end
94+
95+
@testset "CGConv" begin
96+
g = rand_bipartite_heterograph((2,3), 6)
97+
x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3))
98+
layers = HeteroGraphConv( (:A, :to, :B) => CGConv(4 => 2, relu),
99+
(:B, :to, :A) => CGConv(4 => 2, relu));
100+
y = layers(g, x);
101+
@test size(y.A) == (2,2) && size(y.B) == (2,3)
102+
end
94103
end

0 commit comments

Comments
 (0)