Skip to content

Commit b08084f

Browse files
undirected edge encoding
1 parent 22ab06c commit b08084f

File tree

3 files changed

+67
-30
lines changed

3 files changed

+67
-30
lines changed

src/GNNGraphs/transform.jl

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -393,36 +393,6 @@ function rand_edge_split(g::GNNGraph, frac)
393393
end
394394

395395

396-
# each edge is represented by a number in
397-
# 1:N^2
398-
function edge_encoding(s, t, n; directed=true)
399-
if directed
400-
# directed edges and self-loops allowed
401-
idx = (s .- 1) .* n .+ t
402-
maxid = n^2
403-
else
404-
# undirected edges and self-loops allowed
405-
maxid = n * (n - 1) ÷ 2
406-
mask = s .<= t
407-
s1, t1 = s[mask], t[mask]
408-
t2, s2 = s[.!mask], t[.!mask]
409-
s, t = [s1; s2], [t1; t2]
410-
offset1 = (n .* 0:n-1) .- cumsum(0:n-1)
411-
offset2 = 0:n-1
412-
idx = offset1[s] .+ (t .- offset2)
413-
end
414-
return idx, maxid
415-
end
416-
417-
# each edge is represented by a number in
418-
# 1:N^2
419-
function edge_decoding(idx, n)
420-
# g = remove_self_loops(g)
421-
s = (idx .- 1) n .+ 1
422-
t = (idx .- 1) .% n .+ 1
423-
return s, t
424-
end
425-
426396
# """
427397
# Transform vector of cartesian indexes into a tuple of vectors containing integers.
428398
# """

src/GNNGraphs/utils.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,47 @@ end
7070
ones_like(x::AbstractArray, T=eltype(x), sz=size(x)) = fill!(similar(x, T, sz), 1)
7171
ones_like(x::SparseMatrixCSC, T=eltype(x), sz=size(x)) = ones(T, sz)
7272
ones_like(x::CUMAT_T, T=eltype(x), sz=size(x)) = CUDA.ones(T, sz)
73+
74+
75+
# each edge is represented by a number in
76+
# 1:N^2
77+
function edge_encoding(s, t, n; directed=true)
78+
if directed
79+
# directed edges and self-loops allowed
80+
idx = (s .- 1) .* n .+ t
81+
maxid = n^2
82+
else
83+
# Undirected edges and self-loops allowed
84+
# In this encoding, each edge has 2 possible encodings (also the self-loops).
85+
# We return the canonical one given by the upper triangular adj matrix
86+
maxid = n * (n + 1) ÷ 2
87+
mask = s .> t
88+
# s1, t1 = s[mask], t[mask]
89+
# t2, s2 = s[.!mask], t[.!mask]
90+
snew = copy(s)
91+
tnew = copy(t)
92+
snew[mask] .= t[mask]
93+
tnew[mask] .= s[mask]
94+
s, t = snew, tnew
95+
96+
# idx = ∑_{i',i'<i} ∑_{j',j'>=i'}^n 1 + ∑_{j',i<=j'<=j} 1
97+
# = ∑_{i',i'<i} ∑_{j',j'>=i'}^n 1 + j - i + 1
98+
# = ∑_{i',i'<i} (n - i' + 1) + (j - i + 1)
99+
# = (i - 1)*(2*(n+1)-i)÷2 + (j - i + 1)
100+
idx = @. (s-1)*(2*(n+1)-s)÷2 + (t-s+1)
101+
end
102+
return idx, maxid
103+
end
104+
105+
# each edge is represented by a number in
106+
# 1:N^2
107+
function edge_decoding(idx, n; directed=true)
108+
# g = remove_self_loops(g)
109+
s = (idx .- 1) n .+ 1
110+
t = (idx .- 1) .% n .+ 1
111+
return s, t
112+
end
113+
114+
@non_differentiable edge_encoding(x...)
115+
@non_differentiable edge_decoding(x...)
116+

test/GNNGraphs/utils.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
@testset "Utils" begin
2+
@testset "edge encoding/decoding" begin
3+
# not is_bidirected
4+
n = 5
5+
s = [1,1,2,3,3,4,5]
6+
t = [1,3,1,1,2,5,5]
7+
8+
# directed=true
9+
idx, maxid = GNNGraphs.edge_encoding(s, t, n)
10+
@test maxid == n^2
11+
@test idx == [1, 3, 6, 11, 12, 20, 25]
12+
13+
sdec, tdec = GNNGraphs.edge_decoding(idx, n)
14+
@test sdec == s
15+
@test tdec == t
16+
17+
18+
# directed=false
19+
idx, maxid = GNNGraphs.edge_encoding(s, t, n, directed=false)
20+
@test maxid == n * (n+1)÷2
21+
@test idx == [1, 3, 2, 3, 7, 14, 15]
22+
end
23+
end

0 commit comments

Comments
 (0)