Skip to content

Commit b432bfc

Browse files
committed
Make threadsafe evaluation opt-in
1 parent 3cd8d34 commit b432bfc

File tree

9 files changed

+247
-181
lines changed

9 files changed

+247
-181
lines changed

HISTORY.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,35 @@
44

55
### Breaking changes
66

7+
#### Threadsafe evaluation
8+
9+
DynamicPPL models are by default no longer thread-safe.
10+
If you have threading in a model, you **must** now manually mark it as so, using:
11+
12+
```julia
13+
@model f() = ...
14+
model = f()
15+
model = setthreadsafe(model, true)
16+
```
17+
18+
It used to be that DynamicPPL would 'automatically' enable thread-safe evaluation if Julia was launched with more than one thread (i.e., by checking `Threads.nthreads() > 1`).
19+
20+
The problem with this approach is that it sacrifices a huge amount of performance.
21+
Furthermore, it is not actually the correct approach: just because Julia has multiple threads does not mean that a particular model actually requires threadsafe evaluation.
22+
23+
**A model requires threadsafe evaluation if, and only if, the VarInfo object used inside the model is manipulated in parallel.**
24+
This can occur if any of the following are inside `Threads.@threads` or other concurrency functions / macros:
25+
26+
- tilde-statements
27+
- calls to `@addlogprob!`
28+
- any direct manipulation of the special `__varinfo__` variable
29+
30+
If you have none of these inside threaded blocks, then you do not need to mark your model as threadsafe.
31+
**Notably, the following do not require threadsafe evaluation:**
32+
33+
- Using threading for anything that does not involve VarInfo. For example, you can calculate a log-probability in parallel, and then add it using `@addlogprob!` outside of the threaded block. This does not require threadsafe evaluation.
34+
- Sampling with `AbstractMCMC.MCMCThreads()`.
35+
736
#### Parent and leaf contexts
837

938
The `DynamicPPL.NodeTrait` function has been removed.

docs/src/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ The context of a model can be set using [`contextualize`](@ref):
4242
contextualize
4343
```
4444

45+
Some models require threadsafe evaluation (see https://turinglang.org/docs/THIS_DOESNT_EXIST_YET for more information on when this is necessary).
46+
If this is the case, one must enable threadsafe evaluation for a model:
47+
48+
```@docs
49+
setthreadsafe
50+
```
51+
4552
## Evaluation
4653

4754
With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref).

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ export AbstractVarInfo,
9090
Model,
9191
getmissings,
9292
getargnames,
93+
setthreadsafe,
9394
extract_priors,
9495
values_as_in_model,
9596
# LogDensityFunction

src/compiler.jl

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ function model(mod, linenumbernode, expr, warn)
301301
modeldef = build_model_definition(expr)
302302

303303
# Generate main body
304-
modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn)
304+
modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn, false)
305305

306306
return build_output(modeldef, linenumbernode)
307307
end
@@ -346,36 +346,67 @@ Generate the body of the main evaluation function from expression `expr` and arg
346346
If `warn` is true, a warning is displayed if internal variables are used in the model
347347
definition.
348348
"""
349-
generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn)
349+
generate_mainbody(mod, expr, warn, warned_about_threads_threads) =
350+
generate_mainbody!(mod, Symbol[], expr, warn, warned_about_threads_threads)
350351

351-
generate_mainbody!(mod, found, x, warn) = x
352-
function generate_mainbody!(mod, found, sym::Symbol, warn)
352+
generate_mainbody!(mod, found, x, warn, warned_about_threads_threads) = x
353+
function generate_mainbody!(mod, found, sym::Symbol, warn, warned_about_threads_threads)
353354
if warn && sym in INTERNALNAMES && sym found
354355
@warn "you are using the internal variable `$sym`"
355356
push!(found, sym)
356357
end
357358

358359
return sym
359360
end
360-
function generate_mainbody!(mod, found, expr::Expr, warn)
361+
function generate_mainbody!(mod, found, expr::Expr, warn, warned_about_threads_threads)
361362
# Do not touch interpolated expressions
362363
expr.head === :$ && return expr.args[1]
363364

365+
# Flag to determine whether we've issued a warning for threadsafe macros Note that this
366+
# detection is not fully correct. We can only detect the presence of a macro that has
367+
# the symbol `Threads.@threads`, however, we can't detect if that *is actually*
368+
# Threads.@threads from Base.Threads.
369+
364370
# Do we don't want escaped expressions because we unfortunately
365371
# escape the entire body afterwards.
366-
Meta.isexpr(expr, :escape) && return generate_mainbody(mod, found, expr.args[1], warn)
372+
Meta.isexpr(expr, :escape) && return generate_mainbody(
373+
mod, found, expr.args[1], warn, warned_about_threads_threads
374+
)
367375

368376
# If it's a macro, we expand it
369377
if Meta.isexpr(expr, :macrocall)
370-
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn)
378+
if expr.args[1] == Expr(:., :Threads, QuoteNode(Symbol("@threads"))) &&
379+
!warned_about_threads_threads
380+
warned_about_threads_threads = true
381+
@warn (
382+
"It looks like you are using `Threads.@threads` in your model definition." *
383+
"\n\nNote that since version 0.39 of DynamicPPL, threadsafe evaluation of models is disabled by default." *
384+
" If you need it, you will need to explicitly enable it by creating the model, and then running `model = setthreadsafe(model, true)`." *
385+
"\n\nAvoiding threadsafe evaluation can often lead to significant performance improvements. Please see https://turinglang.org/docs/THIS_PAGE_DOESNT_EXIST_YET for more details of when threadsafe evaluation is actually required."
386+
)
387+
else
388+
dump(expr.args[1])
389+
dump(:(Threads.@threads).args[1])
390+
end
391+
return generate_mainbody!(
392+
mod,
393+
found,
394+
macroexpand(mod, expr; recursive=true),
395+
warn,
396+
warned_about_threads_threads,
397+
)
371398
end
372399

373400
# Modify dotted tilde operators.
374401
args_dottilde = getargs_dottilde(expr)
375402
if args_dottilde !== nothing
376403
L, R = args_dottilde
377404
return generate_mainbody!(
378-
mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn
405+
mod,
406+
found,
407+
Base.remove_linenums!(generate_dot_tilde(L, R)),
408+
warn,
409+
warned_about_threads_threads,
379410
)
380411
end
381412

@@ -385,8 +416,8 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
385416
L, R = args_tilde
386417
return Base.remove_linenums!(
387418
generate_tilde(
388-
generate_mainbody!(mod, found, L, warn),
389-
generate_mainbody!(mod, found, R, warn),
419+
generate_mainbody!(mod, found, L, warn, warned_about_threads_threads),
420+
generate_mainbody!(mod, found, R, warn, warned_about_threads_threads),
390421
),
391422
)
392423
end
@@ -403,7 +434,13 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
403434
)
404435
end
405436

406-
return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...)
437+
return Expr(
438+
expr.head,
439+
map(
440+
x -> generate_mainbody!(mod, found, x, warn, warned_about_threads_threads),
441+
expr.args,
442+
)...,
443+
)
407444
end
408445

409446
function generate_assign(left, right)

src/fasteval.jl

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,9 @@ struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
212212
iden_varname_ranges::N
213213
varname_ranges::Dict{VarName,RangeAndLinked}
214214
end
215-
function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
215+
function (f::FastLogDensityAt{Model{F,A,D,M,Ta,Td,Ctx,false}})(
216+
params::AbstractVector{<:Real}
217+
) where {F,A,D,M,Ta,Td,Ctx}
216218
ctx = InitContext(
217219
Random.default_rng(),
218220
InitFromParams(
@@ -221,23 +223,27 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
221223
)
222224
model = DynamicPPL.setleafcontext(f.model, ctx)
223225
accs = fast_ldf_accs(f.getlogdensity)
224-
# Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
225-
# which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
226-
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
227-
# here.
228-
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
229-
# it _should_ do, but this is wrong regardless.
230-
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
231-
vi = if Threads.nthreads() > 1
232-
accs = map(
233-
acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc),
234-
accs,
235-
)
236-
ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
237-
else
238-
OnlyAccsVarInfo(accs)
239-
end
240-
_, vi = DynamicPPL._evaluate!!(model, vi)
226+
_, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs))
227+
return f.getlogdensity(vi)
228+
end
229+
function (f::FastLogDensityAt{Model{F,A,D,M,Ta,Td,Ctx,true}})(
230+
params::AbstractVector{<:Real}
231+
) where {F,A,D,M,Ta,Td,Ctx}
232+
ctx = InitContext(
233+
Random.default_rng(),
234+
InitFromParams(
235+
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
236+
),
237+
)
238+
model = DynamicPPL.setleafcontext(f.model, ctx)
239+
accs = fast_ldf_accs(f.getlogdensity)
240+
accs = map(
241+
acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc),
242+
accs,
243+
)
244+
vi_wrapped = ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
245+
_, vi_wrapped = DynamicPPL._evaluate!!(model, vi_wrapped)
246+
vi = OnlyAccsVarInfo(DynamicPPL.getaccs(vi_wrapped))
241247
return f.getlogdensity(vi)
242248
end
243249

0 commit comments

Comments
 (0)