Skip to content

Commit d2abe2e

Browse files
fix addselfloops
1 parent 7909fdc commit d2abe2e

File tree

3 files changed

+7
-18
lines changed

3 files changed

+7
-18
lines changed

src/featuredgraph.jl

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -353,23 +353,18 @@ _eigmax(A) = KrylovKit.eigsolve(Symmetric(A), 1, :LR)[1][1] # also eigs(A, x0, n
353353
# https://discourse.julialang.org/t/cuda-eigenvalues-of-a-sparse-matrix/46851/5
354354

355355
"""
356-
add_self_loops(fg::FeaturedGraph; add_to_existing=true)
356+
add_self_loops(fg::FeaturedGraph)
357357
358358
Return a featured graph with the same features as `fg`
359359
but also adding edges connecting the nodes to themselves.
360360
361-
If `add_to_existing=true`, nodes with already existing
361+
Nodes with already existing
362362
self-loops will obtain a second self-loop.
363363
"""
364-
function add_self_loops(fg::FeaturedGraph{<:COO_T}; add_to_existing=true)
364+
function add_self_loops(fg::FeaturedGraph{<:COO_T})
365365
s, t = edge_index(fg)
366366
@assert edge_feature(fg) === nothing
367367
@assert edge_weight(fg) === nothing
368-
if !add_to_existing
369-
mask_old_loops = s .!= t
370-
s = s[mask_old_loops]
371-
t = t[mask_old_loops]
372-
end
373368
n = fg.num_nodes
374369
nodes = convert(typeof(s), [1:n;])
375370
s = [s; nodes]
@@ -382,14 +377,8 @@ end
382377
function add_self_loops(fg::FeaturedGraph{<:ADJMAT_T}; add_to_existing=true)
383378
A = graph(fg)
384379
@assert edge_feature(fg) === nothing
385-
if add_to_existing
386-
nold = 0
387-
A += I
388-
else
389-
nold = sum(Diagonal(A)) |> Int
390-
A += I - Diagonal(A)
391-
end
392-
num_edges = fg.num_edges - nold + fg.num_nodes
380+
A += I
381+
num_edges = fg.num_edges + fg.num_nodes
393382
FeaturedGraph(A, fg.num_nodes, num_edges,
394383
node_feature(fg), edge_feature(fg), global_feature(fg))
395384
end

src/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ message(l::GCNConv, xi, xj) = xj
4343
update(l::GCNConv, m, x) = m
4444

4545
function (l::GCNConv)(fg::FeaturedGraph, x::CuMatrix{T}) where T
46-
fg = add_self_loops(fg; add_to_existing=true)
46+
fg = add_self_loops(fg)
4747
c = 1 ./ sqrt.(degree(fg, T, dir=:in))
4848
x = x .* c'
4949
_, x = propagate(l, fg, nothing, x, nothing, +)

test/featured_graph.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989
0 0 1 0
9090
0 0 0 1
9191
1 0 0 0]
92-
A2 = [1 1 0 0
92+
A2 = [2 1 0 0
9393
0 1 1 0
9494
0 0 1 1
9595
1 0 0 1]

0 commit comments

Comments
 (0)