Skip to content

InitContext, part 4 - Use init!! to replace evaluate_and_sample!!, predict, returned, and initialize_values #984

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: breaking
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -449,11 +449,6 @@ AbstractPPL.evaluate!!

This method mutates the `varinfo` used for execution.
By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`.
To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method:

```@docs
DynamicPPL.evaluate_and_sample!!
```

The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model.
Contexts are subtypes of `AbstractPPL.AbstractContext`.
Expand All @@ -468,7 +463,12 @@ InitContext

### VarInfo initialisation

`InitContext` is used to initialise, or overwrite, values in a VarInfo.
The function `init!!` is used to initialise, or overwrite, values in a VarInfo.
It is really a thin wrapper around using `evaluate!!` with an `InitContext`.

```@docs
DynamicPPL.init!!
```

To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained.
There are three concrete strategies provided in DynamicPPL:
Expand Down Expand Up @@ -507,7 +507,7 @@ The default implementation of [`Sampler`](@ref) uses the following unexported fu
```@docs
DynamicPPL.initialstep
DynamicPPL.loadstate
DynamicPPL.initialsampler
DynamicPPL.init_strategy
```

Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`.
Expand Down
43 changes: 25 additions & 18 deletions ext/DynamicPPLJETExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ using JET: JET
function DynamicPPL.Experimental.is_suitable_varinfo(
model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; only_ddpl::Bool=true
)
# Let's make sure that both evaluation and sampling doesn't result in type errors.
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(model, varinfo)
# If specified, we only check errors originating somewhere in the DynamicPPL.jl.
# This way we don't just fall back to untyped if the user's code is the issue.
Expand All @@ -21,32 +20,40 @@ end
function DynamicPPL.Experimental._determine_varinfo_jet(
model::DynamicPPL.Model; only_ddpl::Bool=true
)
Comment on lines 20 to 22
Copy link
Member Author

@penelopeysm penelopeysm Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused by the comments in this function because as far as I can tell it only ever tested sampling, not both sampling and evaluation. (That was also true going further back e.g. in v0.36)

This PR thus also changes the implementation of this function to test both evaluation and sampling (i.e. initialisation) and if either fails, it will return the untyped varinfo.

Sorry I had to make this change in this PR. There were a few unholy tests where one would end up evaluating a model with a SamplingContext{<:InitContext}, which would error unless I introduced special code to handle it, and I didn't really want to do that. JETExt was one of those unholy scenarios.

# Use SamplingContext to test type stability.
sampling_model = DynamicPPL.contextualize(
model, DynamicPPL.SamplingContext(model.context)
)

# First we try with the typed varinfo.
varinfo = DynamicPPL.typed_varinfo(sampling_model)
# Generate a typed varinfo to test model type stability with
varinfo = DynamicPPL.typed_varinfo(model)

# Let's make sure that both evaluation and sampling doesn't result in type errors.
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
sampling_model, varinfo; only_ddpl
# Check type stability of evaluation (i.e. DefaultContext)
model = DynamicPPL.contextualize(
model, DynamicPPL.setleafcontext(model.context, DynamicPPL.DefaultContext())
)
eval_issuccess, eval_result = DynamicPPL.Experimental.is_suitable_varinfo(
model, varinfo; only_ddpl
)
if !eval_issuccess
@debug "Evaluation with typed varinfo failed with the following issues:"
@debug eval_result
end

if !issuccess
# Useful information for debugging.
@debug "Evaluaton with typed varinfo failed with the following issues:"
@debug result
# Check type stability of initialisation (i.e. InitContext)
model = DynamicPPL.contextualize(
model, DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext())
)
init_issuccess, init_result = DynamicPPL.Experimental.is_suitable_varinfo(
model, varinfo; only_ddpl
)
if !init_issuccess
@debug "Initialisation with typed varinfo failed with the following issues:"
@debug init_result
end

# If we didn't fail anywhere, we return the type stable one.
return if issuccess
# If neither of them failed, we can return the typed varinfo as it's type stable.
return if (eval_issuccess && init_issuccess)
varinfo
else
# Warn the user that we can't use the type stable one.
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
DynamicPPL.untyped_varinfo(sampling_model)
DynamicPPL.untyped_varinfo(model)
end
end

Expand Down
40 changes: 29 additions & 11 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ end

function _check_varname_indexing(c::MCMCChains.Chains)
return DynamicPPL.supports_varname_indexing(c) ||
error("Chains do not support indexing using `VarName`s.")
error("This `Chains` object does not support indexing using `VarName`s.")
end

function DynamicPPL.getindex_varname(
Expand All @@ -42,6 +42,15 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
return keys(c.info.varname_to_symbol)
end

function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx)
_check_varname_indexing(c)
d = Dict{DynamicPPL.VarName,Any}()
for vn in DynamicPPL.varnames(c)
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx)
end
return d
end
Comment on lines +45 to +52
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that, if the chain does not store varnames inside its info field, chain_sample_to_varname_dict will fail.

I don't think this is a huge problem right now because every chain obtained via Turing's sample() will contain varnames:

https://github.com/TuringLang/Turing.jl/blob/1aa95ac91a115569c742bab74f7b751ed1450309/src/mcmc/Inference.jl#L288-L290

So this is only a problem if one manually constructs a chain and tries to call predict on it, which I think is a highly unlikely workflow (and I'm happy to wait for people to complain if it fails). There are a few places in DynamicPPL's test suite where this does actually happen. I fixed them all by manually adding the varname dictionary.

However, it's obviously ugly. The only good way around this is to rework MCMCChains.jl :( (See here for the implementation of the corresponding functionality in FlexiChains.)


"""
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)

Expand Down Expand Up @@ -114,9 +123,15 @@ function DynamicPPL.predict(

iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
predictive_samples = map(iters) do (sample_idx, chain_idx)
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo))

# Extract values from the chain
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
# Resample any variables that are not present in `values_dict`
_, varinfo = DynamicPPL.init!!(
rng,
model,
varinfo,
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
)
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
varname_vals = mapreduce(
collect,
Expand Down Expand Up @@ -248,13 +263,16 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
# TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702.
# Update the varinfo with the current sample and make variables not present in `chain`
# to be sampled.
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
# NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
# `deepcopy` the `varinfo` before passing it to the `model`.
model(deepcopy(varinfo))
# Extract values from the chain
values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx)
# Resample any variables that are not present in `values_dict`, and
# return the model's retval.
retval, _ = DynamicPPL.init!!(
model,
varinfo,
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
)
retval
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/extract_priors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ extract_priors(args::Union{Model,AbstractVarInfo}...) =
function extract_priors(rng::Random.AbstractRNG, model::Model)
varinfo = VarInfo()
varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(),))
varinfo = last(evaluate_and_sample!!(rng, model, varinfo))
varinfo = last(init!!(rng, model, varinfo))
return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors
end

Expand Down
63 changes: 8 additions & 55 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ end
# ^ Weird Documenter.jl bug means that we have to write the two above separately
# as it can only detect the `function`-less syntax.
function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo())
return first(evaluate_and_sample!!(rng, model, varinfo))
return first(init!!(rng, model, varinfo))
end

"""
Expand All @@ -863,32 +863,6 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
return Threads.nthreads() > 1
end

"""
evaluate_and_sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler])

Evaluate the `model` with the given `varinfo`, but perform sampling during the
evaluation using the given `sampler` by wrapping the model's context in a
`SamplingContext`.

If `sampler` is not provided, defaults to [`SampleFromPrior`](@ref).

Returns a tuple of the model's return value, plus the updated `varinfo` object.
"""
function evaluate_and_sample!!(
rng::Random.AbstractRNG,
model::Model,
varinfo::AbstractVarInfo,
sampler::AbstractSampler=SampleFromPrior(),
)
sampling_model = contextualize(model, SamplingContext(rng, sampler, model.context))
return evaluate!!(sampling_model, varinfo)
end
function evaluate_and_sample!!(
model::Model, varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior()
)
return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler)
end

"""
init!!(
[rng::Random.AbstractRNG,]
Expand All @@ -897,10 +871,10 @@ end
[init_strategy::AbstractInitStrategy=InitFromPrior()]
)

Evaluate the `model` and replace the values of the model's random variables in
the given `varinfo` with new values using a specified initialisation strategy.
If the values in `varinfo` are not already present, they will be added using
that same strategy.
Evaluate the `model` and replace the values of the model's random variables
in the given `varinfo` with new values, using a specified initialisation strategy.
If the values in `varinfo` are not set, they will be added.
using a specified initialisation strategy.

If `init_strategy` is not provided, defaults to InitFromPrior().

Expand Down Expand Up @@ -1051,11 +1025,7 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f)
Generate a sample of type `T` from the prior distribution of the `model`.
"""
function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T}
x = last(
evaluate_and_sample!!(
rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())
),
)
x = last(init!!(rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())))
return values_as(x, T)
end

Expand Down Expand Up @@ -1227,25 +1197,8 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC
end
end

"""
predict([rng::Random.AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo})

Generate samples from the posterior predictive distribution by evaluating `model` at each set
of parameter values provided in `chain`. The number of posterior predictive samples matches
the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values
and the predicted values.
"""
function predict(
rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo}
)
varinfo = DynamicPPL.VarInfo(model)
return map(chain) do params_varinfo
vi = deepcopy(varinfo)
DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple))
model(rng, vi)
return vi
end
end
# Implemented & documented in DynamicPPLMCMCChainsExt
function predict end
Comment on lines -1230 to +1201
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was discussed at one of the meetings and we decided we didn't care enough about the predict method on vectors of varinfos. It's currently bugged because varinfo is always unlinked, but params_varinfo might be linked, and if it is, it will give wrong results because it sets a linked value into an unlinked varinfo. See #983.


"""
returned(model::Model, parameters::NamedTuple)
Expand Down
Loading
Loading