Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,35 @@

### Breaking changes

#### Threadsafe evaluation

DynamicPPL models are by default no longer thread-safe.
If you have threading in a model, you **must** now manually mark it as so, using:

```julia
@model f() = ...
model = f()
model = setthreadsafe(model, true)
```

It used to be that DynamicPPL would 'automatically' enable thread-safe evaluation if Julia was launched with more than one thread (i.e., by checking `Threads.nthreads() > 1`).

The problem with this approach is that it sacrifices a huge amount of performance.
Furthermore, it is not actually the correct approach: just because Julia has multiple threads does not mean that a particular model actually requires threadsafe evaluation.

**A model requires threadsafe evaluation if, and only if, the VarInfo object used inside the model is manipulated in parallel.**
This can occur if any of the following are inside `Threads.@threads` or other concurrency functions / macros:

- tilde-statements
- calls to `@addlogprob!`
- any direct manipulation of the special `__varinfo__` variable

If you have none of these inside threaded blocks, then you do not need to mark your model as threadsafe.
**Notably, the following do not require threadsafe evaluation:**

- Using threading for anything that does not involve VarInfo. For example, you can calculate a log-probability in parallel, and then add it using `@addlogprob!` outside of the threaded block. This does not require threadsafe evaluation.
- Sampling with `AbstractMCMC.MCMCThreads()`.

#### Parent and leaf contexts

The `DynamicPPL.NodeTrait` function has been removed.
Expand Down
7 changes: 7 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ The context of a model can be set using [`contextualize`](@ref):
contextualize
```

Some models require threadsafe evaluation (see https://turinglang.org/docs/THIS_DOESNT_EXIST_YET for more information on when this is necessary).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a reminder note to change this before merging.

If this is the case, one must enable threadsafe evaluation for a model:

```@docs
setthreadsafe
```

## Evaluation

With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref).
Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ export AbstractVarInfo,
Model,
getmissings,
getargnames,
setthreadsafe,
extract_priors,
values_as_in_model,
# LogDensityFunction
Expand Down
60 changes: 47 additions & 13 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ function model(mod, linenumbernode, expr, warn)
modeldef = build_model_definition(expr)

# Generate main body
modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn)
modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn, false)

return build_output(modeldef, linenumbernode)
end
Expand Down Expand Up @@ -346,36 +346,64 @@ Generate the body of the main evaluation function from expression `expr` and arg
If `warn` is true, a warning is displayed if internal variables are used in the model
definition.
"""
generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn)
generate_mainbody(mod, expr, warn, warned_about_threads_threads) =
generate_mainbody!(mod, Symbol[], expr, warn, warned_about_threads_threads)

generate_mainbody!(mod, found, x, warn) = x
function generate_mainbody!(mod, found, sym::Symbol, warn)
generate_mainbody!(mod, found, x, warn, warned_about_threads_threads) = x
function generate_mainbody!(mod, found, sym::Symbol, warn, warned_about_threads_threads)
if warn && sym in INTERNALNAMES && sym ∉ found
@warn "you are using the internal variable `$sym`"
push!(found, sym)
end

return sym
end
function generate_mainbody!(mod, found, expr::Expr, warn)
function generate_mainbody!(mod, found, expr::Expr, warn, warned_about_threads_threads)
# Do not touch interpolated expressions
expr.head === :$ && return expr.args[1]

# Flag to determine whether we've issued a warning for threadsafe macros Note that this
# detection is not fully correct. We can only detect the presence of a macro that has
# the symbol `Threads.@threads`, however, we can't detect if that *is actually*
# Threads.@threads from Base.Threads.

# Do we don't want escaped expressions because we unfortunately
# escape the entire body afterwards.
Meta.isexpr(expr, :escape) && return generate_mainbody(mod, found, expr.args[1], warn)
Meta.isexpr(expr, :escape) && return generate_mainbody(
mod, found, expr.args[1], warn, warned_about_threads_threads
)

# If it's a macro, we expand it
if Meta.isexpr(expr, :macrocall)
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn)
if expr.args[1] == Expr(:., :Threads, QuoteNode(Symbol("@threads"))) &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder how often people do

using Threads: @threads
@threads

vs how often people call some other macro called @threads. I.e. false positives vs false negatives.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another case where I wasn't too sure where to draw the line. I think Threads.@threads probably accounts for most usage, but I'm not averse to also handling @threads since after all the warning message is quite noncommittal.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would err on the safe side in warning about @threads, but happy if you prefer otherwise.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, changed now

!warned_about_threads_threads
warned_about_threads_threads = true
@warn (
"It looks like you are using `Threads.@threads` in your model definition." *
"\n\nNote that since version 0.39 of DynamicPPL, threadsafe evaluation of models is disabled by default." *
" If you need it, you will need to explicitly enable it by creating the model, and then running `model = setthreadsafe(model, true)`." *
"\n\nAvoiding threadsafe evaluation can often lead to significant performance improvements. Please see https://turinglang.org/docs/THIS_PAGE_DOESNT_EXIST_YET for more details of when threadsafe evaluation is actually required."
Copy link
Member Author

@penelopeysm penelopeysm Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
end
return generate_mainbody!(
mod,
found,
macroexpand(mod, expr; recursive=true),
warn,
warned_about_threads_threads,
)
end

# Modify dotted tilde operators.
args_dottilde = getargs_dottilde(expr)
if args_dottilde !== nothing
L, R = args_dottilde
return generate_mainbody!(
mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn
mod,
found,
Base.remove_linenums!(generate_dot_tilde(L, R)),
warn,
warned_about_threads_threads,
)
end

Expand All @@ -385,8 +413,8 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
L, R = args_tilde
return Base.remove_linenums!(
generate_tilde(
generate_mainbody!(mod, found, L, warn),
generate_mainbody!(mod, found, R, warn),
generate_mainbody!(mod, found, L, warn, warned_about_threads_threads),
generate_mainbody!(mod, found, R, warn, warned_about_threads_threads),
),
)
end
Expand All @@ -397,13 +425,19 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
L, R = args_assign
return Base.remove_linenums!(
generate_assign(
generate_mainbody!(mod, found, L, warn),
generate_mainbody!(mod, found, R, warn),
generate_mainbody!(mod, found, L, warn, warned_about_threads_threads),
generate_mainbody!(mod, found, R, warn, warned_about_threads_threads),
),
)
end

return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...)
return Expr(
expr.head,
map(
x -> generate_mainbody!(mod, found, x, warn, warned_about_threads_threads),
expr.args,
)...,
)
end

function generate_assign(left, right)
Expand Down
7 changes: 5 additions & 2 deletions src/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,11 @@ function check_model_and_trace(
# Perform checks before evaluating the model.
issuccess = check_model_pre_evaluation(model)

# Force single-threaded execution.
_, varinfo = DynamicPPL.evaluate_threadunsafe!!(model, varinfo)
# TODO(penelopeysm): Implement merge, etc. for DebugAccumulator, and then perform a
# check on the merged accumulator, rather than checking it in the accumulate_assume
# calls. That way we can also support multi-threaded evaluation and use `evaluate!!`
# here instead of `_evaluate!!`.
_, varinfo = DynamicPPL._evaluate!!(model, varinfo)

# Perform checks after evaluating the model.
debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME))
Expand Down
42 changes: 24 additions & 18 deletions src/fasteval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
iden_varname_ranges::N
varname_ranges::Dict{VarName,RangeAndLinked}
end
function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
function (f::FastLogDensityAt{Model{F,A,D,M,Ta,Td,Ctx,false}})(
params::AbstractVector{<:Real}
) where {F,A,D,M,Ta,Td,Ctx}
ctx = InitContext(
Random.default_rng(),
InitFromParams(
Expand All @@ -221,23 +223,27 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
)
model = DynamicPPL.setleafcontext(f.model, ctx)
accs = fast_ldf_accs(f.getlogdensity)
# Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
# which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
# here.
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
# it _should_ do, but this is wrong regardless.
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
vi = if Threads.nthreads() > 1
accs = map(
acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc),
accs,
)
ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
else
OnlyAccsVarInfo(accs)
end
_, vi = DynamicPPL._evaluate!!(model, vi)
_, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs))
return f.getlogdensity(vi)
end
function (f::FastLogDensityAt{Model{F,A,D,M,Ta,Td,Ctx,true}})(
params::AbstractVector{<:Real}
) where {F,A,D,M,Ta,Td,Ctx}
ctx = InitContext(
Random.default_rng(),
InitFromParams(
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
),
)
model = DynamicPPL.setleafcontext(f.model, ctx)
accs = fast_ldf_accs(f.getlogdensity)
accs = map(
acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc),
accs,
)
vi_wrapped = ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
_, vi_wrapped = DynamicPPL._evaluate!!(model, vi_wrapped)
vi = OnlyAccsVarInfo(DynamicPPL.getaccs(vi_wrapped))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be simpler to have a single method and wrap this part in a if is_threaded(f.model)? It should get constant propagated away at compile time. Very optional to change, the current version isn't bad either.

This probably purely a style question, but there could be a difference in that listing all the type parameters of Model in the function signature I think forces specialisation on all of them.

Copy link
Member Author

@penelopeysm penelopeysm Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good idea, I was a bit concerned about the specialisation too. I thought about making threaded the first type parameter, which would also avoid this (and we don't rarely dispatch on any other type parameters in Model)... then decided against it because it might be too breaking

return f.getlogdensity(vi)
end

Expand Down
Loading
Loading