Skip to content

Commit 9c7143b

Browse files
committed
Replace GLOBAL_RNG with default_rng() (#468)
I just noticed that we still use `GLOBAL_RNG` in a few places instead of `default_rng()` (see e.g. JuliaStats/Distributions.jl#1679 and references therein why the newer default_rng() has advantages over GLOBAL_RNG).
1 parent ee23853 commit 9c7143b

File tree

8 files changed

+23
-23
lines changed

8 files changed

+23
-23
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.22.1"
3+
version = "0.22.2"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/contexts.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right
121121
# Contexts
122122
"""
123123
SamplingContext(
124-
[rng::Random.AbstractRNG=Random.GLOBAL_RNG],
124+
[rng::Random.AbstractRNG=Random.default_rng()],
125125
[sampler::AbstractSampler=SampleFromPrior()],
126126
[context::AbstractContext=DefaultContext()],
127127
)
@@ -138,23 +138,23 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte
138138
end
139139

140140
function SamplingContext(
141-
rng::Random.AbstractRNG=Random.GLOBAL_RNG, sampler::AbstractSampler=SampleFromPrior()
141+
rng::Random.AbstractRNG=Random.default_rng(), sampler::AbstractSampler=SampleFromPrior()
142142
)
143143
return SamplingContext(rng, sampler, DefaultContext())
144144
end
145145

146146
function SamplingContext(
147147
sampler::AbstractSampler, context::AbstractContext=DefaultContext()
148148
)
149-
return SamplingContext(Random.GLOBAL_RNG, sampler, context)
149+
return SamplingContext(Random.default_rng(), sampler, context)
150150
end
151151

152152
function SamplingContext(rng::Random.AbstractRNG, context::AbstractContext)
153153
return SamplingContext(rng, SampleFromPrior(), context)
154154
end
155155

156156
function SamplingContext(context::AbstractContext)
157-
return SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), context)
157+
return SamplingContext(Random.default_rng(), SampleFromPrior(), context)
158158
end
159159

160160
NodeTrait(context::SamplingContext) = IsParent()

src/model.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ function AbstractPPL.evaluate!!(model::Model, context::AbstractContext)
520520
end
521521

522522
function AbstractPPL.evaluate!!(model::Model, args...)
523-
return evaluate!!(model, Random.GLOBAL_RNG, args...)
523+
return evaluate!!(model, Random.default_rng(), args...)
524524
end
525525

526526
# without VarInfo
@@ -626,7 +626,7 @@ Base.nameof(model::Model) = Symbol(model.f)
626626
Base.nameof(model::Model{<:Function}) = nameof(model.f)
627627

628628
"""
629-
rand([rng=Random.GLOBAL_RNG], [T=NamedTuple], model::Model)
629+
rand([rng=Random.default_rng()], [T=NamedTuple], model::Model)
630630
631631
Generate a sample of type `T` from the prior distribution of the `model`.
632632
"""
@@ -643,8 +643,8 @@ end
643643

644644
# Default RNG and type
645645
Base.rand(rng::Random.AbstractRNG, model::Model) = rand(rng, NamedTuple, model)
646-
Base.rand(::Type{T}, model::Model) where {T} = rand(Random.GLOBAL_RNG, T, model)
647-
Base.rand(model::Model) = rand(Random.GLOBAL_RNG, NamedTuple, model)
646+
Base.rand(::Type{T}, model::Model) where {T} = rand(Random.default_rng(), T, model)
647+
Base.rand(model::Model) = rand(Random.default_rng(), NamedTuple, model)
648648

649649
"""
650650
logjoint(model::Model, varinfo::AbstractVarInfo)

src/varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ function VarInfo(
135135
model(rng, varinfo, sampler, context)
136136
return TypedVarInfo(varinfo)
137137
end
138-
VarInfo(model::Model, args...) = VarInfo(Random.GLOBAL_RNG, model, args...)
138+
VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...)
139139

140140
unflatten(vi::VarInfo, x::AbstractVector) = unflatten(vi, SampleFromPrior(), x)
141141

test/contexts.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,16 +240,16 @@ end
240240
end
241241

242242
@testset "SamplingContext" begin
243-
context = SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext())
243+
context = SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext())
244244
@test context isa SamplingContext
245245

246246
# convenience constructors
247247
@test SamplingContext() == context
248-
@test SamplingContext(Random.GLOBAL_RNG) == context
248+
@test SamplingContext(Random.default_rng()) == context
249249
@test SamplingContext(SampleFromPrior()) == context
250250
@test SamplingContext(DefaultContext()) == context
251-
@test SamplingContext(Random.GLOBAL_RNG, SampleFromPrior()) == context
252-
@test SamplingContext(Random.GLOBAL_RNG, DefaultContext()) == context
251+
@test SamplingContext(Random.default_rng(), SampleFromPrior()) == context
252+
@test SamplingContext(Random.default_rng(), DefaultContext()) == context
253253
@test SamplingContext(SampleFromPrior(), DefaultContext()) == context
254254
@test SamplingContext(SampleFromPrior(), DefaultContext()) == context
255255
end

test/model.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ end
4545
for i in 1:10
4646
Random.seed!(100 + i)
4747
vi = VarInfo()
48-
model(Random.GLOBAL_RNG, vi, sampler)
48+
model(Random.default_rng(), vi, sampler)
4949
vals = DynamicPPL.getall(vi)
5050

5151
Random.seed!(100 + i)
5252
vi = VarInfo()
53-
model(Random.GLOBAL_RNG, vi, sampler)
53+
model(Random.default_rng(), vi, sampler)
5454
@test DynamicPPL.getall(vi) == vals
5555
end
5656
end
@@ -63,7 +63,7 @@ end
6363
s, m = model()
6464

6565
Random.seed!(100)
66-
@test model(Random.GLOBAL_RNG) == (s, m)
66+
@test model(Random.default_rng()) == (s, m)
6767
end
6868

6969
@testset "nameof" begin

test/threadsafe.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
DynamicPPL.evaluate_threadsafe!!(
6464
wthreads(x),
6565
vi,
66-
SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()),
66+
SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()),
6767
)
6868
@test getlogp(vi) lp_w_threads
6969
@test vi_ isa DynamicPPL.ThreadSafeVarInfo
@@ -72,7 +72,7 @@
7272
@time DynamicPPL.evaluate_threadsafe!!(
7373
wthreads(x),
7474
vi,
75-
SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()),
75+
SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()),
7676
)
7777

7878
@model function wothreads(x)
@@ -102,7 +102,7 @@
102102
DynamicPPL.evaluate_threadunsafe!!(
103103
wothreads(x),
104104
vi,
105-
SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()),
105+
SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()),
106106
)
107107
@test getlogp(vi) lp_w_threads
108108
@test vi_ isa VarInfo
@@ -111,7 +111,7 @@
111111
@time DynamicPPL.evaluate_threadunsafe!!(
112112
wothreads(x),
113113
vi,
114-
SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()),
114+
SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()),
115115
)
116116
end
117117
end

test/turing/varinfo.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@
285285
#= g = Sampler(Gibbs(PG(10, :x, :y, :z), HMC(0.4, 8, :w, :u)), g_demo_f)
286286
vi = VarInfo()
287287
g_demo_f(vi, SampleFromPrior())
288-
_, state = @inferred AbstractMCMC.step(Random.GLOBAL_RNG, g_demo_f, g)
288+
_, state = @inferred AbstractMCMC.step(Random.default_rng(), g_demo_f, g)
289289
pg, hmc = state.states
290290
@test pg isa TypedVarInfo
291291
@test hmc isa Turing.Inference.HMCState
@@ -302,7 +302,7 @@
302302
vi = empty!!(TypedVarInfo(vi))
303303
@inferred g_demo_f(vi, SampleFromPrior())
304304
pg.state.vi = vi
305-
step!(Random.GLOBAL_RNG, g_demo_f, pg, 1)
305+
step!(Random.default_rng(), g_demo_f, pg, 1)
306306
vi = pg.state.vi
307307
@inferred g_demo_f(vi, hmc)
308308
@test vi.metadata.x.gids[1] == Set([pg.selector])

0 commit comments

Comments
 (0)