Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 10 additions & 17 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Part of the API of DynamicPPL is defined in the more lightweight interface packa

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it time to start a HISTORY.md entry? Might be easier to do it here were you can cross-check against what's being removed, rather than once everything is in breaking in a huge diff.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good idea to do it in this PR. I'll write one up later and ping you again

A core component of DynamicPPL is the [`@model`](@ref) macro.
It can be used to define probabilistic models in an intuitive way by specifying random variables and their distributions with `~` statements.
These statements are rewritten by `@model` as calls of [internal functions](@ref model_internal) for sampling the variables and computing their log densities.
These statements are rewritten by `@model` as calls of internal functions for sampling the variables and computing their log densities.

```@docs
@model
Expand Down Expand Up @@ -344,6 +344,13 @@ Base.empty!
SimpleVarInfo
```

### Tilde-pipeline

```@docs
tilde_assume!!
tilde_observe!!
```

### Accumulators

The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators.
Expand Down Expand Up @@ -447,12 +454,12 @@ AbstractPPL.evaluate!!

This method mutates the `varinfo` used for execution.
By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`.
If you wish to sample new values, see the section on [VarInfo initialisation](#VarInfo-initialisation) just below this.

The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model.
Contexts are subtypes of `AbstractPPL.AbstractContext`.

```@docs
SamplingContext
DefaultContext
PrefixContext
ConditionContext
Expand Down Expand Up @@ -486,15 +493,7 @@ DynamicPPL.init

### Samplers

In DynamicPPL two samplers are defined that are used to initialize unobserved random variables:
[`SampleFromPrior`](@ref) which samples from the prior distribution, and [`SampleFromUniform`](@ref) which samples from a uniform distribution.

```@docs
SampleFromPrior
SampleFromUniform
```

Additionally, a generic sampler for inference is implemented.
In DynamicPPL a generic sampler for inference is implemented.

```@docs
Sampler
Expand All @@ -520,9 +519,3 @@ There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_va
DynamicPPL.Experimental.determine_suitable_varinfo
DynamicPPL.Experimental.is_suitable_varinfo
```

### [Model-Internal Functions](@id model_internal)

```@docs
tilde_assume
```
2 changes: 0 additions & 2 deletions ext/DynamicPPLEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ else
using ..EnzymeCore
end

@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true

# Mark istrans as having 0 derivative. The `nothing` return value is not significant, Enzyme
# only checks whether such a method exists, and never runs it.
@inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.istrans), args...) = nothing
Expand Down
8 changes: 3 additions & 5 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,18 +96,16 @@ export AbstractVarInfo,
values_as_in_model,
# Samplers
Sampler,
SampleFromPrior,
SampleFromUniform,
# LogDensityFunction
LogDensityFunction,
# Contexts
contextualize,
SamplingContext,
DefaultContext,
PrefixContext,
ConditionContext,
assume,
tilde_assume,
# Tilde pipeline
tilde_assume!!,
tilde_observe!!,
# Initialisation
InitContext,
AbstractInitStrategy,
Expand Down
217 changes: 87 additions & 130 deletions src/context_implementations.jl
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry this file's diff is a bit of a mess, there are no real code changes, it's just:

  • merged tilde_assume, tilde_assume!!, and assume into a single function (they all just called each other)
  • added types to the arguments
  • added proper docstrings

Original file line number Diff line number Diff line change
@@ -1,93 +1,102 @@
# assume
"""
tilde_assume(context::SamplingContext, right, vn, vi)

Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
accumulate the log probability, and return the sampled value with a context associated
with a sampler.

Falls back to
```julia
tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
```
DynamicPPL.tilde_assume!!(
context::AbstractContext,
right::Distribution,
vn::VarName,
vi::AbstractVarInfo
)

Handle assumed variables, i.e. anything which is not observed (see
[`tilde_observe!!`](@ref)). Accumulate the associated log probability, and return the
sampled value and updated `vi`.

`vn` is the VarName on the left-hand side of the tilde statement.
"""
function tilde_assume(context::SamplingContext, right, vn, vi)
return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
end

function tilde_assume(context::AbstractContext, args...)
return tilde_assume(childcontext(context), args...)
end
function tilde_assume(::DefaultContext, right, vn, vi)
return assume(right, vn, vi)
end

function tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...)
return tilde_assume(rng, childcontext(context), args...)
end
function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi)
return assume(rng, sampler, right, vn, vi)
function tilde_assume!!(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To confirm that I understand how this will play out in Turing.jl: The idea is that samplers don't need to modify the behaviour of the tilde pipeline any more, and thus SamplingContext can go in its entirety, and we don't need things like tilde_assume without !! or assume. And the few that do still need to do that (Gibbs, maybe particles, hopefully nothing else) need to define their own context. Is that right?

Copy link
Member Author

@penelopeysm penelopeysm Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, although I'm not 100% sure. We'll probably have to do a similar process to what we did last time for accumulators: make a Turing PR that builds against the breaking branch of DPPL. Things might break. But I'm hoping it won't be too bad haha.

context::AbstractContext, right::Distribution, vn::VarName, vi::AbstractVarInfo
)
return tilde_assume!!(childcontext(context), right, vn, vi)
end
function tilde_assume(::DefaultContext, sampler, right, vn, vi)
# same as above but no rng
return assume(Random.default_rng(), sampler, right, vn, vi)
function tilde_assume!!(
::DefaultContext, right::Distribution, vn::VarName, vi::AbstractVarInfo
)
y = getindex_internal(vi, vn)
f = from_maybe_linked_internal_transform(vi, vn, right)
x, inv_logjac = with_logabsdet_jacobian(f, y)
vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right)
return x, vi
end

function tilde_assume(context::PrefixContext, right, vn, vi)
function tilde_assume!!(
context::PrefixContext, right::Distribution, vn::VarName, vi::AbstractVarInfo
)
# Note that we can't use something like this here:
# new_vn = prefix(context, vn)
# return tilde_assume(childcontext(context), right, new_vn, vi)
# return tilde_assume!!(childcontext(context), right, new_vn, vi)
# This is because `prefix` applies _all_ prefixes in a given context to a
# variable name. Thus, if we had two levels of nested prefixes e.g.
# `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the
# first call would apply the prefix `a.b._`, and the recursive call
# would apply the prefix `b._`, resulting in `b.a.b._`.
# This is why we need a special function, `prefix_and_strip_contexts`.
new_vn, new_context = prefix_and_strip_contexts(context, vn)
return tilde_assume(new_context, right, new_vn, vi)
return tilde_assume!!(new_context, right, new_vn, vi)
end
function tilde_assume(
rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi
"""
DynamicPPL.tilde_assume!!(
context::AbstractContext,
right::DynamicPPL.Submodel,
vn::VarName,
vi::AbstractVarInfo
)

Evaluate the submodel with the given context.
"""
function tilde_assume!!(
context::AbstractContext, right::DynamicPPL.Submodel, vn::VarName, vi::AbstractVarInfo
)
new_vn, new_context = prefix_and_strip_contexts(context, vn)
return tilde_assume(rng, new_context, sampler, right, new_vn, vi)
return _evaluate!!(right, vi, context, vn)
end

"""
tilde_assume!!(context, right, vn, vi)
tilde_observe!!(
context::AbstractContext,
right::Distribution,
left,
vn::Union{VarName, Nothing},
vi::AbstractVarInfo
)

Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
accumulate the log probability, and return the sampled value and updated `vi`.
This function handles observed variables, which may be:

By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log
probability of `vi` with the returned value.
"""
function tilde_assume!!(context, right, vn, vi)
return if right isa DynamicPPL.Submodel
_evaluate!!(right, vi, context, vn)
else
tilde_assume(context, right, vn, vi)
end
end
- literals on the left-hand side, e.g., `3.0 ~ Normal()`
- a model input, e.g. `x ~ Normal()` in a model `@model f(x) ... end`
- a conditioned or fixed variable, e.g. `x ~ Normal()` in a model `model | (; x = 3.0)`.

# observe
"""
tilde_observe!!(context::SamplingContext, right, left, vi)
The relevant log-probability associated with the observation is computed and accumulated in
the VarInfo object `vi` (except for fixed variables, which do not contribute to the
log-probability).

Handle observed constants with a `context` associated with a sampler.
`left` is the actual value that the left-hand side evaluates to. `vn` is the VarName on the
left-hand side, or `nothing` if the left-hand side is a literal value.

Falls back to `tilde_observe!!(context.context, right, left, vi)`.
Observations of submodels are not yet supported in DynamicPPL.
"""
function tilde_observe!!(context::SamplingContext, right, left, vn, vi)
return tilde_observe!!(context.context, right, left, vn, vi)
end

function tilde_observe!!(context::AbstractContext, right, left, vn, vi)
function tilde_observe!!(
context::AbstractContext,
right::Distribution,
left,
vn::Union{VarName,Nothing},
vi::AbstractVarInfo,
)
return tilde_observe!!(childcontext(context), right, left, vn, vi)
end

# `PrefixContext`
function tilde_observe!!(context::PrefixContext, right, left, vn, vi)
function tilde_observe!!(
context::PrefixContext,
right::Distribution,
left,
vn::Union{VarName,Nothing},
vi::AbstractVarInfo,
)
# In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal
# value. For the need for prefix_and_strip_contexts rather than just prefix, see the
# comment in `tilde_assume!!`.
Expand All @@ -98,74 +107,22 @@ function tilde_observe!!(context::PrefixContext, right, left, vn, vi)
end
return tilde_observe!!(new_context, right, left, new_vn, vi)
end

"""
tilde_observe!!(context, right, left, vn, vi)

Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
accumulate the log probability, and return the observed value and updated `vi`.

Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name
and indices; if needed, these can be accessed through this function, though.
"""
function tilde_observe!!(::DefaultContext, right, left, vn, vi)
right isa DynamicPPL.Submodel &&
throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed"))
function tilde_observe!!(
::DefaultContext,
right::Distribution,
left,
vn::Union{VarName,Nothing},
vi::AbstractVarInfo,
)
vi = accumulate_observe!!(vi, right, left, vn)
return left, vi
end

function assume(::Random.AbstractRNG, spl::Sampler, dist)
return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))")
end

# fallback without sampler
function assume(dist::Distribution, vn::VarName, vi)
y = getindex_internal(vi, vn)
f = from_maybe_linked_internal_transform(vi, vn, dist)
x, inv_logjac = with_logabsdet_jacobian(f, y)
vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist)
return x, vi
end

# TODO: Remove this thing.
# SampleFromPrior and SampleFromUniform
function assume(
rng::Random.AbstractRNG,
sampler::Union{SampleFromPrior,SampleFromUniform},
dist::Distribution,
vn::VarName,
vi::VarInfoOrThreadSafeVarInfo,
function tilde_observe!!(
::AbstractContext,
::DynamicPPL.Submodel,
left,
vn::Union{VarName,Nothing},
::AbstractVarInfo,
)
if haskey(vi, vn)
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
# if that's okay.
unset_flag!(vi, vn, "del", true)
r = init(rng, dist, sampler)
f = to_maybe_linked_internal_transform(vi, vn, dist)
# TODO(mhauru) This should probably be call a function called setindex_internal!
vi = BangBang.setindex!!(vi, f(r), vn)
else
# Otherwise we just extract it.
r = vi[vn, dist]
end
else
r = init(rng, dist, sampler)
if istrans(vi)
f = to_linked_internal_transform(vi, vn, dist)
vi = push!!(vi, vn, f(r), dist)
# By default `push!!` sets the transformed flag to `false`.
vi = settrans!!(vi, true, vn)
else
vi = push!!(vi, vn, r, dist)
end
end

# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
vi = accumulate_assume!!(vi, r, logjac, vn, dist)
return r, vi
throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed"))
end
Loading
Loading