Skip to content

Commit a9ba4d3

Browse files
authored
Replace GLOBAL_RNG with default_rng() (#104)
* Replace `GLOBAL_RNG` with `default_rng()` * Update utils.jl * Update sample.jl
1 parent 5b75d15 commit a9ba4d3

File tree

6 files changed

+14
-16
lines changed

6 files changed

+14
-16
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probablistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "4.1.2"
6+
version = "4.1.3"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"

src/sample.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function setprogress!(progress::Bool)
1313
end
1414

1515
function StatsBase.sample(model::AbstractModel, sampler::AbstractSampler, arg; kwargs...)
16-
return StatsBase.sample(Random.GLOBAL_RNG, model, sampler, arg; kwargs...)
16+
return StatsBase.sample(Random.default_rng(), model, sampler, arg; kwargs...)
1717
end
1818

1919
"""
@@ -63,7 +63,7 @@ function StatsBase.sample(
6363
kwargs...,
6464
)
6565
return StatsBase.sample(
66-
Random.GLOBAL_RNG, model, sampler, parallel, N, nchains; kwargs...
66+
Random.default_rng(), model, sampler, parallel, N, nchains; kwargs...
6767
)
6868
end
6969

src/stepper.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Base.IteratorSize(::Type{<:Stepper}) = Base.IsInfinite()
4242
Base.IteratorEltype(::Type{<:Stepper}) = Base.EltypeUnknown()
4343

4444
function steps(model::AbstractModel, sampler::AbstractSampler; kwargs...)
45-
return steps(Random.GLOBAL_RNG, model, sampler; kwargs...)
45+
return steps(Random.default_rng(), model, sampler; kwargs...)
4646
end
4747

4848
"""

src/transducer.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ struct Sample{A<:Random.AbstractRNG,M<:AbstractModel,S<:AbstractSampler,K} <:
77
end
88

99
function Sample(model::AbstractModel, sampler::AbstractSampler; kwargs...)
10-
return Sample(Random.GLOBAL_RNG, model, sampler; kwargs...)
10+
return Sample(Random.default_rng(), model, sampler; kwargs...)
1111
end
1212

1313
"""

test/sample.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
Random.seed!(1234)
77
N = 1_000
8-
chain = sample(MyModel(), MySampler(), N; sleepy=true, loggers=true)
8+
chain = sample(MyModel(), MySampler(), N; loggers=true)
99

1010
@test length(LOGGERS) == 1
1111
logger = first(LOGGERS)
@@ -42,7 +42,7 @@
4242

4343
logger = JunoProgressLogger()
4444
Logging.with_logger(logger) do
45-
sample(MyModel(), MySampler(), N; sleepy=true, loggers=true)
45+
sample(MyModel(), MySampler(), N; loggers=true)
4646
end
4747

4848
@test length(LOGGERS) == 1
@@ -60,7 +60,7 @@
6060

6161
Random.seed!(1234)
6262
N = 10
63-
sample(MyModel(), MySampler(), N; sleepy=true, loggers=true)
63+
sample(MyModel(), MySampler(), N; loggers=true)
6464

6565
@test length(LOGGERS) == 1
6666
logger = first(LOGGERS)
@@ -82,7 +82,7 @@
8282

8383
logger = Logging.ConsoleLogger(stderr, Logging.LogLevel(-1))
8484
Logging.with_logger(logger) do
85-
sample(MyModel(), MySampler(), N; sleepy=true, loggers=true)
85+
sample(MyModel(), MySampler(), N; loggers=true)
8686
end
8787

8888
@test length(LOGGERS) == 1
@@ -92,7 +92,7 @@
9292

9393
@testset "Suppress output" begin
9494
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
95-
sample(MyModel(), MySampler(), 100; progress=false, sleepy=true)
95+
sample(MyModel(), MySampler(), 100; progress=false)
9696
end
9797
@test all(l.level > Logging.LogLevel(-1) for l in logs)
9898

@@ -103,7 +103,7 @@
103103
@test !AbstractMCMC.PROGRESS[]
104104

105105
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
106-
sample(MyModel(), MySampler(), 100; sleepy=true)
106+
sample(MyModel(), MySampler(), 100)
107107
end
108108
@test all(l.level > Logging.LogLevel(-1) for l in logs)
109109

@@ -462,8 +462,8 @@
462462
end
463463

464464
@testset "Chain constructors" begin
465-
chain1 = sample(MyModel(), MySampler(), 100; sleepy=true)
466-
chain2 = sample(MyModel(), MySampler(), 100; sleepy=true, chain_type=MyChain)
465+
chain1 = sample(MyModel(), MySampler(), 100)
466+
chain2 = sample(MyModel(), MySampler(), 100; chain_type=MyChain)
467467

468468
@test chain1 isa Vector{<:MySample}
469469
@test chain2 isa MyChain

test/utils.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ function AbstractMCMC.step(
2121
model::MyModel,
2222
sampler::MySampler,
2323
state::Union{Nothing,Integer}=nothing;
24-
sleepy=false,
2524
loggers=false,
2625
init_params=nothing,
2726
kwargs...,
@@ -34,7 +33,6 @@ function AbstractMCMC.step(
3433
end
3534

3635
loggers && push!(LOGGERS, Logging.current_logger())
37-
sleepy && sleep(0.001)
3836

3937
_state = state === nothing ? 1 : state + 1
4038

@@ -72,7 +70,7 @@ end
7270

7371
# Set a default convergence function.
7472
function AbstractMCMC.sample(model, sampler::MySampler; kwargs...)
75-
return sample(Random.GLOBAL_RNG, model, sampler, isdone; kwargs...)
73+
return sample(Random.default_rng(), model, sampler, isdone; kwargs...)
7674
end
7775

7876
function AbstractMCMC.chainscat(

0 commit comments

Comments
 (0)