@@ -1522,15 +1522,13 @@ function SGConv(ch::Pair{Int, Int}, k = 1;
1522
1522
SGConv (W, b, k, add_self_loops, use_edge_weight)
1523
1523
end
1524
1524
1525
- function (l:: SGConv )(g:: AbstractGNNGraph , x,
1525
+ # this layer is not stable enough to be supported by GNNHeteroGraph type
1526
+ # due to it's looping mechanism
1527
+ function (l:: SGConv )(g:: GNNGraph , x:: AbstractMatrix{T} ,
1526
1528
edge_weight:: EW = nothing ) where
1527
- { EW <: Union{Nothing, AbstractVector} }
1529
+ {T, EW <: Union{Nothing, AbstractVector} }
1528
1530
@assert ! (g isa GNNGraph{<: ADJMAT_T } && edge_weight != = nothing ) " Providing external edge_weight is not yet supported for adjacency matrix graphs"
1529
1531
1530
- xj, xi = expand_srcdst (g, x)
1531
- edge_t = g isa GNNHeteroGraph ? g. etypes[1 ] : nothing
1532
- T = eltype (xi)
1533
-
1534
1532
if edge_weight != = nothing
1535
1533
@assert length (edge_weight)== g. num_edges " Wrong number of edge weights (expected $(g. num_edges) but given $(length (edge_weight)) )"
1536
1534
end
@@ -1543,28 +1541,29 @@ function (l::SGConv)(g::AbstractGNNGraph, x,
1543
1541
end
1544
1542
end
1545
1543
Dout, Din = size (l. weight)
1546
- if g isa GNNHeteroGraph
1547
- d = degree (g, edge_t, T; dir = :in )
1544
+ if Dout < Din
1545
+ x = l. weight * x
1546
+ end
1547
+ if edge_weight != = nothing
1548
+ d = degree (g, T; dir = :in , edge_weight)
1548
1549
else
1549
- if edge_weight != = nothing
1550
- d = degree (g, T; dir = :in , edge_weight)
1551
- else
1552
- d = degree (g, T; dir = :in , edge_weight= l. use_edge_weight)
1553
- end
1550
+ d = degree (g, T; dir = :in , edge_weight= l. use_edge_weight)
1554
1551
end
1555
1552
c = 1 ./ sqrt .(d)
1556
1553
for iter in 1 : (l. k)
1557
1554
x = x .* c'
1558
1555
if edge_weight != = nothing
1559
- x = propagate (e_mul_xj, g, + , xj = xj , e = edge_weight)
1556
+ x = propagate (e_mul_xj, g, + , xj = x , e = edge_weight)
1560
1557
elseif l. use_edge_weight
1561
- x = propagate (w_mul_xj, g, + , xj = xj )
1558
+ x = propagate (w_mul_xj, g, + , xj = x )
1562
1559
else
1563
- x = propagate (copy_xj, g, + , xj = xj )
1560
+ x = propagate (copy_xj, g, + , xj = x )
1564
1561
end
1565
1562
x = x .* c'
1566
1563
end
1567
- x = l. weight * x
1564
+ if Dout >= Din
1565
+ x = l. weight * x
1566
+ end
1568
1567
return (x .+ l. bias)
1569
1568
end
1570
1569
0 commit comments