@@ -1502,46 +1502,49 @@ function SGConv(ch::Pair{Int, Int}, k = 1;
1502
1502
SGConv (W, b, k, add_self_loops, use_edge_weight)
1503
1503
end
1504
1504
1505
- function (l:: SGConv )(g:: GNNGraph , x:: AbstractMatrix{T} ,
1505
+ function (l:: SGConv )(g:: AbstractGNNGraph , x,
1506
1506
edge_weight:: EW = nothing ) where
1507
- {T, EW <: Union{Nothing, AbstractVector} }
1507
+ { EW <: Union{Nothing, AbstractVector} }
1508
1508
@assert ! (g isa GNNGraph{<: ADJMAT_T } && edge_weight != = nothing ) " Providing external edge_weight is not yet supported for adjacency matrix graphs"
1509
1509
1510
+ xj, xi = expand_srcdst (g, x)
1511
+ edge_t = g isa GNNHeteroGraph ? g. etypes[1 ] : nothing
1512
+ T = eltype (xi)
1513
+
1510
1514
if edge_weight != = nothing
1511
1515
@assert length (edge_weight)== g. num_edges " Wrong number of edge weights (expected $(g. num_edges) but given $(length (edge_weight)) )"
1512
1516
end
1513
1517
1514
1518
if l. add_self_loops
1515
- g = add_self_loops (g)
1519
+ g = g isa GNNHeteroGraph ? add_self_loops (g, edge_t) : add_self_loops (g)
1516
1520
if edge_weight != = nothing
1517
1521
edge_weight = [edge_weight; fill! (similar (edge_weight, g. num_nodes), 1 )]
1518
1522
@assert length (edge_weight) == g. num_edges
1519
1523
end
1520
1524
end
1521
1525
Dout, Din = size (l. weight)
1522
- if Dout < Din
1523
- x = l. weight * x
1524
- end
1525
- if edge_weight != = nothing
1526
- d = degree (g, T; dir = :in , edge_weight)
1526
+ if g isa GNNHeteroGraph
1527
+ d = degree (g, edge_t, T; dir = :in )
1527
1528
else
1528
- d = degree (g, T; dir = :in , edge_weight= l. use_edge_weight)
1529
+ if edge_weight != = nothing
1530
+ d = degree (g, T; dir = :in , edge_weight)
1531
+ else
1532
+ d = degree (g, T; dir = :in , edge_weight= l. use_edge_weight)
1533
+ end
1529
1534
end
1530
1535
c = 1 ./ sqrt .(d)
1531
1536
for iter in 1 : (l. k)
1532
1537
x = x .* c'
1533
1538
if edge_weight != = nothing
1534
- x = propagate (e_mul_xj, g, + , xj = x , e = edge_weight)
1539
+ x = propagate (e_mul_xj, g, + , xj = xj , e = edge_weight)
1535
1540
elseif l. use_edge_weight
1536
- x = propagate (w_mul_xj, g, + , xj = x )
1541
+ x = propagate (w_mul_xj, g, + , xj = xj )
1537
1542
else
1538
- x = propagate (copy_xj, g, + , xj = x )
1543
+ x = propagate (copy_xj, g, + , xj = xj )
1539
1544
end
1540
1545
x = x .* c'
1541
1546
end
1542
- if Dout >= Din
1543
- x = l. weight * x
1544
- end
1547
+ x = l. weight * x
1545
1548
return (x .+ l. bias)
1546
1549
end
1547
1550
0 commit comments