diff --git a/HISTORY.md b/HISTORY.md index d67afcbfe..f7aa1a03c 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -52,6 +52,10 @@ The separation of these functions was primarily implemented to avoid performing **Other changes** +### `setleafcontext(model, context)` + +This convenience method has been added to quickly modify the leaf context of a model. + ### Reimplementation of functions using `InitContext` A number of functions have been reimplemented and unified with the help of `InitContext`. diff --git a/ext/DynamicPPLJETExt.jl b/ext/DynamicPPLJETExt.jl index 55016d40c..e0163bb35 100644 --- a/ext/DynamicPPLJETExt.jl +++ b/ext/DynamicPPLJETExt.jl @@ -24,9 +24,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet( varinfo = DynamicPPL.typed_varinfo(model) # Check type stability of evaluation (i.e. DefaultContext) - model = DynamicPPL.contextualize( - model, DynamicPPL.setleafcontext(model.context, DynamicPPL.DefaultContext()) - ) + model = DynamicPPL.setleafcontext(model, DynamicPPL.DefaultContext()) eval_issuccess, eval_result = DynamicPPL.Experimental.is_suitable_varinfo( model, varinfo; only_ddpl ) @@ -36,9 +34,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet( end # Check type stability of initialisation (i.e. InitContext) - model = DynamicPPL.contextualize( - model, DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext()) - ) + model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext()) init_issuccess, init_result = DynamicPPL.Experimental.is_suitable_varinfo( model, varinfo; only_ddpl ) diff --git a/src/contexts.jl b/src/contexts.jl index 439da47e5..d95df9d2c 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -61,16 +61,17 @@ DynamicTransformationContext{true}() setchildcontext """ - leafcontext(context) + leafcontext(context::AbstractContext) Return the leaf of `context`, i.e. the first descendant context that `IsLeaf`. """ -leafcontext(context) = leafcontext(NodeTrait(leafcontext, context), context) +leafcontext(context::AbstractContext) = + leafcontext(NodeTrait(leafcontext, context), context) leafcontext(::IsLeaf, context) = context leafcontext(::IsParent, context) = leafcontext(childcontext(context)) """ - setleafcontext(left, right) + setleafcontext(left::AbstractContext, right::AbstractContext) Return `left` but now with its leaf context replaced by `right`. @@ -106,7 +107,7 @@ julia> # Append another parent context. ParentContext(ParentContext(ParentContext(DefaultContext()))) ``` """ -function setleafcontext(left, right) +function setleafcontext(left::AbstractContext, right::AbstractContext) return setleafcontext( NodeTrait(setleafcontext, left), NodeTrait(setleafcontext, right), left, right ) diff --git a/src/model.jl b/src/model.jl index a6a3e0685..9272f8c2c 100644 --- a/src/model.jl +++ b/src/model.jl @@ -94,6 +94,15 @@ with its underlying context set to `context`. function contextualize(model::Model, context::AbstractContext) return Model(model.f, model.args, model.defaults, context) end +""" + setleafcontext(model::Model, context::AbstractContext) + +Return a new `Model` with its leaf context set to `context`. This is a convenience +shortcut for `contextualize(model, setleafcontext(model.context, context)`). +""" +function setleafcontext(model::Model, context::AbstractContext) + return contextualize(model, setleafcontext(model.context, context)) +end """ model | (x = 1.0, ...) @@ -886,8 +895,7 @@ function init!!( varinfo::AbstractVarInfo, init_strategy::AbstractInitStrategy=InitFromPrior(), ) - new_context = setleafcontext(model.context, InitContext(rng, init_strategy)) - new_model = contextualize(model, new_context) + new_model = setleafcontext(model, InitContext(rng, init_strategy)) return evaluate!!(new_model, varinfo) end function init!!( diff --git a/src/sampler.jl b/src/sampler.jl index 8b49f6c3b..5bd09993c 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -51,7 +51,7 @@ end Define the initialisation strategy used for generating initial values when sampling with `sampler`. Defaults to `InitFromPrior()`, but can be overridden. """ -init_strategy(::Sampler) = InitFromPrior() +init_strategy(::AbstractSampler) = InitFromPrior() function AbstractMCMC.sample( rng::Random.AbstractRNG, @@ -59,12 +59,13 @@ function AbstractMCMC.sample( sampler::Sampler, N::Integer; chain_type=default_chain_type(sampler), + initial_params=init_strategy(sampler), resume_from=nothing, initial_state=loadstate(resume_from), kwargs..., ) return AbstractMCMC.mcmcsample( - rng, model, sampler, N; chain_type, initial_state, kwargs... + rng, model, sampler, N; chain_type, initial_params, initial_state, kwargs... ) end @@ -75,13 +76,23 @@ function AbstractMCMC.sample( parallel::AbstractMCMC.AbstractMCMCEnsemble, N::Integer, nchains::Integer; + initial_params=fill(init_strategy(sampler), nchains), chain_type=default_chain_type(sampler), resume_from=nothing, initial_state=loadstate(resume_from), kwargs..., ) return AbstractMCMC.mcmcsample( - rng, model, sampler, parallel, N, nchains; chain_type, initial_state, kwargs... + rng, + model, + sampler, + parallel, + N, + nchains; + chain_type, + initial_params, + initial_state, + kwargs..., ) end @@ -89,7 +100,7 @@ function AbstractMCMC.step( rng::Random.AbstractRNG, model::Model, spl::Sampler; - initial_params::AbstractInitStrategy=init_strategy(spl), + initial_params::AbstractInitStrategy, kwargs..., ) # Generate the default varinfo. Note that any parameters inside this varinfo diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index d53ba6c5f..aae2e4ec6 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -47,7 +47,7 @@ function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPP _, untyped_vi = DynamicPPL.init!!(model, DynamicPPL.VarInfo()) typed_vi = DynamicPPL.typed_varinfo(untyped_vi) # Set the test context as the new leaf context - new_model = contextualize(model, DynamicPPL.setleafcontext(model.context, context)) + new_model = DynamicPPL.setleafcontext(model, context) # Check that evaluation works for vi in [untyped_vi, typed_vi] _, vi = DynamicPPL.evaluate!!(new_model, vi) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 6ca3b9852..a2c6899c3 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -103,16 +103,12 @@ end # consistency between `vi.accs_by_thread` field and `getacc(vi.varinfo)`, which accumulates # to define `getacc(vi)`. function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - model = contextualize( - model, setleafcontext(model.context, DynamicTransformationContext{false}()) - ) + model = setleafcontext(model, DynamicTransformationContext{false}()) return settrans!!(last(evaluate!!(model, vi)), t) end function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - model = contextualize( - model, setleafcontext(model.context, DynamicTransformationContext{true}()) - ) + model = setleafcontext(model, DynamicTransformationContext{true}()) return settrans!!(last(evaluate!!(model, vi)), NoTransformation()) end diff --git a/src/transforming.jl b/src/transforming.jl index 589dca031..f9c55231d 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -59,7 +59,7 @@ function _transform!!( model::Model, ) # To transform using DynamicTransformationContext, we evaluate the model using that as the leaf context: - model = contextualize(model, setleafcontext(model.context, ctx)) + model = setleafcontext(model, ctx) vi = settrans!!(last(evaluate!!(model, vi)), t) return vi end diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index b34424a1c..8ed29e0c7 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -81,20 +81,14 @@ typed_vi = DynamicPPL.typed_varinfo(model) @info "Evaluating with DefaultContext:" - model = DynamicPPL.contextualize( - model, - DynamicPPL.setleafcontext(model.context, DynamicPPL.DefaultContext()), - ) + model = DynamicPPL.setleafcontext(model, DynamicPPL.DefaultContext()) f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( model, varinfo ) JET.test_call(f, argtypes) @info "Initialising with InitContext:" - model = DynamicPPL.contextualize( - model, - DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext()), - ) + model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext()) f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( model, varinfo )