Skip to content

Commit 005c5e7

Browse files
authored
fix: remove SGConv GNNHeteroGraph support (#416)
1 parent f05af33 commit 005c5e7

File tree

2 files changed

+16
-25
lines changed

2 files changed

+16
-25
lines changed

src/layers/conv.jl

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,15 +1522,13 @@ function SGConv(ch::Pair{Int, Int}, k = 1;
15221522
SGConv(W, b, k, add_self_loops, use_edge_weight)
15231523
end
15241524

1525-
function (l::SGConv)(g::AbstractGNNGraph, x,
1525+
# this layer is not stable enough to be supported by GNNHeteroGraph type
1526+
# due to it's looping mechanism
1527+
function (l::SGConv)(g::GNNGraph, x::AbstractMatrix{T},
15261528
edge_weight::EW = nothing) where
1527-
{EW <: Union{Nothing, AbstractVector}}
1529+
{T, EW <: Union{Nothing, AbstractVector}}
15281530
@assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs"
15291531

1530-
xj, xi = expand_srcdst(g, x)
1531-
edge_t = g isa GNNHeteroGraph ? g.etypes[1] : nothing
1532-
T = eltype(xi)
1533-
15341532
if edge_weight !== nothing
15351533
@assert length(edge_weight)==g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))"
15361534
end
@@ -1543,28 +1541,29 @@ function (l::SGConv)(g::AbstractGNNGraph, x,
15431541
end
15441542
end
15451543
Dout, Din = size(l.weight)
1546-
if g isa GNNHeteroGraph
1547-
d = degree(g, edge_t, T; dir = :in)
1544+
if Dout < Din
1545+
x = l.weight * x
1546+
end
1547+
if edge_weight !== nothing
1548+
d = degree(g, T; dir = :in, edge_weight)
15481549
else
1549-
if edge_weight !== nothing
1550-
d = degree(g, T; dir = :in, edge_weight)
1551-
else
1552-
d = degree(g, T; dir = :in, edge_weight=l.use_edge_weight)
1553-
end
1550+
d = degree(g, T; dir = :in, edge_weight=l.use_edge_weight)
15541551
end
15551552
c = 1 ./ sqrt.(d)
15561553
for iter in 1:(l.k)
15571554
x = x .* c'
15581555
if edge_weight !== nothing
1559-
x = propagate(e_mul_xj, g, +, xj = xj, e = edge_weight)
1556+
x = propagate(e_mul_xj, g, +, xj = x, e = edge_weight)
15601557
elseif l.use_edge_weight
1561-
x = propagate(w_mul_xj, g, +, xj = xj)
1558+
x = propagate(w_mul_xj, g, +, xj = x)
15621559
else
1563-
x = propagate(copy_xj, g, +, xj = xj)
1560+
x = propagate(copy_xj, g, +, xj = x)
15641561
end
15651562
x = x .* c'
15661563
end
1567-
x = l.weight * x
1564+
if Dout >= Din
1565+
x = l.weight * x
1566+
end
15681567
return (x .+ l.bias)
15691568
end
15701569

test/layers/heteroconv.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,6 @@
109109
@test size(y.A) == (2,2) && size(y.B) == (2,3)
110110
end
111111

112-
@testset "SGConv" begin
113-
x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3))
114-
layers = HeteroGraphConv((:A, :to, :B) => SGConv(4 => 2),
115-
(:B, :to, :A) => SGConv(4 => 2));
116-
y = layers(hg, x);
117-
@test size(y.A) == (2, 2) && size(y.B) == (2, 3)
118-
end
119-
120112
@testset "SAGEConv" begin
121113
x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3))
122114
layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(4 => 2, relu, bias = false, aggr = +),

0 commit comments

Comments
 (0)