Skip to content

Commit 729bfba

Browse files
penelopeysmmhauru
andauthored
InitContext, part 4 - Use init!! to replace evaluate_and_sample!!, predict, returned, and initialize_values (#984)
* Replace `evaluate_and_sample!!` -> `init!!` * Use `ParamsInit` for `predict`; remove `setval_and_resample!` and friends * Use `init!!` for initialisation * Paper over the `Sampling->Init` context stack (pending removal of SamplingContext) * Remove SamplingContext from JETExt to avoid triggering `Sampling->Init` pathway * Remove `predict` on vector of VarInfo * Fix some tests * Remove duplicated test * Simplify context testing * Rename FooInit -> InitFromFoo * Fix JETExt * Fix JETExt properly * Fix tests * Improve comments * Remove duplicated tests * Docstring improvements Co-authored-by: Markus Hauru <[email protected]> * Concretise `chain_sample_to_varname_dict` using chain value type * Clarify testset name * Re-add comment that shouldn't have vanished * Fix stale Requires dep * Fix default_varinfo/initialisation for odd models * Add comment to src/sampler.jl Co-authored-by: Markus Hauru <[email protected]> --------- Co-authored-by: Markus Hauru <[email protected]>
1 parent 1e1cd94 commit 729bfba

23 files changed

+348
-720
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2323
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2424
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2525
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
26-
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2726
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2827
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2928

@@ -71,7 +70,6 @@ Mooncake = "0.4.147"
7170
OrderedCollections = "1"
7271
Printf = "1.10"
7372
Random = "1.6"
74-
Requires = "1"
7573
Statistics = "1"
7674
Test = "1.6"
7775
julia = "1.10.8"

docs/src/api.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -447,11 +447,6 @@ AbstractPPL.evaluate!!
447447

448448
This method mutates the `varinfo` used for execution.
449449
By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`.
450-
To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method:
451-
452-
```@docs
453-
DynamicPPL.evaluate_and_sample!!
454-
```
455450

456451
The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model.
457452
Contexts are subtypes of `AbstractPPL.AbstractContext`.
@@ -466,7 +461,12 @@ InitContext
466461

467462
### VarInfo initialisation
468463

469-
`InitContext` is used to initialise, or overwrite, values in a VarInfo.
464+
The function `init!!` is used to initialise, or overwrite, values in a VarInfo.
465+
It is really a thin wrapper around using `evaluate!!` with an `InitContext`.
466+
467+
```@docs
468+
DynamicPPL.init!!
469+
```
470470

471471
To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained.
472472
There are three concrete strategies provided in DynamicPPL:
@@ -505,7 +505,7 @@ The default implementation of [`Sampler`](@ref) uses the following unexported fu
505505
```@docs
506506
DynamicPPL.initialstep
507507
DynamicPPL.loadstate
508-
DynamicPPL.initialsampler
508+
DynamicPPL.init_strategy
509509
```
510510

511511
Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`.

ext/DynamicPPLJETExt.jl

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ using JET: JET
66
function DynamicPPL.Experimental.is_suitable_varinfo(
77
model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; only_ddpl::Bool=true
88
)
9-
# Let's make sure that both evaluation and sampling doesn't result in type errors.
109
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(model, varinfo)
1110
# If specified, we only check errors originating somewhere in the DynamicPPL.jl.
1211
# This way we don't just fall back to untyped if the user's code is the issue.
@@ -21,32 +20,40 @@ end
2120
function DynamicPPL.Experimental._determine_varinfo_jet(
2221
model::DynamicPPL.Model; only_ddpl::Bool=true
2322
)
24-
# Use SamplingContext to test type stability.
25-
sampling_model = DynamicPPL.contextualize(
26-
model, DynamicPPL.SamplingContext(model.context)
27-
)
28-
29-
# First we try with the typed varinfo.
30-
varinfo = DynamicPPL.typed_varinfo(sampling_model)
23+
# Generate a typed varinfo to test model type stability with
24+
varinfo = DynamicPPL.typed_varinfo(model)
3125

32-
# Let's make sure that both evaluation and sampling doesn't result in type errors.
33-
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
34-
sampling_model, varinfo; only_ddpl
26+
# Check type stability of evaluation (i.e. DefaultContext)
27+
model = DynamicPPL.contextualize(
28+
model, DynamicPPL.setleafcontext(model.context, DynamicPPL.DefaultContext())
29+
)
30+
eval_issuccess, eval_result = DynamicPPL.Experimental.is_suitable_varinfo(
31+
model, varinfo; only_ddpl
3532
)
33+
if !eval_issuccess
34+
@debug "Evaluation with typed varinfo failed with the following issues:"
35+
@debug eval_result
36+
end
3637

37-
if !issuccess
38-
# Useful information for debugging.
39-
@debug "Evaluaton with typed varinfo failed with the following issues:"
40-
@debug result
38+
# Check type stability of initialisation (i.e. InitContext)
39+
model = DynamicPPL.contextualize(
40+
model, DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext())
41+
)
42+
init_issuccess, init_result = DynamicPPL.Experimental.is_suitable_varinfo(
43+
model, varinfo; only_ddpl
44+
)
45+
if !init_issuccess
46+
@debug "Initialisation with typed varinfo failed with the following issues:"
47+
@debug init_result
4148
end
4249

43-
# If we didn't fail anywhere, we return the type stable one.
44-
return if issuccess
50+
# If neither of them failed, we can return the typed varinfo as it's type stable.
51+
return if (eval_issuccess && init_issuccess)
4552
varinfo
4653
else
4754
# Warn the user that we can't use the type stable one.
4855
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
49-
DynamicPPL.untyped_varinfo(sampling_model)
56+
DynamicPPL.untyped_varinfo(model)
5057
end
5158
end
5259

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ end
2323

2424
function _check_varname_indexing(c::MCMCChains.Chains)
2525
return DynamicPPL.supports_varname_indexing(c) ||
26-
error("Chains do not support indexing using `VarName`s.")
26+
error("This `Chains` object does not support indexing using `VarName`s.")
2727
end
2828

2929
function DynamicPPL.getindex_varname(
@@ -37,6 +37,17 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
3737
return keys(c.info.varname_to_symbol)
3838
end
3939

40+
function chain_sample_to_varname_dict(
41+
c::MCMCChains.Chains{Tval}, sample_idx, chain_idx
42+
) where {Tval}
43+
_check_varname_indexing(c)
44+
d = Dict{DynamicPPL.VarName,Tval}()
45+
for vn in DynamicPPL.varnames(c)
46+
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx)
47+
end
48+
return d
49+
end
50+
4051
"""
4152
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
4253
@@ -109,9 +120,15 @@ function DynamicPPL.predict(
109120

110121
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
111122
predictive_samples = map(iters) do (sample_idx, chain_idx)
112-
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
113-
varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo))
114-
123+
# Extract values from the chain
124+
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
125+
# Resample any variables that are not present in `values_dict`
126+
_, varinfo = DynamicPPL.init!!(
127+
rng,
128+
model,
129+
varinfo,
130+
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
131+
)
115132
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
116133
varname_vals = mapreduce(
117134
collect,
@@ -243,13 +260,16 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
243260
varinfo = DynamicPPL.VarInfo(model)
244261
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
245262
return map(iters) do (sample_idx, chain_idx)
246-
# TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702.
247-
# Update the varinfo with the current sample and make variables not present in `chain`
248-
# to be sampled.
249-
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
250-
# NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
251-
# `deepcopy` the `varinfo` before passing it to the `model`.
252-
model(deepcopy(varinfo))
263+
# Extract values from the chain
264+
values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx)
265+
# Resample any variables that are not present in `values_dict`, and
266+
# return the model's retval.
267+
retval, _ = DynamicPPL.init!!(
268+
model,
269+
varinfo,
270+
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
271+
)
272+
retval
253273
end
254274
end
255275

src/DynamicPPL.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,6 @@ include("test_utils.jl")
206206
include("experimental.jl")
207207
include("deprecated.jl")
208208

209-
if !isdefined(Base, :get_extension)
210-
using Requires
211-
end
212-
213209
# Better error message if users forget to load JET
214210
if isdefined(Base.Experimental, :register_error_hint)
215211
function __init__()

src/extract_priors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ extract_priors(args::Union{Model,AbstractVarInfo}...) =
123123
function extract_priors(rng::Random.AbstractRNG, model::Model)
124124
varinfo = VarInfo()
125125
varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(),))
126-
varinfo = last(evaluate_and_sample!!(rng, model, varinfo))
126+
varinfo = last(init!!(rng, model, varinfo))
127127
return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors
128128
end
129129

src/model.jl

Lines changed: 9 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,7 @@ end
850850
# ^ Weird Documenter.jl bug means that we have to write the two above separately
851851
# as it can only detect the `function`-less syntax.
852852
function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo())
853-
return first(evaluate_and_sample!!(rng, model, varinfo))
853+
return first(init!!(rng, model, varinfo))
854854
end
855855

856856
"""
@@ -863,32 +863,6 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
863863
return Threads.nthreads() > 1
864864
end
865865

866-
"""
867-
evaluate_and_sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler])
868-
869-
Evaluate the `model` with the given `varinfo`, but perform sampling during the
870-
evaluation using the given `sampler` by wrapping the model's context in a
871-
`SamplingContext`.
872-
873-
If `sampler` is not provided, defaults to [`SampleFromPrior`](@ref).
874-
875-
Returns a tuple of the model's return value, plus the updated `varinfo` object.
876-
"""
877-
function evaluate_and_sample!!(
878-
rng::Random.AbstractRNG,
879-
model::Model,
880-
varinfo::AbstractVarInfo,
881-
sampler::AbstractSampler=SampleFromPrior(),
882-
)
883-
sampling_model = contextualize(model, SamplingContext(rng, sampler, model.context))
884-
return evaluate!!(sampling_model, varinfo)
885-
end
886-
function evaluate_and_sample!!(
887-
model::Model, varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior()
888-
)
889-
return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler)
890-
end
891-
892866
"""
893867
init!!(
894868
[rng::Random.AbstractRNG,]
@@ -897,12 +871,12 @@ end
897871
[init_strategy::AbstractInitStrategy=InitFromPrior()]
898872
)
899873
900-
Evaluate the `model` and replace the values of the model's random variables in
901-
the given `varinfo` with new values using a specified initialisation strategy.
902-
If the values in `varinfo` are not already present, they will be added using
903-
that same strategy.
874+
Evaluate the `model` and replace the values of the model's random variables
875+
in the given `varinfo` with new values, using a specified initialisation strategy.
876+
If the values in `varinfo` are not set, they will be added
877+
using a specified initialisation strategy.
904878
905-
If `init_strategy` is not provided, defaults to InitFromPrior().
879+
If `init_strategy` is not provided, defaults to `InitFromPrior()`.
906880
907881
Returns a tuple of the model's return value, plus the updated `varinfo` object.
908882
"""
@@ -1051,11 +1025,7 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f)
10511025
Generate a sample of type `T` from the prior distribution of the `model`.
10521026
"""
10531027
function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T}
1054-
x = last(
1055-
evaluate_and_sample!!(
1056-
rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())
1057-
),
1058-
)
1028+
x = last(init!!(rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())))
10591029
return values_as(x, T)
10601030
end
10611031

@@ -1227,25 +1197,8 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC
12271197
end
12281198
end
12291199

1230-
"""
1231-
predict([rng::Random.AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo})
1232-
1233-
Generate samples from the posterior predictive distribution by evaluating `model` at each set
1234-
of parameter values provided in `chain`. The number of posterior predictive samples matches
1235-
the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values
1236-
and the predicted values.
1237-
"""
1238-
function predict(
1239-
rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo}
1240-
)
1241-
varinfo = DynamicPPL.VarInfo(model)
1242-
return map(chain) do params_varinfo
1243-
vi = deepcopy(varinfo)
1244-
DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple))
1245-
model(rng, vi)
1246-
return vi
1247-
end
1248-
end
1200+
# Implemented & documented in DynamicPPLMCMCChainsExt
1201+
function predict end
12491202

12501203
"""
12511204
returned(model::Model, parameters::NamedTuple)

0 commit comments

Comments
 (0)