Skip to content

Commit df4f1a0

Browse files
directed edge decoding
1 parent b08084f commit df4f1a0

File tree

2 files changed

+31
-11
lines changed

2 files changed

+31
-11
lines changed

src/GNNGraphs/utils.jl

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,20 +81,17 @@ function edge_encoding(s, t, n; directed=true)
8181
maxid = n^2
8282
else
8383
# Undirected edges and self-loops allowed
84-
# In this encoding, each edge has 2 possible encodings (also the self-loops).
85-
# We return the canonical one given by the upper triangular adj matrix
8684
maxid = n * (n + 1) ÷ 2
85+
8786
mask = s .> t
88-
# s1, t1 = s[mask], t[mask]
89-
# t2, s2 = s[.!mask], t[.!mask]
9087
snew = copy(s)
9188
tnew = copy(t)
9289
snew[mask] .= t[mask]
9390
tnew[mask] .= s[mask]
9491
s, t = snew, tnew
95-
92+
9693
# idx = ∑_{i',i'<i} ∑_{j',j'>=i'}^n 1 + ∑_{j',i<=j'<=j} 1
97-
# = ∑_{i',i'<i} ∑_{j',j'>=i'}^n 1 + j - i + 1
94+
# = ∑_{i',i'<i} ∑_{j',j'>=i'}^n 1 + (j - i + 1)
9895
# = ∑_{i',i'<i} (n - i' + 1) + (j - i + 1)
9996
# = (i - 1)*(2*(n+1)-i)÷2 + (j - i + 1)
10097
idx = @. (s-1)*(2*(n+1)-s)÷2 + (t-s+1)
@@ -105,9 +102,24 @@ end
105102
# each edge is represented by a number in
106103
# 1:N^2
107104
function edge_decoding(idx, n; directed=true)
108-
# g = remove_self_loops(g)
109-
s = (idx .- 1) n .+ 1
110-
t = (idx .- 1) .% n .+ 1
105+
if directed
106+
# g = remove_self_loops(g)
107+
s = (idx .- 1) n .+ 1
108+
t = (idx .- 1) .% n .+ 1
109+
else
110+
# We replace j=n in
111+
# idx = (i - 1)*(2*(n+1)-i)÷2 + (j - i + 1)
112+
# and obtain
113+
# idx = (i - 1)*(2*(n+1)-i)÷2 + (n - i + 1)
114+
115+
# OR We replace j=i and obtain??
116+
# idx = (i - 1)*(2*(n+1)-i)÷2 + 1
117+
118+
# inverting we have
119+
s = @. ceil(Int, -sqrt((n + 1/2)^2 - 2*idx) + n + 1/2)
120+
t = @. idx - (s-1)*(2*(n+1)-s)÷2 - 1 + s
121+
# t = (idx .- 1) .% n .+ 1
122+
end
111123
return s, t
112124
end
113125

test/GNNGraphs/utils.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,18 @@
1414
@test sdec == s
1515
@test tdec == t
1616

17-
1817
# directed=false
1918
idx, maxid = GNNGraphs.edge_encoding(s, t, n, directed=false)
20-
@test maxid == n * (n+1)÷2
19+
@test maxid == n*(n+1)÷2
2120
@test idx == [1, 3, 2, 3, 7, 14, 15]
21+
22+
mask = s .> t
23+
snew = copy(s)
24+
tnew = copy(t)
25+
snew[mask] .= t[mask]
26+
tnew[mask] .= s[mask]
27+
sdec, tdec = GNNGraphs.edge_decoding(idx, n, directed=false)
28+
@test sdec == snew
29+
@test tdec == tnew
2230
end
2331
end

0 commit comments

Comments
 (0)