Skip to content

Commit 2aab80f

Browse files
committed
store RNG instead of seed
allows to customize `DEFAULT_RNG[]` on a global basis to use stable RNGs
1 parent 6574223 commit 2aab80f

File tree

6 files changed

+72
-27
lines changed

6 files changed

+72
-27
lines changed

Project.toml

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,28 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
99
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1010
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1111

12+
[weakdeps]
13+
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
14+
15+
[extensions]
16+
NetworkLayoutGraphsExt = "Graphs"
17+
1218
[compat]
1319
GeometryBasics = "0.4"
1420
Graphs = "1"
1521
Requires = "1"
22+
StableRNGs = "1.0.2"
1623
StaticArrays = "1"
1724
julia = "1.6"
1825

19-
[extensions]
20-
NetworkLayoutGraphsExt = "Graphs"
21-
2226
[extras]
2327
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
2428
GeometryTypes = "4d00f742-c7ba-57c2-abde-4428a4b178cb"
2529
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
2630
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2731
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
32+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2833
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2934

3035
[targets]
31-
test = ["Graphs", "Test", "DelimitedFiles", "GeometryTypes", "SparseArrays"]
32-
33-
[weakdeps]
34-
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
36+
test = ["Graphs", "Test", "DelimitedFiles", "GeometryTypes", "SparseArrays", "StableRNGs"]

src/NetworkLayout.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ using StaticArrays
77

88
export LayoutIterator
99

10+
"""
11+
Default RNG for layouts.
12+
"""
13+
const DEFAULT_RNG = Ref{DataType}(MersenneTwister)
14+
1015
"""
1116
AbstractLayout{Dim,Ptype}
1217

src/sfdp.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,36 +33,40 @@ the nodes.
3333
- `(true, false, false)` : only pin certain coordinates
3434
3535
- `seed=1`: Seed for random initial positions.
36+
- `rng=DEFAULT_RNG[](seed)`
37+
38+
Create rng based on seed. Defaults to `MersenneTwister`, can be specified
39+
by overwriting `DEFAULT_RNG[]`
3640
"""
37-
@addcall struct SFDP{Dim,Ptype,T<:AbstractFloat} <: IterativeLayout{Dim,Ptype}
41+
@addcall struct SFDP{Dim,Ptype,T<:AbstractFloat,RNG} <: IterativeLayout{Dim,Ptype}
3842
tol::T
3943
C::T
4044
K::T
4145
iterations::Int
4246
initialpos::Dict{Int,Point{Dim,Ptype}}
4347
pin::Dict{Int,SVector{Dim,Bool}}
44-
seed::UInt
48+
rng::RNG
4549
end
4650

4751
# TODO: check SFDP default parameters
4852
function SFDP(; dim=2, Ptype=Float64,
4953
tol=1.0, C=0.2, K=1.0,
5054
iterations=100,
5155
initialpos=[], pin=[],
52-
seed=1)
56+
seed=1, rng=DEFAULT_RNG[](seed))
5357
if !isempty(initialpos)
5458
dim, Ptype = infer_pointtype(initialpos)
5559
Ptype = promote_type(Float32, Ptype) # make sure to get at least f32 if given as int
5660
end
5761
_initialpos, _pin = _sanitize_initialpos_pin(dim, Ptype, initialpos, pin)
5862

59-
return SFDP{dim,Ptype,typeof(tol)}(tol, C, K, iterations, _initialpos, _pin, seed)
63+
return SFDP{dim,Ptype,typeof(tol),typeof(rng)}(tol, C, K, iterations, _initialpos, _pin, rng)
6064
end
6165

62-
function Base.iterate(iter::LayoutIterator{SFDP{Dim,Ptype,T}}) where {Dim,Ptype,T}
66+
function Base.iterate(iter::LayoutIterator{<:SFDP{Dim,Ptype,T}}) where {Dim,Ptype,T}
6367
algo, adj_matrix = iter.algorithm, iter.adj_matrix
6468
N = size(adj_matrix, 1)
65-
rng = MersenneTwister(algo.seed)
69+
rng = copy(algo.rng)
6670
startpos = [2 .* rand(rng, Point{Dim,Ptype}) .- 1 for _ in 1:N]
6771

6872
for (k, v) in algo.initialpos
@@ -106,8 +110,8 @@ function Base.iterate(iter::LayoutIterator{<:SFDP}, state)
106110
if any(isnan, force)
107111
# if two points are at the exact same location
108112
# use random force in any direction
109-
rng = MersenneTwister(algo.seed + i)
110-
force = randn(rng, Ftype)
113+
rng = copy(algo.rng)
114+
force += randn(rng, Ftype)
111115
end
112116
mask = (!).(pin[i]) # where pin=true mask will multiply with 0
113117
locs[i] = locs[i] .+ (step .* (force ./ norm(force))) .* mask

src/spring.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,33 +35,37 @@ the nodes.
3535
- `(true, false, false)` : only pin certain coordinates
3636
3737
- `seed=1`: Seed for random initial positions.
38+
- `rng=DEFAULT_RNG[](seed)`
39+
40+
Create rng based on seed. Defaults to `MersenneTwister`, can be specified
41+
by overwriting `DEFAULT_RNG[]`
3842
"""
39-
@addcall struct Spring{Dim,Ptype} <: IterativeLayout{Dim,Ptype}
43+
@addcall struct Spring{Dim,Ptype,RNG} <: IterativeLayout{Dim,Ptype}
4044
C::Float64
4145
iterations::Int
4246
initialtemp::Float64
4347
initialpos::Dict{Int,Point{Dim,Ptype}}
4448
pin::Dict{Int,SVector{Dim,Bool}}
45-
seed::UInt
49+
rng::RNG
4650
end
4751

4852
function Spring(; dim=2, Ptype=Float64,
4953
C=2.0, iterations=100, initialtemp=2.0,
5054
initialpos=[], pin=[],
51-
seed=1)
55+
seed=1, rng=DEFAULT_RNG[](seed))
5256
if !isempty(initialpos)
5357
dim, Ptype = infer_pointtype(initialpos)
5458
Ptype = promote_type(Float32, Ptype) # make sure to get at least f32 if given as int
5559
end
5660
_initialpos, _pin = _sanitize_initialpos_pin(dim, Ptype, initialpos, pin)
5761

58-
return Spring{dim,Ptype}(C, iterations, initialtemp, _initialpos, _pin, seed)
62+
return Spring{dim,Ptype,typeof(rng)}(C, iterations, initialtemp, _initialpos, _pin, rng)
5963
end
6064

6165
function Base.iterate(iter::LayoutIterator{<:Spring{Dim,Ptype}}) where {Dim,Ptype}
6266
algo, adj_matrix = iter.algorithm, iter.adj_matrix
6367
N = size(adj_matrix, 1)
64-
rng = MersenneTwister(algo.seed)
68+
rng = copy(algo.rng)
6569
startpos = [2 .* rand(rng, Point{Dim,Ptype}) .- 1 for _ in 1:N]
6670

6771
for (k, v) in algo.initialpos
@@ -110,7 +114,7 @@ function Base.iterate(iter::LayoutIterator{<:Spring}, state)
110114
else
111115
# if two points are at the exact same location
112116
# use random force in any direction
113-
rng = MersenneTwister(algo.seed + i)
117+
rng = copy(algo.rng)
114118
force_vec += randn(rng, Ftype)
115119
end
116120

src/stress.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,12 @@ The main equation to solve is (8) in Gansner, Koren and North (2005,
5353
- `(true, false, false)` : only pin certain coordinates
5454
5555
- `seed=1`: Seed for random initial positions.
56+
- `rng=DEFAULT_RNG[](seed)`
57+
58+
Create rng based on seed. Defaults to `MersenneTwister`, can be specified
59+
by overwriting `DEFAULT_RNG[]`
5660
"""
57-
@addcall struct Stress{Dim,Ptype,IT<:Union{Symbol,Int},FT<:AbstractFloat,M<:AbstractMatrix} <:
61+
@addcall struct Stress{Dim,Ptype,IT<:Union{Symbol,Int},FT<:AbstractFloat,M<:AbstractMatrix,RNG} <:
5862
IterativeLayout{Dim,Ptype}
5963
iterations::IT
6064
abstols::FT
@@ -63,7 +67,7 @@ The main equation to solve is (8) in Gansner, Koren and North (2005,
6367
weights::M
6468
initialpos::Dict{Int,Point{Dim,Ptype}}
6569
pin::Dict{Int,SVector{Dim,Bool}}
66-
seed::UInt
70+
rng::RNG
6771
end
6872

6973
function Stress(; dim=2,
@@ -74,23 +78,23 @@ function Stress(; dim=2,
7478
abstolx=10e-6,
7579
weights=Array{Float64}(undef, 0, 0),
7680
initialpos=[], pin=[],
77-
seed=1)
81+
seed=1, rng=DEFAULT_RNG[](seed))
7882
if !isempty(initialpos)
7983
dim, Ptype = infer_pointtype(initialpos)
8084
Ptype = promote_type(Float32, Ptype) # make sure to get at least f32 if given as int
8185
end
8286

8387
_initialpos, _pin = _sanitize_initialpos_pin(dim, Ptype, initialpos, pin)
8488

85-
IT, FT, WT = typeof(iterations), typeof(abstols), typeof(weights)
86-
Stress{dim,Ptype,IT,FT,WT}(iterations, abstols, reltols, abstolx, weights, _initialpos, _pin, seed)
89+
IT, FT, WT, RNG = typeof(iterations), typeof(abstols), typeof(weights), typeof(rng)
90+
Stress{dim,Ptype,IT,FT,WT,RNG}(iterations, abstols, reltols, abstolx, weights, _initialpos, _pin, rng)
8791
end
8892

8993
function Base.iterate(iter::LayoutIterator{<:Stress{Dim,Ptype,IT,FT}}) where {Dim,Ptype,IT,FT}
9094
algo, δ = iter.algorithm, iter.adj_matrix
9195
N = size(δ, 1)
9296
M = length(algo.initialpos)
93-
rng = MersenneTwister(algo.seed)
97+
rng = copy(algo.rng)
9498
startpos = randn(rng, Point{Dim,Ptype}, N)
9599

96100
for (k, v) in algo.initialpos

test/runtests.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ using GeometryBasics
44
using DelimitedFiles: readdlm
55
using SparseArrays: sparse
66
using StaticArrays
7+
using StableRNGs
78
using Test
9+
using Random
810

911
function jagmesh()
1012
jagmesh_path = joinpath(dirname(@__FILE__), "jagmesh1.mtx")
@@ -63,6 +65,14 @@ jagmesh_adj = jagmesh()
6365
positions = @time SFDP(; dim=3, Ptype=Float32, tol=0.1, K=1)(adj_matrix)
6466
@test typeof(positions) == Vector{Point3f}
6567
@test positions == sfdp(adj_matrix; dim=3, Ptype=Float32, tol=0.1, K=1)
68+
69+
NetworkLayout.DEFAULT_RNG[] = StableRNG
70+
l = SFDP()
71+
@test l.rng isa StableRNG
72+
li = LayoutIterator(l, g)
73+
p1, p2 = iterate(li)[1], iterate(li)[1]
74+
@test p1 == p2
75+
NetworkLayout.DEFAULT_RNG[] = MersenneTwister
6676
end
6777
end
6878

@@ -112,6 +122,14 @@ jagmesh_adj = jagmesh()
112122
positions = @time Stress(; iterations=10, dim=3, Ptype=Float32)(adj_matrix)
113123
@test typeof(positions) == Vector{Point3f}
114124
@test positions == stress(adj_matrix; iterations=10, dim=3, Ptype=Float32)
125+
126+
NetworkLayout.DEFAULT_RNG[] = StableRNG
127+
l = Stress()
128+
@test l.rng isa StableRNG
129+
li = LayoutIterator(l, g)
130+
p1, p2 = iterate(li)[1], iterate(li)[1]
131+
@test p1 == p2
132+
NetworkLayout.DEFAULT_RNG[] = MersenneTwister
115133
end
116134

117135
@testset "test pairwise_distance" begin
@@ -170,6 +188,14 @@ jagmesh_adj = jagmesh()
170188
@test typeof(positions) == Vector{Point3f}
171189
@test positions ==
172190
spring(adj_matrix; C=2.0, iterations=100, initialtemp=2.0, Ptype=Float32, dim=3)
191+
192+
NetworkLayout.DEFAULT_RNG[] = StableRNG
193+
l = Spring()
194+
@test l.rng isa StableRNG
195+
li = LayoutIterator(l, g)
196+
p1, p2 = iterate(li)[1], iterate(li)[1]
197+
@test p1 == p2
198+
NetworkLayout.DEFAULT_RNG[] = MersenneTwister
173199
end
174200

175201
@testset "test single node graph" begin

0 commit comments

Comments
 (0)