Skip to content

Commit 54f1b7b

Browse files
committed
Make threadsafe evaluation opt-in
1 parent 3cd8d34 commit 54f1b7b

File tree

9 files changed

+233
-174
lines changed

9 files changed

+233
-174
lines changed

HISTORY.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,35 @@
44

55
### Breaking changes
66

7+
#### Threadsafe evaluation
8+
9+
DynamicPPL models are by default no longer thread-safe.
10+
If you have threading in a model, you **must** now manually mark it as so, using:
11+
12+
```julia
13+
@model f() = ...
14+
model = f()
15+
model = setthreadsafe(model, true)
16+
```
17+
18+
It used to be that DynamicPPL would 'automatically' enable thread-safe evaluation if Julia was launched with more than one thread (i.e., by checking `Threads.nthreads() > 1`).
19+
20+
The problem with this approach is that it sacrifices a huge amount of performance.
21+
Furthermore, it is not actually the correct approach: just because Julia has multiple threads does not mean that a particular model actually requires threadsafe evaluation.
22+
23+
**A model requires threadsafe evaluation if, and only if, the VarInfo object used inside the model is manipulated in parallel.**
24+
This can occur if any of the following are inside `Threads.@threads` or other concurrency functions / macros:
25+
26+
- tilde-statements
27+
- calls to `@addlogprob!`
28+
- any direct manipulation of the special `__varinfo__` variable
29+
30+
If you have none of these inside threaded blocks, then you do not need to mark your model as threadsafe.
31+
**Notably, the following do not require threadsafe evaluation:**
32+
33+
- Using threading for anything that does not involve VarInfo. For example, you can calculate a log-probability in parallel, and then add it using `@addlogprob!` outside of the threaded block. This does not require threadsafe evaluation.
34+
- Sampling with `AbstractMCMC.MCMCThreads()`.
35+
736
#### Parent and leaf contexts
837

938
The `DynamicPPL.NodeTrait` function has been removed.

docs/src/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ The context of a model can be set using [`contextualize`](@ref):
4242
contextualize
4343
```
4444

45+
Some models require threadsafe evaluation (see https://turinglang.org/docs/THIS_DOESNT_EXIST_YET for more information on when this is necessary).
46+
If this is the case, one must enable threadsafe evaluation for a model:
47+
48+
```@docs
49+
setthreadsafe
50+
```
51+
4552
## Evaluation
4653

4754
With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref).

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ export AbstractVarInfo,
9090
Model,
9191
getmissings,
9292
getargnames,
93+
setthreadsafe,
9394
extract_priors,
9495
values_as_in_model,
9596
# LogDensityFunction

src/compiler.jl

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ function model(mod, linenumbernode, expr, warn)
301301
modeldef = build_model_definition(expr)
302302

303303
# Generate main body
304-
modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn)
304+
modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn, false)
305305

306306
return build_output(modeldef, linenumbernode)
307307
end
@@ -346,36 +346,67 @@ Generate the body of the main evaluation function from expression `expr` and arg
346346
If `warn` is true, a warning is displayed if internal variables are used in the model
347347
definition.
348348
"""
349-
generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn)
349+
generate_mainbody(mod, expr, warn, warned_about_threads_threads) =
350+
generate_mainbody!(mod, Symbol[], expr, warn, warned_about_threads_threads)
350351

351-
generate_mainbody!(mod, found, x, warn) = x
352-
function generate_mainbody!(mod, found, sym::Symbol, warn)
352+
generate_mainbody!(mod, found, x, warn, warned_about_threads_threads) = x
353+
function generate_mainbody!(mod, found, sym::Symbol, warn, warned_about_threads_threads)
353354
if warn && sym in INTERNALNAMES && sym found
354355
@warn "you are using the internal variable `$sym`"
355356
push!(found, sym)
356357
end
357358

358359
return sym
359360
end
360-
function generate_mainbody!(mod, found, expr::Expr, warn)
361+
function generate_mainbody!(mod, found, expr::Expr, warn, warned_about_threads_threads)
361362
# Do not touch interpolated expressions
362363
expr.head === :$ && return expr.args[1]
363364

365+
# Flag to determine whether we've issued a warning for threadsafe macros Note that this
366+
# detection is not fully correct. We can only detect the presence of a macro that has
367+
# the symbol `Threads.@threads`, however, we can't detect if that *is actually*
368+
# Threads.@threads from Base.Threads.
369+
364370
# Do we don't want escaped expressions because we unfortunately
365371
# escape the entire body afterwards.
366-
Meta.isexpr(expr, :escape) && return generate_mainbody(mod, found, expr.args[1], warn)
372+
Meta.isexpr(expr, :escape) && return generate_mainbody(
373+
mod, found, expr.args[1], warn, warned_about_threads_threads
374+
)
367375

368376
# If it's a macro, we expand it
369377
if Meta.isexpr(expr, :macrocall)
370-
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn)
378+
if expr.args[1] == Expr(:., :Threads, QuoteNode(Symbol("@threads"))) &&
379+
!warned_about_threads_threads
380+
warned_about_threads_threads = true
381+
@warn (
382+
"It looks like you are using `Threads.@threads` in your model definition." *
383+
"\n\nNote that since version 0.39 of DynamicPPL, threadsafe evaluation of models is disabled by default." *
384+
" If you need it, you will need to explicitly enable it by creating the model, and then running `model = setthreadsafe(model, true)`." *
385+
"\n\nAvoiding threadsafe evaluation can often lead to significant performance improvements. Please see https://turinglang.org/docs/THIS_PAGE_DOESNT_EXIST_YET for more details of when threadsafe evaluation is actually required."
386+
)
387+
else
388+
dump(expr.args[1])
389+
dump(:(Threads.@threads).args[1])
390+
end
391+
return generate_mainbody!(
392+
mod,
393+
found,
394+
macroexpand(mod, expr; recursive=true),
395+
warn,
396+
warned_about_threads_threads,
397+
)
371398
end
372399

373400
# Modify dotted tilde operators.
374401
args_dottilde = getargs_dottilde(expr)
375402
if args_dottilde !== nothing
376403
L, R = args_dottilde
377404
return generate_mainbody!(
378-
mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn
405+
mod,
406+
found,
407+
Base.remove_linenums!(generate_dot_tilde(L, R)),
408+
warn,
409+
warned_about_threads_threads,
379410
)
380411
end
381412

@@ -385,8 +416,8 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
385416
L, R = args_tilde
386417
return Base.remove_linenums!(
387418
generate_tilde(
388-
generate_mainbody!(mod, found, L, warn),
389-
generate_mainbody!(mod, found, R, warn),
419+
generate_mainbody!(mod, found, L, warn, warned_about_threads_threads),
420+
generate_mainbody!(mod, found, R, warn, warned_about_threads_threads),
390421
),
391422
)
392423
end
@@ -403,7 +434,13 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
403434
)
404435
end
405436

406-
return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...)
437+
return Expr(
438+
expr.head,
439+
map(
440+
x -> generate_mainbody!(mod, found, x, warn, warned_about_threads_threads),
441+
expr.args,
442+
)...,
443+
)
407444
end
408445

409446
function generate_assign(left, right)

src/fasteval.jl

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,9 @@ struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
212212
iden_varname_ranges::N
213213
varname_ranges::Dict{VarName,RangeAndLinked}
214214
end
215-
function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
215+
function (f::FastLogDensityAt{Model{F,A,D,M,Ta,Td,Ctx,false}})(
216+
params::AbstractVector{<:Real}
217+
) where {F,A,D,M,Ta,Td,Ctx}
216218
ctx = InitContext(
217219
Random.default_rng(),
218220
InitFromParams(
@@ -221,23 +223,27 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
221223
)
222224
model = DynamicPPL.setleafcontext(f.model, ctx)
223225
accs = fast_ldf_accs(f.getlogdensity)
224-
# Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
225-
# which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
226-
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
227-
# here.
228-
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
229-
# it _should_ do, but this is wrong regardless.
230-
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
231-
vi = if Threads.nthreads() > 1
232-
accs = map(
233-
acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc),
234-
accs,
235-
)
236-
ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
237-
else
238-
OnlyAccsVarInfo(accs)
239-
end
240-
_, vi = DynamicPPL._evaluate!!(model, vi)
226+
_, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs))
227+
return f.getlogdensity(vi)
228+
end
229+
function (f::FastLogDensityAt{Model{F,A,D,M,Ta,Td,Ctx,true}})(
230+
params::AbstractVector{<:Real}
231+
) where {F,A,D,M,Ta,Td,Ctx}
232+
ctx = InitContext(
233+
Random.default_rng(),
234+
InitFromParams(
235+
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
236+
),
237+
)
238+
model = DynamicPPL.setleafcontext(f.model, ctx)
239+
accs = fast_ldf_accs(f.getlogdensity)
240+
accs = map(
241+
acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc),
242+
accs,
243+
)
244+
vi_wrapped = ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
245+
_, vi_wrapped = DynamicPPL._evaluate!!(model, vi_wrapped)
246+
vi = OnlyAccsVarInfo(DynamicPPL.getaccs(vi_wrapped))
241247
return f.getlogdensity(vi)
242248
end
243249

src/model.jl

Lines changed: 40 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext}
2+
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext,Threaded}
33
f::F
44
args::NamedTuple{argnames,Targs}
55
defaults::NamedTuple{defaultnames,Tdefaults}
@@ -17,6 +17,10 @@ An argument with a type of `Missing` will be in `missings` by default. However,
1717
non-traditional use-cases `missings` can be defined differently. All variables in `missings`
1818
are treated as random variables rather than observations.
1919
20+
The `Threaded` type parameter indicates whether the model requires threadsafe evaluation
21+
(i.e., whether the model contains statements which modify the internal VarInfo that are
22+
executed in parallel). By default, this is set to `false`.
23+
2024
The default arguments are used internally when constructing instances of the same model with
2125
different arguments.
2226
@@ -33,8 +37,9 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition
3337
Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,))
3438
```
3539
"""
36-
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <:
37-
AbstractProbabilisticProgram
40+
struct Model{
41+
F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext,Threaded
42+
} <: AbstractProbabilisticProgram
3843
f::F
3944
args::NamedTuple{argnames,Targs}
4045
defaults::NamedTuple{defaultnames,Tdefaults}
@@ -52,7 +57,7 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractConte
5257
defaults::NamedTuple{defaultnames,Tdefaults},
5358
context::Ctx=DefaultContext(),
5459
) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx}
55-
return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}(
60+
return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,false}(
5661
f, args, defaults, context
5762
)
5863
end
@@ -105,6 +110,31 @@ function setleafcontext(model::Model, context::AbstractContext)
105110
return contextualize(model, setleafcontext(model.context, context))
106111
end
107112

113+
"""
114+
setthreadsafe(model::Model, threadsafe::Bool)
115+
116+
Returns a new `Model` with its threadsafe flag set to `threadsafe`.
117+
118+
Threadsafe evaluation allows for parallel execution of model statements that mutate the
119+
internal `VarInfo` object. For example, this is needed if tilde-statements are nested inside
120+
`Threads.@threads` or similar constructs.
121+
122+
It is not needed for generic multithreaded operations that don't involve VarInfo. For
123+
example, calculating a log-likelihood term in parallel and then calling `@addlogprob!`
124+
outside of the parallel region is safe without needing to set `threadsafe=true`.
125+
126+
It is also not needed for multithreaded sampling with AbstractMCMC's `MCMCThreads()`.
127+
"""
128+
function setthreadsafe(
129+
model::Model{F,A,D,M,Ta,Td,Ctx,Threaded}, threadsafe::Bool
130+
) where {F,A,D,M,Ta,Td,Ctx,Threaded}
131+
return if Threaded == threadsafe
132+
model
133+
else
134+
Model{M}(model.f, model.args, model.defaults, model.context)
135+
end
136+
end
137+
108138
"""
109139
model | (x = 1.0, ...)
110140
@@ -863,16 +893,6 @@ function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInf
863893
return first(init!!(rng, model, varinfo))
864894
end
865895

866-
"""
867-
use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
868-
869-
Return `true` if evaluation of a model using `context` and `varinfo` should
870-
wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise.
871-
"""
872-
function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
873-
return Threads.nthreads() > 1
874-
end
875-
876896
"""
877897
init!!(
878898
[rng::Random.AbstractRNG,]
@@ -918,40 +938,14 @@ If multiple threads are available, the varinfo provided will be wrapped in a
918938
Returns a tuple of the model's return value, plus the updated `varinfo`
919939
(unwrapped if necessary).
920940
"""
921-
function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo)
922-
return if use_threadsafe_eval(model.context, varinfo)
923-
evaluate_threadsafe!!(model, varinfo)
924-
else
925-
evaluate_threadunsafe!!(model, varinfo)
926-
end
927-
end
928-
929-
"""
930-
evaluate_threadunsafe!!(model, varinfo)
931-
932-
Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`.
933-
934-
If the `model` makes use of Julia's multithreading this will lead to undefined behaviour.
935-
This method is not exposed and supposed to be used only internally in DynamicPPL.
936-
937-
See also: [`evaluate_threadsafe!!`](@ref)
938-
"""
939-
function evaluate_threadunsafe!!(model, varinfo)
941+
function AbstractPPL.evaluate!!(
942+
model::Model{F,A,D,M,Ta,Td,Ctx,false}, varinfo::AbstractVarInfo
943+
) where {F,A,D,M,Ta,Td,Ctx}
940944
return _evaluate!!(model, resetaccs!!(varinfo))
941945
end
942-
943-
"""
944-
evaluate_threadsafe!!(model, varinfo, context)
945-
946-
Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`.
947-
948-
With the wrapper, Julia's multithreading can be used for observe statements in the `model`
949-
but parallel sampling will lead to undefined behaviour.
950-
This method is not exposed and supposed to be used only internally in DynamicPPL.
951-
952-
See also: [`evaluate_threadunsafe!!`](@ref)
953-
"""
954-
function evaluate_threadsafe!!(model, varinfo)
946+
function AbstractPPL.evaluate!!(
947+
model::Model{F,A,D,M,Ta,Td,Ctx,true}, varinfo::AbstractVarInfo
948+
) where {F,A,D,M,Ta,Td,Ctx}
955949
wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo))
956950
result, wrapper_new = _evaluate!!(model, wrapper)
957951
# TODO(penelopeysm): If seems that if you pass a TSVI to this method, it

0 commit comments

Comments
 (0)