diff --git a/HISTORY.md b/HISTORY.md index 3f20cfd2f..f69c4a6fd 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -56,6 +56,10 @@ Previously `VarInfo` (or more correctly, the `Metadata` object within a `VarInfo **Other changes** +### `setleafcontext(model, context)` + +This convenience method has been added to quickly modify the leaf context of a model. + ### Reimplementation of functions using `InitContext` A number of functions have been reimplemented and unified with the help of `InitContext`. diff --git a/ext/DynamicPPLJETExt.jl b/ext/DynamicPPLJETExt.jl index 55016d40c..e0163bb35 100644 --- a/ext/DynamicPPLJETExt.jl +++ b/ext/DynamicPPLJETExt.jl @@ -24,9 +24,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet( varinfo = DynamicPPL.typed_varinfo(model) # Check type stability of evaluation (i.e. DefaultContext) - model = DynamicPPL.contextualize( - model, DynamicPPL.setleafcontext(model.context, DynamicPPL.DefaultContext()) - ) + model = DynamicPPL.setleafcontext(model, DynamicPPL.DefaultContext()) eval_issuccess, eval_result = DynamicPPL.Experimental.is_suitable_varinfo( model, varinfo; only_ddpl ) @@ -36,9 +34,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet( end # Check type stability of initialisation (i.e. InitContext) - model = DynamicPPL.contextualize( - model, DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext()) - ) + model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext()) init_issuccess, init_result = DynamicPPL.Experimental.is_suitable_varinfo( model, varinfo; only_ddpl ) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index b3cf77121..7cc800dbb 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -830,8 +830,7 @@ end function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) # Note that in practice this method is only called for SimpleVarInfo, because VarInfo # has a dedicated implementation - ctx = DynamicTransformationContext{false}() - model = contextualize(model, setleafcontext(model.context, ctx)) + model = setleafcontext(model, DynamicTransformationContext{false}()) vi = last(evaluate!!(model, vi)) return settrans!!(vi, t) end @@ -893,8 +892,7 @@ end function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) # Note that in practice this method is only called for SimpleVarInfo, because VarInfo # has a dedicated implementation - ctx = DynamicTransformationContext{true}() - model = contextualize(model, setleafcontext(model.context, ctx)) + model = setleafcontext(model, DynamicTransformationContext{true}()) vi = last(evaluate!!(model, vi)) return settrans!!(vi, NoTransformation()) end diff --git a/src/contexts.jl b/src/contexts.jl index 70f99a73f..32a236e8e 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -58,16 +58,17 @@ DynamicTransformationContext{true}() setchildcontext """ - leafcontext(context) + leafcontext(context::AbstractContext) Return the leaf of `context`, i.e. the first descendant context that `IsLeaf`. """ -leafcontext(context) = leafcontext(NodeTrait(leafcontext, context), context) -leafcontext(::IsLeaf, context) = context -leafcontext(::IsParent, context) = leafcontext(childcontext(context)) +leafcontext(context::AbstractContext) = + leafcontext(NodeTrait(leafcontext, context), context) +leafcontext(::IsLeaf, context::AbstractContext) = context +leafcontext(::IsParent, context::AbstractContext) = leafcontext(childcontext(context)) """ - setleafcontext(left, right) + setleafcontext(left::AbstractContext, right::AbstractContext) Return `left` but now with its leaf context replaced by `right`. @@ -103,19 +104,21 @@ julia> # Append another parent context. ParentContext(ParentContext(ParentContext(DefaultContext()))) ``` """ -function setleafcontext(left, right) +function setleafcontext(left::AbstractContext, right::AbstractContext) return setleafcontext( NodeTrait(setleafcontext, left), NodeTrait(setleafcontext, right), left, right ) end -function setleafcontext(::IsParent, ::IsParent, left, right) +function setleafcontext( + ::IsParent, ::IsParent, left::AbstractContext, right::AbstractContext +) return setchildcontext(left, setleafcontext(childcontext(left), right)) end -function setleafcontext(::IsParent, ::IsLeaf, left, right) +function setleafcontext(::IsParent, ::IsLeaf, left::AbstractContext, right::AbstractContext) return setchildcontext(left, setleafcontext(childcontext(left), right)) end -setleafcontext(::IsLeaf, ::IsParent, left, right) = right -setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right +setleafcontext(::IsLeaf, ::IsParent, left::AbstractContext, right::AbstractContext) = right +setleafcontext(::IsLeaf, ::IsLeaf, left::AbstractContext, right::AbstractContext) = right """ DynamicPPL.tilde_assume!!( diff --git a/src/model.jl b/src/model.jl index a6a3e0685..6c7e8de94 100644 --- a/src/model.jl +++ b/src/model.jl @@ -95,6 +95,16 @@ function contextualize(model::Model, context::AbstractContext) return Model(model.f, model.args, model.defaults, context) end +""" + setleafcontext(model::Model, context::AbstractContext) + +Return a new `Model` with its leaf context set to `context`. This is a convenience shortcut +for `contextualize(model, setleafcontext(model.context, context)`). +""" +function setleafcontext(model::Model, context::AbstractContext) + return contextualize(model, setleafcontext(model.context, context)) +end + """ model | (x = 1.0, ...) @@ -886,8 +896,7 @@ function init!!( varinfo::AbstractVarInfo, init_strategy::AbstractInitStrategy=InitFromPrior(), ) - new_context = setleafcontext(model.context, InitContext(rng, init_strategy)) - new_model = contextualize(model, new_context) + new_model = setleafcontext(model, InitContext(rng, init_strategy)) return evaluate!!(new_model, varinfo) end function init!!( diff --git a/src/sampler.jl b/src/sampler.jl index 8b49f6c3b..c598e13f5 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -46,12 +46,12 @@ function default_varinfo(rng::Random.AbstractRNG, model::Model, ::AbstractSample end """ - init_strategy(sampler) + init_strategy(sampler::AbstractSampler) Define the initialisation strategy used for generating initial values when sampling with `sampler`. Defaults to `InitFromPrior()`, but can be overridden. """ -init_strategy(::Sampler) = InitFromPrior() +init_strategy(::AbstractSampler) = InitFromPrior() function AbstractMCMC.sample( rng::Random.AbstractRNG, @@ -60,11 +60,15 @@ function AbstractMCMC.sample( N::Integer; chain_type=default_chain_type(sampler), resume_from=nothing, + initial_params=init_strategy(sampler), initial_state=loadstate(resume_from), kwargs..., ) + if hasproperty(kwargs, :initial_parameters) + @warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead." + end return AbstractMCMC.mcmcsample( - rng, model, sampler, N; chain_type, initial_state, kwargs... + rng, model, sampler, N; chain_type, initial_params, initial_state, kwargs... ) end @@ -76,12 +80,25 @@ function AbstractMCMC.sample( N::Integer, nchains::Integer; chain_type=default_chain_type(sampler), + initial_params=fill(init_strategy(sampler), nchains), resume_from=nothing, initial_state=loadstate(resume_from), kwargs..., ) + if hasproperty(kwargs, :initial_parameters) + @warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead." + end return AbstractMCMC.mcmcsample( - rng, model, sampler, parallel, N, nchains; chain_type, initial_state, kwargs... + rng, + model, + sampler, + parallel, + N, + nchains; + chain_type, + initial_params, + initial_state, + kwargs..., ) end diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index d53ba6c5f..aae2e4ec6 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -47,7 +47,7 @@ function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPP _, untyped_vi = DynamicPPL.init!!(model, DynamicPPL.VarInfo()) typed_vi = DynamicPPL.typed_varinfo(untyped_vi) # Set the test context as the new leaf context - new_model = contextualize(model, DynamicPPL.setleafcontext(model.context, context)) + new_model = DynamicPPL.setleafcontext(model, context) # Check that evaluation works for vi in [untyped_vi, typed_vi] _, vi = DynamicPPL.evaluate!!(new_model, vi) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index f89a562e3..e86a4c4ae 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -103,16 +103,12 @@ end # consistency between `vi.accs_by_thread` field and `getacc(vi.varinfo)`, which accumulates # to define `getacc(vi)`. function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - model = contextualize( - model, setleafcontext(model.context, DynamicTransformationContext{false}()) - ) + model = setleafcontext(model, DynamicTransformationContext{false}()) return settrans!!(last(evaluate!!(model, vi)), t) end function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - model = contextualize( - model, setleafcontext(model.context, DynamicTransformationContext{true}()) - ) + model = setleafcontext(model, DynamicTransformationContext{true}()) return settrans!!(last(evaluate!!(model, vi)), NoTransformation()) end diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index b34424a1c..8ed29e0c7 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -81,20 +81,14 @@ typed_vi = DynamicPPL.typed_varinfo(model) @info "Evaluating with DefaultContext:" - model = DynamicPPL.contextualize( - model, - DynamicPPL.setleafcontext(model.context, DynamicPPL.DefaultContext()), - ) + model = DynamicPPL.setleafcontext(model, DynamicPPL.DefaultContext()) f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( model, varinfo ) JET.test_call(f, argtypes) @info "Initialising with InitContext:" - model = DynamicPPL.contextualize( - model, - DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext()), - ) + model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext()) f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( model, varinfo )