Skip to content

Commit 1b8e19b

Browse files
committed
Add requires_threadsafe (#353)
See related discussion: TuringLang/Turing.jl#1726 (comment).
1 parent 57c50f1 commit 1b8e19b

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.17.1"
3+
version = "0.17.2"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/model.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,16 @@ number of `sampler`.
376376
"""
377377
(model::Model)(args...) = first(evaluate!!(model, args...))
378378

379+
"""
380+
use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
381+
382+
Return `true` if evaluation of a model using `context` and `varinfo` should
383+
wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise.
384+
"""
385+
function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
386+
return Threads.nthreads() > 1
387+
end
388+
379389
"""
380390
evaluate!!(model::Model[, rng, varinfo, sampler, context])
381391
@@ -388,10 +398,10 @@ The method resets the log joint probability of `varinfo` and increases the evalu
388398
number of `sampler`.
389399
"""
390400
function evaluate!!(model::Model, varinfo::AbstractVarInfo, context::AbstractContext)
391-
if Threads.nthreads() == 1
392-
return evaluate_threadunsafe!!(model, varinfo, context)
401+
return if use_threadsafe_eval(context, varinfo)
402+
evaluate_threadsafe!!(model, varinfo, context)
393403
else
394-
return evaluate_threadsafe!!(model, varinfo, context)
404+
evaluate_threadunsafe!!(model, varinfo, context)
395405
end
396406
end
397407

0 commit comments

Comments
 (0)