Skip to content

Commit fa584e5

Browse files
authored
Add heterogeneous add_self_loop support (#345)
* Add heterogeneous add_self_loop support * Fixed array types for add_self_loop on GNNHeteroGraph * Remove debugging info
1 parent 25b0323 commit fa584e5

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

src/GNNGraphs/transform.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,64 @@ function add_self_loops(g::GNNGraph{<:ADJMAT_T})
3838
g.ndata, g.edata, g.gdata)
3939
end
4040

41+
"""
42+
add_self_loops(g::GNNHeteroGraph, edge_t::EType)
43+
44+
Return a graph with the same features as `g`
45+
but also adding self-loops of the specified type, edge_t
46+
47+
Nodes with already existing self-loops of type edge_t will obtain a second self-loop of type edge_t.
48+
49+
If the graphs has edge weights for edges of type edge_t, the new edges will have weight 1.
50+
51+
If no edges of type edge_t exist, or all existing edges have no weight, then all new self loops will have no weight.
52+
"""
53+
function add_self_loops(g::GNNHeteroGraph{Tuple{T, T, V}}, edge_t::EType) where {T <: AbstractVector{<:Integer}, V}
54+
function get_edge_weight_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType)
55+
get(g.graph, edge_t, (nothing, nothing, nothing))[3]
56+
end
57+
58+
src_t, _, tgt_t = edge_t
59+
(src_t === tgt_t) ||
60+
@error "cannot add a self-loop with different source and target types"
61+
62+
n = get(g.num_nodes, src_t, 0)
63+
64+
if haskey(g.graph, edge_t)
65+
x = g.graph[edge_t]
66+
s, t = x[1:2]
67+
nodes = convert(typeof(s), [1:n;])
68+
s = [s; nodes]
69+
t = [t; nodes]
70+
else
71+
nodes = convert(T, [1:n;])
72+
s = nodes
73+
t = nodes
74+
end
75+
76+
graph = g.graph |> copy
77+
ew = get(g.graph, edge_t, (nothing, nothing, nothing))[3]
78+
79+
if ew !== nothing
80+
ew = [ew; fill!(similar(ew, n), 1)]
81+
end
82+
83+
graph[edge_t] = (s, t, ew)
84+
edata = g.edata |> copy
85+
ndata = g.ndata |> copy
86+
ntypes = g.ntypes |> copy
87+
etypes = g.etypes |> copy
88+
num_nodes = g.num_nodes |> copy
89+
num_edges = g.num_edges |> copy
90+
num_edges[edge_t] = length(get(graph, edge_t, ([],[]))[1])
91+
92+
return GNNHeteroGraph(graph,
93+
num_nodes, num_edges, g.num_graphs,
94+
g.graph_indicator,
95+
ndata, edata, g.gdata,
96+
ntypes, etypes)
97+
end
98+
4199
"""
42100
remove_self_loops(g::GNNGraph)
43101

test/GNNGraphs/transform.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,4 +462,30 @@ end
462462
@test get_edge_weight(hgnew2, (:user, :like, :actor)) == [0.5, 0.6, 0.7, 0.8]
463463
end
464464
end
465+
466+
@testset "add self-loops heterographs" begin
467+
g = rand_heterograph((:A =>10, :B => 14), ((:A, :to1, :A) => 5, (:A, :to1, :B) => 20))
468+
# Case in which haskey(g.graph, edge_t) passes
469+
g = add_self_loops(g, (:A, :to1, :A))
470+
471+
@test g.num_edges[(:A, :to1, :A)] == 5 + 10
472+
@test g.num_edges[(:A, :to1, :B)] == 20
473+
# This test should not use length(keys(g.num_edges)) since that may be undefined behavior
474+
@test sum(1 for k in keys(g.num_edges) if g.num_edges[k] != 0) == 2
475+
476+
# Case in which haskey(g.graph, edge_t) fails
477+
g = add_self_loops(g, (:A, :to3, :A))
478+
479+
@test g.num_edges[(:A, :to1, :A)] == 5 + 10
480+
@test g.num_edges[(:A, :to1, :B)] == 20
481+
@test g.num_edges[(:A, :to3, :A)] == 10
482+
@test sum(1 for k in keys(g.num_edges) if g.num_edges[k] != 0) == 3
483+
484+
# Case with edge weights
485+
g = GNNHeteroGraph(Dict((:A, :to1, :A) => ([1, 2, 3], [3, 2, 1], [2, 2, 2]), (:A, :to2, :B) => ([1, 4, 5], [1, 2, 3])))
486+
n = g.num_nodes[:A]
487+
g = add_self_loops(g, (:A, :to1, :A))
488+
489+
@test g.graph[(:A, :to1, :A)][3] == vcat([2, 2, 2], fill(1, n))
490+
end
465491
end

0 commit comments

Comments
 (0)