Skip to content

Commit 64b3cbb

Browse files
rbSparkyrishabh-aaryanaskorupka
authored
Add SGConv support for HeteroGraphConv (#383)
* sgconv hetero * WIP * tests work * Code references PR 367+399 by @askorupka --------- Co-authored-by: rbSparky <[email protected]> Co-authored-by: askorupka <[email protected]>
1 parent b55f0fa commit 64b3cbb

File tree

2 files changed

+27
-16
lines changed

2 files changed

+27
-16
lines changed

src/layers/conv.jl

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1502,46 +1502,49 @@ function SGConv(ch::Pair{Int, Int}, k = 1;
15021502
SGConv(W, b, k, add_self_loops, use_edge_weight)
15031503
end
15041504

1505-
function (l::SGConv)(g::GNNGraph, x::AbstractMatrix{T},
1505+
function (l::SGConv)(g::AbstractGNNGraph, x,
15061506
edge_weight::EW = nothing) where
1507-
{T, EW <: Union{Nothing, AbstractVector}}
1507+
{EW <: Union{Nothing, AbstractVector}}
15081508
@assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs"
15091509

1510+
xj, xi = expand_srcdst(g, x)
1511+
edge_t = g isa GNNHeteroGraph ? g.etypes[1] : nothing
1512+
T = eltype(xi)
1513+
15101514
if edge_weight !== nothing
15111515
@assert length(edge_weight)==g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))"
15121516
end
15131517

15141518
if l.add_self_loops
1515-
g = add_self_loops(g)
1519+
g = g isa GNNHeteroGraph ? add_self_loops(g, edge_t) : add_self_loops(g)
15161520
if edge_weight !== nothing
15171521
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
15181522
@assert length(edge_weight) == g.num_edges
15191523
end
15201524
end
15211525
Dout, Din = size(l.weight)
1522-
if Dout < Din
1523-
x = l.weight * x
1524-
end
1525-
if edge_weight !== nothing
1526-
d = degree(g, T; dir = :in, edge_weight)
1526+
if g isa GNNHeteroGraph
1527+
d = degree(g, edge_t, T; dir = :in)
15271528
else
1528-
d = degree(g, T; dir = :in, edge_weight=l.use_edge_weight)
1529+
if edge_weight !== nothing
1530+
d = degree(g, T; dir = :in, edge_weight)
1531+
else
1532+
d = degree(g, T; dir = :in, edge_weight=l.use_edge_weight)
1533+
end
15291534
end
15301535
c = 1 ./ sqrt.(d)
15311536
for iter in 1:(l.k)
15321537
x = x .* c'
15331538
if edge_weight !== nothing
1534-
x = propagate(e_mul_xj, g, +, xj = x, e = edge_weight)
1539+
x = propagate(e_mul_xj, g, +, xj = xj, e = edge_weight)
15351540
elseif l.use_edge_weight
1536-
x = propagate(w_mul_xj, g, +, xj = x)
1541+
x = propagate(w_mul_xj, g, +, xj = xj)
15371542
else
1538-
x = propagate(copy_xj, g, +, xj = x)
1543+
x = propagate(copy_xj, g, +, xj = xj)
15391544
end
15401545
x = x .* c'
15411546
end
1542-
if Dout >= Din
1543-
x = l.weight * x
1544-
end
1547+
x = l.weight * x
15451548
return (x .+ l.bias)
15461549
end
15471550

test/layers/heteroconv.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,15 @@
108108
y = layers(hg, x);
109109
@test size(y.A) == (2,2) && size(y.B) == (2,3)
110110
end
111-
111+
112+
@testset "SGConv" begin
113+
x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3))
114+
layers = HeteroGraphConv((:A, :to, :B) => SGConv(4 => 2),
115+
(:B, :to, :A) => SGConv(4 => 2));
116+
y = layers(hg, x);
117+
@test size(y.A) == (2, 2) && size(y.B) == (2, 3)
118+
end
119+
112120
@testset "SAGEConv" begin
113121
x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3))
114122
layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(4 => 2, relu, bias = false, aggr = +),

0 commit comments

Comments
 (0)