Skip to content

Commit 3daf120

Browse files
rng instead of seed for rand_graph
1 parent ef22e9a commit 3daf120

File tree

6 files changed

+132
-31
lines changed

6 files changed

+132
-31
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ Manifest.toml
99
.vscode
1010
LocalPreferences.toml
1111
.DS_Store
12-
docs/src/democards/gridtheme.css
12+
docs/src/democards/gridtheme.css
13+
test.jl

GNNGraphs/src/abstracttypes.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V}
2+
const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V <: AbstractVector}
33
const ADJLIST_T = AbstractVector{T} where {T <: AbstractVector{<:Integer}}
44
const ADJMAT_T = AbstractMatrix
55
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T

GNNGraphs/src/generate.jl

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,30 @@ GNNGraph:
3636
# Each edge has a reverse
3737
julia> edge_index(g)
3838
([1, 3, 3, 4], [3, 4, 1, 3])
39-
4039
```
4140
"""
42-
function rand_graph(n::Integer, m::Integer; bidirected = true, seed = -1, edge_weight = nothing, kws...)
41+
function rand_graph(n::Integer, m::Integer; seed=-1, kws...)
42+
if seed != -1
43+
Base.depwarn("Keyword argument `seed` is deprecated, pass an rng as first argument instead.", :rand_graph)
44+
rng = MersenneTwister(seed)
45+
else
46+
rng = Random.default_rng()
47+
end
48+
return rand_graph(rng, n, m; kws...)
49+
end
50+
51+
function rand_graph(rng::AbstractRNG, n::Integer, m::Integer; bidirected = true, edge_weight = nothing, kws...)
4352
if bidirected
44-
@assert iseven(m) "Need even number of edges for bidirected graphs, given m=$m."
53+
@assert iseven(m) lazy"Need even number of edges for bidirected graphs, given m=$m."
54+
s, t, _ = _rand_edges(rng, n, m ÷ 2; directed=false, self_loops=false)
55+
s, t = vcat(s, t), vcat(t, s)
56+
if edge_weight !== nothing
57+
edge_weight = vcat(edge_weight, edge_weight)
58+
end
59+
else
60+
s, t, _ = _rand_edges(rng, n, m; directed=true, self_loops=false)
4561
end
46-
m2 = bidirected ? m ÷ 2 : m
47-
return GNNGraph(Graphs.erdos_renyi(n, m2; is_directed = !bidirected, seed); edge_weight, kws...)
62+
return GNNGraph((s, t, edge_weight); kws...)
4863
end
4964

5065
"""
@@ -104,12 +119,6 @@ function _rand_bidirected_heterograph(rng, n::NDict, m::EDict; kws...)
104119
return GNNHeteroGraph(graphs; num_nodes = n, kws...)
105120
end
106121

107-
function _rand_edges(rng, (n1, n2), m)
108-
idx = StatsBase.sample(rng, 1:(n1 * n2), m, replace = false)
109-
s, t = edge_decoding(idx, n1, n2)
110-
val = nothing
111-
return s, t, val
112-
end
113122

114123
"""
115124
rand_bipartite_heterograph(n1, n2, m; [bidirected, seed, node_t, edge_t, kws...])

GNNGraphs/src/utils.jl

Lines changed: 62 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -205,17 +205,13 @@ end
205205
numnonzeros(a::AbstractSparseMatrix) = nnz(a)
206206
numnonzeros(a::AbstractMatrix) = count(!=(0), a)
207207

208-
# each edge is represented by a number in
209-
# 1:N^2
210-
function edge_encoding(s, t, n; directed = true)
211-
if directed
212-
# directed edges and self-loops allowed
213-
idx = (s .- 1) .* n .+ t
208+
## Map edges into a contiguous range of integers
209+
function edge_encoding(s, t, n; directed = true, self_loops = true)
210+
if directed && self_loops
214211
maxid = n^2
215-
else
216-
# Undirected edges and self-loops allowed
212+
idx = (s .- 1) .* n .+ t
213+
elseif !directed && self_loops
217214
maxid = n * (n + 1) ÷ 2
218-
219215
mask = s .> t
220216
snew = copy(s)
221217
tnew = copy(t)
@@ -228,18 +224,34 @@ function edge_encoding(s, t, n; directed = true)
228224
# = ∑_{i',i'<i} (n - i' + 1) + (j - i + 1)
229225
# = (i - 1)*(2*(n+1)-i)÷2 + (j - i + 1)
230226
idx = @. (s - 1) * (2 * (n + 1) - s) ÷ 2 + (t - s + 1)
227+
elseif directed && !self_loops
228+
@assert all(s .!= t)
229+
maxid = n * (n - 1)
230+
idx = (s .- 1) .* (n - 1) .+ t .- (t .> s)
231+
elseif !directed && !self_loops
232+
@assert all(s .!= t)
233+
maxid = n * (n - 1) ÷ 2
234+
mask = s .> t
235+
snew = copy(s)
236+
tnew = copy(t)
237+
snew[mask] .= t[mask]
238+
tnew[mask] .= s[mask]
239+
s, t = snew, tnew
240+
241+
# idx(s,t) = ∑_{s',1<= s'<s} ∑_{t',s'< t' <=n} 1 + ∑_{t',s<t'<=t} 1
242+
# idx(s,t) = ∑_{s',1<= s'<s} (n-s') + (t-s)
243+
# idx(s,t) = (s-1)n - s*(s-1)/2 + (t-s)
244+
idx = @. (s - 1) * n - s * (s - 1) ÷ 2 + (t - s)
231245
end
232246
return idx, maxid
233247
end
234248

235-
# each edge is represented by a number in
236-
# 1:N^2
237-
function edge_decoding(idx, n; directed = true)
238-
if directed
239-
# g = remove_self_loops(g)
249+
# inverse of edge_encoding
250+
function edge_decoding(idx, n; directed = true, self_loops = true)
251+
if directed && self_loops
240252
s = (idx .- 1) n .+ 1
241253
t = (idx .- 1) .% n .+ 1
242-
else
254+
elseif !directed && self_loops
243255
# We replace j=n in
244256
# idx = (i - 1)*(2*(n+1)-i)÷2 + (j - i + 1)
245257
# and obtain
@@ -252,19 +264,52 @@ function edge_decoding(idx, n; directed = true)
252264
s = @. ceil(Int, -sqrt((n + 1 / 2)^2 - 2 * idx) + n + 1 / 2)
253265
t = @. idx - (s - 1) * (2 * (n + 1) - s) ÷ 2 - 1 + s
254266
# t = (idx .- 1) .% n .+ 1
267+
elseif directed && !self_loops
268+
s = (idx .- 1) (n - 1) .+ 1
269+
t = (idx .- 1) .% (n - 1) .+ 1
270+
t = t .+ (t .>= s)
271+
elseif !directed && !self_loops
272+
# Considering t = s + 1 in
273+
# idx = @. (s - 1) * n - s * (s - 1) ÷ 2 + (t - s)
274+
# and inverting for s we have
275+
s = @. floor(Int, 1/2 + n - 1/2 * sqrt(9 - 4n + 4n^2 - 8*idx))
276+
# now we can find t
277+
t = @. idx - (s - 1) * n + s * (s - 1) ÷ 2 + s
255278
end
256279
return s, t
257280
end
258281

259-
# each edge is represented by a number in
260-
# 1:n1*n2
282+
# for bipartite graphs
261283
function edge_decoding(idx, n1, n2)
262284
@assert all(1 .<= idx .<= n1 * n2)
263285
s = (idx .- 1) n2 .+ 1
264286
t = (idx .- 1) .% n2 .+ 1
265287
return s, t
266288
end
267289

290+
function _rand_edges(rng, n::Int, m::Int; directed = true, self_loops = true)
291+
idmax = if directed && self_loops
292+
n^2
293+
elseif !directed && self_loops
294+
n * (n + 1) ÷ 2
295+
elseif directed && !self_loops
296+
n * (n - 1)
297+
elseif !directed && !self_loops
298+
n * (n - 1) ÷ 2
299+
end
300+
idx = StatsBase.sample(rng, 1:idmax, m, replace = false)
301+
s, t = edge_decoding(idx, n; directed, self_loops)
302+
val = nothing
303+
return s, t, val
304+
end
305+
306+
function _rand_edges(rng, (n1, n2), m)
307+
idx = StatsBase.sample(rng, 1:(n1 * n2), m, replace = false)
308+
s, t = edge_decoding(idx, n1, n2)
309+
val = nothing
310+
return s, t, val
311+
end
312+
268313
binarize(x) = map(>(0), x)
269314

270315
@non_differentiable binarize(x...)

GNNGraphs/test/utils.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,50 @@
4747
tnew[mask] .= s1[mask]
4848
@test sdec == snew
4949
@test tdec == tnew
50+
51+
@testset "directed=false, self_loops=false" begin
52+
n = 5
53+
edges = [(1,2), (3,1), (1,4), (1,5), (2,3), (2,4), (2,5), (3,4), (3,5), (4,5)]
54+
s = [e[1] for e in edges]
55+
t = [e[2] for e in edges]
56+
g = GNNGraph(s, t)
57+
idx, idxmax = GNNGraphs.edge_encoding(s, t, n, directed=false, self_loops=false)
58+
@test idxmax == n * (n - 1) ÷ 2
59+
@test idx == 1:idxmax
60+
61+
snew, tnew = GNNGraphs.edge_decoding(idx, n, directed=false, self_loops=false)
62+
@test snew == [1, 1, 1, 1, 2, 2, 2, 3, 3, 4]
63+
@test tnew == [2, 3, 4, 5, 3, 4, 5, 4, 5, 5]
64+
end
65+
66+
@testset "directed=false, self_loops=false" begin
67+
n = 5
68+
edges = [(1,2), (3,1), (1,4), (1,5), (2,3), (2,4), (2,5), (3,4), (3,5), (4,5)]
69+
s = [e[1] for e in edges]
70+
t = [e[2] for e in edges]
71+
72+
idx, idxmax = GNNGraphs.edge_encoding(s, t, n, directed=false, self_loops=false)
73+
@test idxmax == n * (n - 1) ÷ 2
74+
@test idx == 1:idxmax
75+
76+
snew, tnew = GNNGraphs.edge_decoding(idx, n, directed=false, self_loops=false)
77+
@test snew == [1, 1, 1, 1, 2, 2, 2, 3, 3, 4]
78+
@test tnew == [2, 3, 4, 5, 3, 4, 5, 4, 5, 5]
79+
end
80+
81+
@testset "directed=true, self_loops=false" begin
82+
n = 5
83+
edges = [(1,2), (3,1), (1,4), (1,5), (2,3), (2,4), (2,5), (3,4), (3,5), (4,5)]
84+
s = [e[1] for e in edges]
85+
t = [e[2] for e in edges]
86+
87+
idx, idxmax = GNNGraphs.edge_encoding(s, t, n, directed=true, self_loops=false)
88+
@test idxmax == n^2 - n
89+
@test idx == [1, 9, 3, 4, 6, 7, 8, 11, 12, 16]
90+
snew, tnew = GNNGraphs.edge_decoding(idx, n, directed=true, self_loops=false)
91+
@test snew == s
92+
@test tnew == t
93+
end
5094
end
5195

5296
@testset "color_refinment" begin

test/runtests.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ tests = [
3535
!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")
3636

3737
# @testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:coo, :dense, :sparse)
38-
for graph_type in (:coo, :dense, :sparse)
38+
# for graph_type in (:coo, :dense, :sparse)
39+
for graph_type in (:dense,)
40+
3941
@info "Testing graph format :$graph_type"
4042
global GRAPH_T = graph_type
4143
global TEST_GPU = CUDA.functional() && (GRAPH_T != :sparse)

0 commit comments

Comments
 (0)