Skip to content

Commit b1e1cb1

Browse files
authored
New add_self_loops(g) for hetero graphs (#402)
* new add self loops for hetero graphs * added support for sgconv * concise * newline * Update transform.jl * add self loop * change docs * docs * docs
1 parent 4c53e51 commit b1e1cb1

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

src/GNNGraphs/transform.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ end
4040

4141
"""
4242
add_self_loops(g::GNNHeteroGraph, edge_t::EType)
43+
add_self_loops(g::GNNHeteroGraph)
4344
4445
If the source node type is the same as the destination node type in `edge_t`,
4546
return a graph with the same features as `g` but also add self-loops
@@ -51,7 +52,10 @@ a second set of self-loops of the same type.
5152
If the graph has edge weights for edges of type `edge_t`, the new edges will have weight 1.
5253
5354
If no edges of type `edge_t` exist, or all existing edges have no weight,
54-
then all new self-loops will have no weight.
55+
then all new self loops will have no weight.
56+
57+
If `edge_t` is not passed as argument, for the entire graph self-loop is added to each node for every edge type in the graph where the source and destination node types are the same.
58+
This iterates over all edge types present in the graph, applying the self-loop addition logic to each applicable edge type.
5559
"""
5660
function add_self_loops(g::GNNHeteroGraph{Tuple{T, T, V}}, edge_t::EType) where {T <: AbstractVector{<:Integer}, V}
5761
function get_edge_weight_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType)
@@ -99,6 +103,13 @@ function add_self_loops(g::GNNHeteroGraph{Tuple{T, T, V}}, edge_t::EType) where
99103
ntypes, etypes)
100104
end
101105

106+
function add_self_loops(g::GNNHeteroGraph)
107+
for edge_t in keys(g.graph)
108+
g = add_self_loops(g, edge_t)
109+
end
110+
return g
111+
end
112+
102113
"""
103114
remove_self_loops(g::GNNGraph)
104115

src/layers/conv.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -419,11 +419,7 @@ function (l::GATConv)(g::AbstractGNNGraph, x,
419419

420420
if l.add_self_loops
421421
@assert e===nothing "Using edge features and setting add_self_loops=true at the same time is not yet supported."
422-
if g isa GNNHeteroGraph
423-
g = add_self_loops(g, g.etypes[1])
424-
else
425-
g = add_self_loops(g)
426-
end
422+
g = add_self_loops(g)
427423
end
428424

429425
_, chout = l.channel
@@ -1530,7 +1526,7 @@ function (l::SGConv)(g::AbstractGNNGraph, x,
15301526
end
15311527

15321528
if l.add_self_loops
1533-
g = g isa GNNHeteroGraph ? add_self_loops(g, edge_t) : add_self_loops(g)
1529+
g = add_self_loops(g)
15341530
if edge_weight !== nothing
15351531
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
15361532
@assert length(edge_weight) == g.num_edges

0 commit comments

Comments
 (0)