Skip to content

Commit 5148d96

Browse files
Add SAGEConv support to HeteroGraphConv (#384)
* SAGEConv Hetero Layer * without tests * test doesnt work yet * tests * tests should work * test update * temporary testing fast * final tests --------- Co-authored-by: rbSparky <[email protected]>
1 parent 2681e21 commit 5148d96

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

src/layers/conv.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -801,10 +801,11 @@ function SAGEConv(ch::Pair{Int, Int}, σ = identity; aggr = mean,
801801
SAGEConv(W, b, σ, aggr)
802802
end
803803

804-
function (l::SAGEConv)(g::GNNGraph, x::AbstractMatrix)
804+
function (l::SAGEConv)(g::AbstractGNNGraph, x)
805805
check_num_nodes(g, x)
806-
m = propagate(copy_xj, g, l.aggr, xj = x)
807-
x = l.σ.(l.weight * vcat(x, m) .+ l.bias)
806+
xj, xi = expand_srcdst(g, x)
807+
m = propagate(copy_xj, g, l.aggr, xj = xj)
808+
x = l.σ.(l.weight * vcat(xi, m) .+ l.bias)
808809
return x
809810
end
810811

test/layers/heteroconv.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,4 +109,11 @@
109109
@test size(y.A) == (2,2) && size(y.B) == (2,3)
110110
end
111111

112+
@testset "SAGEConv" begin
113+
x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3))
114+
layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(4 => 2, relu, bias = false, aggr = +),
115+
(:B, :to, :A) => SAGEConv(4 => 2, relu, bias = false, aggr = +));
116+
y = layers(hg, x);
117+
@test size(y.A) == (2, 2) && size(y.B) == (2, 3)
118+
end
112119
end

0 commit comments

Comments
 (0)