Skip to content

Commit 5a6bb6a

Browse files
authored
resgated hetero (#391)
1 parent 884b473 commit 5a6bb6a

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

src/layers/conv.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -863,18 +863,19 @@ function ResGatedGraphConv(ch::Pair{Int, Int}, σ = identity;
863863
return ResGatedGraphConv(A, B, U, V, b, σ)
864864
end
865865

866-
function (l::ResGatedGraphConv)(g::GNNGraph, x::AbstractMatrix)
866+
function (l::ResGatedGraphConv)(g::AbstractGNNGraph, x)
867867
check_num_nodes(g, x)
868+
xj, xi = expand_srcdst(g, x)
868869

869870
message(xi, xj, e) = sigmoid.(xi.Ax .+ xj.Bx) .* xj.Vx
870871

871-
Ax = l.A * x
872-
Bx = l.B * x
873-
Vx = l.V * x
872+
Ax = l.A * xi
873+
Bx = l.B * xj
874+
Vx = l.V * xj
874875

875876
m = propagate(message, g, +, xi = (; Ax), xj = (; Bx, Vx))
876877

877-
return l.σ.(l.U * x .+ m .+ l.bias)
878+
return l.σ.(l.U * xi .+ m .+ l.bias)
878879
end
879880

880881
function Base.show(io::IO, l::ResGatedGraphConv)

test/layers/heteroconv.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,12 @@
116116
y = layers(hg, x);
117117
@test size(y.A) == (2, 2) && size(y.B) == (2, 3)
118118
end
119+
120+
@testset "ResGatedGraphConv" begin
121+
x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3))
122+
layers = HeteroGraphConv((:A, :to, :B) => ResGatedGraphConv(4 => 2),
123+
(:B, :to, :A) => ResGatedGraphConv(4 => 2));
124+
y = layers(hg, x);
125+
@test size(y.A) == (2, 2) && size(y.B) == (2, 3)
126+
end
119127
end

0 commit comments

Comments
 (0)