Skip to content

Commit a4f6706

Browse files
committed
Allow to provide RNG only in model evaluation (#172)
Addresses TuringLang/Turing.jl#1421 (comment).
1 parent 161f820 commit a4f6706

File tree

3 files changed

+30
-11
lines changed

3 files changed

+30
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.9.2"
3+
version = "0.9.3"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/model.jl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,26 +76,35 @@ Sample from the `model` using the `sampler` with random number generator `rng` a
7676
The method resets the log joint probability of `varinfo` and increases the evaluation
7777
number of `sampler`.
7878
"""
79-
function (model::Model)(args...)
80-
return model(VarInfo(), args...)
81-
end
82-
83-
function (model::Model)(varinfo::AbstractVarInfo, args...)
84-
return model(Random.GLOBAL_RNG, varinfo, args...)
85-
end
86-
8779
function (model::Model)(
8880
rng::Random.AbstractRNG,
89-
varinfo::AbstractVarInfo,
81+
varinfo::AbstractVarInfo = VarInfo(),
9082
sampler::AbstractSampler = SampleFromPrior(),
91-
context::AbstractContext = DefaultContext()
83+
context::AbstractContext = DefaultContext(),
9284
)
9385
if Threads.nthreads() == 1
9486
return evaluate_threadunsafe(rng, model, varinfo, sampler, context)
9587
else
9688
return evaluate_threadsafe(rng, model, varinfo, sampler, context)
9789
end
9890
end
91+
function (model::Model)(args...)
92+
return model(Random.GLOBAL_RNG, args...)
93+
end
94+
95+
# without VarInfo
96+
function (model::Model)(
97+
rng::Random.AbstractRNG,
98+
sampler::AbstractSampler,
99+
args...,
100+
)
101+
return model(rng, VarInfo(), sampler, args...)
102+
end
103+
104+
# without VarInfo and without AbstractSampler
105+
function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext)
106+
return model(rng, VarInfo(), SampleFromPrior(), context)
107+
end
99108

100109
"""
101110
evaluate_threadunsafe(rng, model, varinfo, sampler, context)

test/model.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,16 @@ Random.seed!(1234)
4545
end
4646
end
4747

48+
@testset "defaults without VarInfo, Sampler, and Context" begin
49+
model = gdemo_default
50+
51+
Random.seed!(100)
52+
s, m = model()
53+
54+
Random.seed!(100)
55+
@test model(Random.GLOBAL_RNG) == (s, m)
56+
end
57+
4858
@testset "setval! & generated_quantities" begin
4959
@model function demo1(xs, ::Type{TV} = Vector{Float64}) where {TV}
5060
m = TV(undef, 2)

0 commit comments

Comments
 (0)