Skip to content

Commit dc7c59a

Browse files
committed
Add docstrings and tests
1 parent d2f64fa commit dc7c59a

File tree

2 files changed

+50
-15
lines changed

2 files changed

+50
-15
lines changed

src/model.jl

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,21 +133,42 @@ function (model::Model)(
133133
context::AbstractContext = DefaultContext()
134134
)
135135
if Threads.nthreads() == 1
136-
return evaluate_singlethreaded(rng, model, varinfo, sampler, context)
136+
return evaluate_threadunsafe(rng, model, varinfo, sampler, context)
137137
else
138-
return evaluate_multithreaded(rng, model, varinfo, sampler, context)
138+
return evaluate_threadsafe(rng, model, varinfo, sampler, context)
139139
end
140140
end
141141

142-
function evaluate_singlethreaded(rng, model, varinfo, sampler, context)
142+
"""
143+
evaluate_threadunsafe(rng, model, varinfo, sampler, context)
144+
145+
Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`.
146+
147+
If the `model` makes use of Julia's multithreading this will lead to undefined behaviour.
148+
This method is not exposed and supposed to be used only internally in DynamicPPL.
149+
150+
See also: [`evaluate_threadsafe`](@ref)
151+
"""
152+
function evaluate_threadunsafe(rng, model, varinfo, sampler, context)
143153
resetlogp!(varinfo)
144154
if has_eval_num(sampler)
145155
sampler.state.eval_num += 1
146156
end
147157
return model.f(rng, model, varinfo, sampler, context)
148158
end
149159

150-
function evaluate_multithreaded(rng, model, varinfo, sampler, context)
160+
"""
161+
evaluate_threadsafe(rng, model, varinfo, sampler, context)
162+
163+
Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`.
164+
165+
With the wrapper, Julia's multithreading can be used for assume statements in the `model`
166+
but parallel sampling will lead to undefined behaviour.
167+
This method is not exposed and supposed to be used only internally in DynamicPPL.
168+
169+
See also: [`evaluate_threadunsafe`](@ref)
170+
"""
171+
function evaluate_threadsafe(rng, model, varinfo, sampler, context)
151172
resetlogp!(varinfo)
152173
if has_eval_num(sampler)
153174
sampler.state.eval_num += 1

test/threadsafe.jl

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
x = rand(10_000)
4040

4141
@model function wthreads(x)
42+
global vi_ = _varinfo
4243
x[1] ~ Normal(0, 1)
4344
Threads.@threads for i in 2:length(x)
4445
x[i] ~ Normal(x[i-1], 1)
@@ -48,21 +49,28 @@
4849
vi = VarInfo()
4950
wthreads(x)(vi)
5051
lp_w_threads = getlogp(vi)
52+
if Threads.nthreads() == 1
53+
@test vi_ isa VarInfo
54+
else
55+
@test vi_ isa DynamicPPL.ThreadSafeVarInfo
56+
end
5157

5258
println("With `@threads`:")
5359
println(" default:")
5460
@time wthreads(x)(vi)
5561

56-
# Ensure that we use `ThreadSafeVarInfo`.
62+
# Ensure that we use `ThreadSafeVarInfo` to handle multithreaded assume statements.
63+
DynamicPPL.evaluate_threadsafe(Random.GLOBAL_RNG, wthreads(x), vi,
64+
SampleFromPrior(), DefaultContext())
5765
@test getlogp(vi) lp_w_threads
58-
DynamicPPL.evaluate_multithreaded(Random.GLOBAL_RNG, wthreads(x), vi,
59-
SampleFromPrior(), DefaultContext())
66+
@test vi_ isa DynamicPPL.ThreadSafeVarInfo
6067

61-
println(" evaluate_multithreaded:")
62-
@time DynamicPPL.evaluate_multithreaded(Random.GLOBAL_RNG, wthreads(x), vi,
63-
SampleFromPrior(), DefaultContext())
68+
println(" evaluate_threadsafe:")
69+
@time DynamicPPL.evaluate_threadsafe(Random.GLOBAL_RNG, wthreads(x), vi,
70+
SampleFromPrior(), DefaultContext())
6471

6572
@model function wothreads(x)
73+
global vi_ = _varinfo
6674
x[1] ~ Normal(0, 1)
6775
for i in 2:length(x)
6876
x[i] ~ Normal(x[i-1], 1)
@@ -72,6 +80,11 @@
7280
vi = VarInfo()
7381
wothreads(x)(vi)
7482
lp_wo_threads = getlogp(vi)
83+
if Threads.nthreads() == 1
84+
@test vi_ isa VarInfo
85+
else
86+
@test vi_ isa DynamicPPL.ThreadSafeVarInfo
87+
end
7588

7689
println("Without `@threads`:")
7790
println(" default:")
@@ -80,12 +93,13 @@
8093
@test lp_w_threads lp_wo_threads
8194

8295
# Ensure that we use `VarInfo`.
83-
DynamicPPL.evaluate_singlethreaded(Random.GLOBAL_RNG, wothreads(x), vi,
84-
SampleFromPrior(), DefaultContext())
96+
DynamicPPL.evaluate_threadunsafe(Random.GLOBAL_RNG, wothreads(x), vi,
97+
SampleFromPrior(), DefaultContext())
8598
@test getlogp(vi) lp_w_threads
99+
@test vi_ isa VarInfo
86100

87-
println(" evaluate_singlethreaded:")
88-
@time DynamicPPL.evaluate_singlethreaded(Random.GLOBAL_RNG, wothreads(x), vi,
89-
SampleFromPrior(), DefaultContext())
101+
println(" evaluate_threadunsafe:")
102+
@time DynamicPPL.evaluate_threadunsafe(Random.GLOBAL_RNG, wothreads(x), vi,
103+
SampleFromPrior(), DefaultContext())
90104
end
91105
end

0 commit comments

Comments
 (0)