Skip to content

Commit 983669e

Browse files
authored
refinement: Self loops for HeteroGraph returns g instead of error if src != tgt (#373)
* return g instead of error if src != tgt * add test and docs * fix typo in test
1 parent 8a6802b commit 983669e

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

src/GNNGraphs/transform.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,17 @@ end
4141
"""
4242
add_self_loops(g::GNNHeteroGraph, edge_t::EType)
4343
44-
Return a graph with the same features as `g`
45-
but also adding self-loops of the specified type, edge_t
44+
If the source node type is the same as destination node type in `edge_t`,
45+
return a graph with the same features as `g` but also adding self-loops
46+
of the specified type, `edge_t`. Otherwise it returns `g` unchanged.
4647
47-
Nodes with already existing self-loops of type edge_t will obtain a second self-loop of type edge_t.
48+
Nodes with already existing self-loops of type edge_t will obtain
49+
a second self-loop of type edge_t.
4850
4951
If the graphs has edge weights for edges of type edge_t, the new edges will have weight 1.
5052
51-
If no edges of type edge_t exist, or all existing edges have no weight, then all new self loops will have no weight.
53+
If no edges of type edge_t exist, or all existing edges have no weight,
54+
then all new self loops will have no weight.
5255
"""
5356
function add_self_loops(g::GNNHeteroGraph{Tuple{T, T, V}}, edge_t::EType) where {T <: AbstractVector{<:Integer}, V}
5457
function get_edge_weight_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType)
@@ -57,7 +60,7 @@ function add_self_loops(g::GNNHeteroGraph{Tuple{T, T, V}}, edge_t::EType) where
5760

5861
src_t, _, tgt_t = edge_t
5962
(src_t === tgt_t) ||
60-
@error "cannot add a self-loop with different source and target types"
63+
return g
6164

6265
n = get(g.num_nodes, src_t, 0)
6366

test/GNNGraphs/gnnheterograph.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,15 @@ end
177177
@test g3.num_nodes[:C] == 10
178178
end
179179

180+
@testset "add self loops" begin
181+
g1 = GNNHeteroGraph((:A, :to, :B) => ([1,2,3,4], [3,2,1,5]))
182+
g2 = add_self_loops(g1, (:A, :to, :B))
183+
@test g2.num_edges[(:A, :to, :B)] === g1.num_edges[(:A, :to, :B)]
184+
g1 = GNNHeteroGraph((:A, :to, :A) => ([1,2,3,4], [3,2,1,5]))
185+
g2 = add_self_loops(g1, (:A, :to, :A))
186+
@test g2.num_edges[(:A, :to, :A)] === g1.num_edges[(:A, :to, :A)] + g1.num_nodes[(:A)]
187+
end
188+
180189
## Cannot test this because DataStore is not an ordered collection
181190
## Uncomment when/if it will be based on OrderedDict
182191
# @testset "show" begin

0 commit comments

Comments
 (0)