diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index edf44439e..b1b3bc3d9 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -174,7 +174,11 @@ abstract type AbstractVarInfo <: AbstractModelTrace end include("utils.jl") include("chains.jl") include("contexts.jl") +include("contexts/default.jl") include("contexts/init.jl") +include("contexts/transformation.jl") +include("contexts/prefix.jl") +include("contexts/conditionfix.jl") # Must come after contexts/prefix.jl include("model.jl") include("sampler.jl") include("varname.jl") @@ -187,10 +191,8 @@ include("abstract_varinfo.jl") include("threadsafe.jl") include("varinfo.jl") include("simple_varinfo.jl") -include("context_implementations.jl") include("compiler.jl") include("pointwise_logdensities.jl") -include("transforming.jl") include("logdensityfunction.jl") include("model_utils.jl") include("extract_priors.jl") diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index ac841baab..b3cf77121 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -827,6 +827,27 @@ end function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return link!!(default_transformation(model, vi), vi, vns, model) end +function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) + # Note that in practice this method is only called for SimpleVarInfo, because VarInfo + # has a dedicated implementation + ctx = DynamicTransformationContext{false}() + model = contextualize(model, setleafcontext(model.context, ctx)) + vi = last(evaluate!!(model, vi)) + return settrans!!(vi, t) +end +function link!!( + t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model +) + b = inverse(t.bijector) + x = vi[:] + y, logjac = with_logabsdet_jacobian(b, x) + # Set parameters and add the logjac term. + vi = unflatten(vi, y) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, logjac) + end + return settrans!!(vi, t) +end """ link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) @@ -846,6 +867,9 @@ end function link(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return link(default_transformation(model, vi), vi, vns, model) end +function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) + return link!!(t, deepcopy(vi), model) +end """ invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) @@ -866,23 +890,14 @@ end function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return invlink!!(default_transformation(model, vi), vi, vns, model) end - -# Vector-based ones. -function link!!( - t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model -) - b = inverse(t.bijector) - x = vi[:] - y, logjac = with_logabsdet_jacobian(b, x) - - # Set parameters and add the logjac term. - vi = unflatten(vi, y) - if hasacc(vi, Val(:LogJacobian)) - vi = acclogjac!!(vi, logjac) - end - return settrans!!(vi, t) +function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) + # Note that in practice this method is only called for SimpleVarInfo, because VarInfo + # has a dedicated implementation + ctx = DynamicTransformationContext{true}() + model = contextualize(model, setleafcontext(model.context, ctx)) + vi = last(evaluate!!(model, vi)) + return settrans!!(vi, NoTransformation()) end - function invlink!!( t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model ) @@ -919,6 +934,9 @@ end function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return invlink(default_transformation(model, vi), vi, vns, model) end +function invlink(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) + return invlink!!(t, deepcopy(vi), model) +end """ maybe_invlink_before_eval!!([t::Transformation,] vi, model) diff --git a/src/context_implementations.jl b/src/context_implementations.jl deleted file mode 100644 index a8f2d57e6..000000000 --- a/src/context_implementations.jl +++ /dev/null @@ -1,128 +0,0 @@ -""" - DynamicPPL.tilde_assume!!( - context::AbstractContext, - right::Distribution, - vn::VarName, - vi::AbstractVarInfo - ) - -Handle assumed variables, i.e. anything which is not observed (see -[`tilde_observe!!`](@ref)). Accumulate the associated log probability, and return the -sampled value and updated `vi`. - -`vn` is the VarName on the left-hand side of the tilde statement. -""" -function tilde_assume!!( - context::AbstractContext, right::Distribution, vn::VarName, vi::AbstractVarInfo -) - return tilde_assume!!(childcontext(context), right, vn, vi) -end -function tilde_assume!!( - ::DefaultContext, right::Distribution, vn::VarName, vi::AbstractVarInfo -) - y = getindex_internal(vi, vn) - f = from_maybe_linked_internal_transform(vi, vn, right) - x, inv_logjac = with_logabsdet_jacobian(f, y) - vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) - return x, vi -end -function tilde_assume!!( - context::PrefixContext, right::Distribution, vn::VarName, vi::AbstractVarInfo -) - # Note that we can't use something like this here: - # new_vn = prefix(context, vn) - # return tilde_assume!!(childcontext(context), right, new_vn, vi) - # This is because `prefix` applies _all_ prefixes in a given context to a - # variable name. Thus, if we had two levels of nested prefixes e.g. - # `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the - # first call would apply the prefix `a.b._`, and the recursive call - # would apply the prefix `b._`, resulting in `b.a.b._`. - # This is why we need a special function, `prefix_and_strip_contexts`. - new_vn, new_context = prefix_and_strip_contexts(context, vn) - return tilde_assume!!(new_context, right, new_vn, vi) -end -""" - DynamicPPL.tilde_assume!!( - context::AbstractContext, - right::DynamicPPL.Submodel, - vn::VarName, - vi::AbstractVarInfo - ) - -Evaluate the submodel with the given context. -""" -function tilde_assume!!( - context::AbstractContext, right::DynamicPPL.Submodel, vn::VarName, vi::AbstractVarInfo -) - return _evaluate!!(right, vi, context, vn) -end - -""" - tilde_observe!!( - context::AbstractContext, - right::Distribution, - left, - vn::Union{VarName, Nothing}, - vi::AbstractVarInfo - ) - -This function handles observed variables, which may be: - -- literals on the left-hand side, e.g., `3.0 ~ Normal()` -- a model input, e.g. `x ~ Normal()` in a model `@model f(x) ... end` -- a conditioned or fixed variable, e.g. `x ~ Normal()` in a model `model | (; x = 3.0)`. - -The relevant log-probability associated with the observation is computed and accumulated in -the VarInfo object `vi` (except for fixed variables, which do not contribute to the -log-probability). - -`left` is the actual value that the left-hand side evaluates to. `vn` is the VarName on the -left-hand side, or `nothing` if the left-hand side is a literal value. - -Observations of submodels are not yet supported in DynamicPPL. -""" -function tilde_observe!!( - context::AbstractContext, - right::Distribution, - left, - vn::Union{VarName,Nothing}, - vi::AbstractVarInfo, -) - return tilde_observe!!(childcontext(context), right, left, vn, vi) -end -function tilde_observe!!( - context::PrefixContext, - right::Distribution, - left, - vn::Union{VarName,Nothing}, - vi::AbstractVarInfo, -) - # In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal - # value. For the need for prefix_and_strip_contexts rather than just prefix, see the - # comment in `tilde_assume!!`. - new_vn, new_context = if vn !== nothing - prefix_and_strip_contexts(context, vn) - else - vn, childcontext(context) - end - return tilde_observe!!(new_context, right, left, new_vn, vi) -end -function tilde_observe!!( - ::DefaultContext, - right::Distribution, - left, - vn::Union{VarName,Nothing}, - vi::AbstractVarInfo, -) - vi = accumulate_observe!!(vi, right, left, vn) - return left, vi -end -function tilde_observe!!( - ::AbstractContext, - ::DynamicPPL.Submodel, - left, - vn::Union{VarName,Nothing}, - ::AbstractVarInfo, -) - throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) -end diff --git a/src/contexts.jl b/src/contexts.jl index 439da47e5..70f99a73f 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,6 +1,3 @@ -# Fallback traits -# TODO: Should this instead be `NoChildren()`, `HasChild()`, etc. so we allow plural too, e.g. `HasChildren()`? - """ NodeTrait(context) NodeTrait(f, context) @@ -120,559 +117,62 @@ end setleafcontext(::IsLeaf, ::IsParent, left, right) = right setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right -# Contexts -""" - struct DefaultContext <: AbstractContext end - -The `DefaultContext` is used by default to accumulate values like the log joint probability -when running the model. -""" -struct DefaultContext <: AbstractContext end -NodeTrait(::DefaultContext) = IsLeaf() - -""" - PrefixContext(vn::VarName[, context::AbstractContext]) - PrefixContext(vn::Val{sym}[, context::AbstractContext]) where {sym} - -Create a context that allows you to use the wrapped `context` when running the model and -prefixes all parameters with the VarName `vn`. - -`PrefixContext(Val(:a), context)` is equivalent to `PrefixContext(@varname(a), context)`. -If `context` is not provided, it defaults to `DefaultContext()`. - -This context is useful in nested models to ensure that the names of the parameters are -unique. - -See also: [`to_submodel`](@ref) -""" -struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext - vn_prefix::Tvn - context::C -end -PrefixContext(vn::VarName) = PrefixContext(vn, DefaultContext()) -function PrefixContext(::Val{sym}, context::AbstractContext) where {sym} - return PrefixContext(VarName{sym}(), context) -end -PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}()) - -NodeTrait(::PrefixContext) = IsParent() -childcontext(context::PrefixContext) = context.context -function setchildcontext(ctx::PrefixContext, child::AbstractContext) - return PrefixContext(ctx.vn_prefix, child) -end - -""" - prefix(ctx::AbstractContext, vn::VarName) - -Apply the prefixes in the context `ctx` to the variable name `vn`. -""" -function prefix(ctx::PrefixContext, vn::VarName) - return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix) -end -function prefix(ctx::AbstractContext, vn::VarName) - return prefix(NodeTrait(ctx), ctx, vn) -end -prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn -function prefix(::IsParent, ctx::AbstractContext, vn::VarName) - return prefix(childcontext(ctx), vn) -end - """ - prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) - -Same as `prefix`, but additionally returns a new context stack that has all the -PrefixContexts removed. - -NOTE: This does _not_ modify any variables in any `ConditionContext` and -`FixedContext` that may be present in the context stack. This is because this -function is only used in `tilde_assume!!`, which is lower in the tilde-pipeline -than `contextual_isassumption` and `contextual_isfixed` (the functions which -actually use the `ConditionContext` and `FixedContext` values). Thus, by this -time, any `ConditionContext`s and `FixedContext`s present have already served -their purpose. - -If you call this function, you must therefore be careful to ensure that you _do -not_ need to modify any inner `ConditionContext`s and `FixedContext`s. If you -_do_ need to modify them, then you may need to use -`prefix_cond_and_fixed_variables` instead. -""" -function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) - child_context = childcontext(ctx) - # vn_prefixed contains the prefixes from all lower levels - vn_prefixed, child_context_without_prefixes = prefix_and_strip_contexts( - child_context, vn + DynamicPPL.tilde_assume!!( + context::AbstractContext, + right::Distribution, + vn::VarName, + vi::AbstractVarInfo ) - return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes -end -function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) - return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn) -end -prefix_and_strip_contexts(::IsLeaf, ctx::AbstractContext, vn::VarName) = (vn, ctx) -function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName) - vn, new_ctx = prefix_and_strip_contexts(childcontext(ctx), vn) - return vn, setchildcontext(ctx, new_ctx) -end - -""" - - ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext} - -Model context that contains values that are to be conditioned on. The values -can either be a NamedTuple mapping symbols to values, such as `(a=1, b=2)`, or -an AbstractDict mapping varnames to values (e.g. `Dict(@varname(a) => 1, -@varname(b) => 2)`). The former is more performant, but the latter must be used -when there are varnames that cannot be represented as symbols, e.g. -`@varname(x[1])`. -""" -struct ConditionContext{ - Values<:Union{NamedTuple,AbstractDict{<:VarName}},Ctx<:AbstractContext -} <: AbstractContext - values::Values - context::Ctx -end - -const NamedConditionContext{Names} = ConditionContext{<:NamedTuple{Names}} -const DictConditionContext = ConditionContext{<:AbstractDict} - -# Use DefaultContext as the default base context -function ConditionContext(values::Union{NamedTuple,AbstractDict}) - return ConditionContext(values, DefaultContext()) -end -# Optimisation when there are no values to condition on -ConditionContext(::NamedTuple{()}, context::AbstractContext) = context -# Same as above, and avoids method ambiguity with below -ConditionContext(::NamedTuple{()}, context::NamedConditionContext) = context -# Collapse consecutive levels of `ConditionContext`. Note that this overrides -# values inside the child context, thus giving precedence to the outermost -# `ConditionContext`. -function ConditionContext(values::NamedTuple, context::NamedConditionContext) - return ConditionContext(merge(context.values, values), childcontext(context)) -end -function ConditionContext(values::AbstractDict{<:VarName}, context::DictConditionContext) - return ConditionContext(merge(context.values, values), childcontext(context)) -end - -function Base.show(io::IO, context::ConditionContext) - return print(io, "ConditionContext($(context.values), $(childcontext(context)))") -end - -NodeTrait(::ConditionContext) = IsParent() -childcontext(context::ConditionContext) = context.context -setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) - -""" - hasconditioned(context::AbstractContext, vn::VarName) - -Return `true` if `vn` is found in `context`. -""" -hasconditioned(context::AbstractContext, vn::VarName) = false -hasconditioned(context::ConditionContext, vn::VarName) = hasvalue(context.values, vn) -function hasconditioned(context::ConditionContext, vns::AbstractArray{<:VarName}) - return all(Base.Fix1(hasvalue, context.values), vns) -end - -""" - getconditioned(context::AbstractContext, vn::VarName) - -Return value of `vn` in `context`. -""" -function getconditioned(context::AbstractContext, vn::VarName) - return error("context $(context) does not contain value for $vn") -end -function getconditioned(context::ConditionContext, vn::VarName) - return getvalue(context.values, vn) -end - -""" - hasconditioned_nested(context, vn) - -Return `true` if `vn` is found in `context` or any of its descendants. - -This is contrast to [`hasconditioned(::AbstractContext, ::VarName)`](@ref) which only checks -for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. -""" -function hasconditioned_nested(context::AbstractContext, vn) - return hasconditioned_nested(NodeTrait(hasconditioned_nested, context), context, vn) -end -hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn) -function hasconditioned_nested(::IsParent, context, vn) - return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) -end -function hasconditioned_nested(context::PrefixContext, vn) - return hasconditioned_nested(collapse_prefix_stack(context), vn) -end - -""" - getconditioned_nested(context, vn) - -Return the value of the parameter corresponding to `vn` from `context` or its descendants. - -This is contrast to [`getconditioned`](@ref) which only returns the value `vn` in `context`, -not recursively looking into its descendants. -""" -function getconditioned_nested(context::AbstractContext, vn) - return getconditioned_nested(NodeTrait(getconditioned_nested, context), context, vn) -end -function getconditioned_nested(::IsLeaf, context, vn) - return error("context $(context) does not contain value for $vn") -end -function getconditioned_nested(context::PrefixContext, vn) - return getconditioned_nested(collapse_prefix_stack(context), vn) -end -function getconditioned_nested(::IsParent, context, vn) - return if hasconditioned(context, vn) - getconditioned(context, vn) - else - getconditioned_nested(childcontext(context), vn) - end -end - -""" - decondition(context::AbstractContext, syms...) - -Return `context` but with `syms` no longer conditioned on. - -Note that this recursively traverses contexts, deconditioning all along the way. - -See also: [`condition`](@ref) -""" -decondition_context(::IsLeaf, context, args...) = context -function decondition_context(::IsParent, context, args...) - return setchildcontext(context, decondition_context(childcontext(context), args...)) -end -function decondition_context(context, args...) - return decondition_context(NodeTrait(context), context, args...) -end -function decondition_context(context::ConditionContext) - return decondition_context(childcontext(context)) -end -function decondition_context(context::ConditionContext, sym, syms...) - new_values = deepcopy(context.values) - for s in (sym, syms...) - new_values = BangBang.delete!!(new_values, s) - end - return if length(new_values) == 0 - # No more values left, can unwrap - decondition_context(childcontext(context), syms...) - else - ConditionContext( - new_values, decondition_context(childcontext(context), sym, syms...) - ) - end -end -function decondition_context(context::NamedConditionContext, vn::VarName{sym}) where {sym} - return ConditionContext( - BangBang.delete!!(context.values, sym), - decondition_context(childcontext(context), vn), - ) -end - -""" - conditioned(context::AbstractContext) - -Return `NamedTuple` of values that are conditioned on under context`. - -Note that this will recursively traverse the context stack and return -a merged version of the condition values. -""" -function conditioned(context::AbstractContext) - return conditioned(NodeTrait(conditioned, context), context) -end -conditioned(::IsLeaf, context) = NamedTuple() -conditioned(::IsParent, context) = conditioned(childcontext(context)) -function conditioned(context::ConditionContext) - # Note the order of arguments to `merge`. The behavior of the rest of DPPL - # is that the outermost `context` takes precendence, hence when resolving - # the `conditioned` variables we need to ensure that `context.values` takes - # precedence over decendants of `context`. - return _merge(context.values, conditioned(childcontext(context))) -end -function conditioned(context::PrefixContext) - return conditioned(collapse_prefix_stack(context)) -end -struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext - values::Values - context::Ctx -end - -const NamedFixedContext{Names} = FixedContext{<:NamedTuple{Names}} -const DictFixedContext = FixedContext{<:AbstractDict} - -FixedContext(values) = FixedContext(values, DefaultContext()) - -# Try to avoid nested `FixedContext`. -function FixedContext(values::NamedTuple, context::NamedFixedContext) - # Note that this potentially overrides values from `context`, thus giving - # precedence to the outmost `FixedContext`. - return FixedContext(merge(context.values, values), childcontext(context)) -end - -function Base.show(io::IO, context::FixedContext) - return print(io, "FixedContext($(context.values), $(childcontext(context)))") -end - -NodeTrait(::FixedContext) = IsParent() -childcontext(context::FixedContext) = context.context -setchildcontext(parent::FixedContext, child) = FixedContext(parent.values, child) - -""" - hasfixed(context::AbstractContext, vn::VarName) +Handle assumed variables, i.e. anything which is not observed (see +[`tilde_observe!!`](@ref)). Accumulate the associated log probability, and return the +sampled value and updated `vi`. -Return `true` if a fixed value for `vn` is found in `context`. -""" -hasfixed(context::AbstractContext, vn::VarName) = false -hasfixed(context::FixedContext, vn::VarName) = hasvalue(context.values, vn) -function hasfixed(context::FixedContext, vns::AbstractArray{<:VarName}) - return all(Base.Fix1(hasvalue, context.values), vns) -end +`vn` is the VarName on the left-hand side of the tilde statement. +This function should return a tuple `(x, vi)`, where `x` is the sampled value (which +must be in unlinked space!) and `vi` is the updated VarInfo. """ - getfixed(context::AbstractContext, vn::VarName) - -Return the fixed value of `vn` in `context`. -""" -function getfixed(context::AbstractContext, vn::VarName) - return error("context $(context) does not contain value for $vn") -end -getfixed(context::FixedContext, vn::VarName) = getvalue(context.values, vn) - -""" - hasfixed_nested(context, vn) - -Return `true` if a fixed value for `vn` is found in `context` or any of its descendants. - -This is contrast to [`hasfixed(::AbstractContext, ::VarName)`](@ref) which only checks -for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. -""" -function hasfixed_nested(context::AbstractContext, vn) - return hasfixed_nested(NodeTrait(hasfixed_nested, context), context, vn) -end -hasfixed_nested(::IsLeaf, context, vn) = hasfixed(context, vn) -function hasfixed_nested(::IsParent, context, vn) - return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn) -end -function hasfixed_nested(context::PrefixContext, vn) - return hasfixed_nested(collapse_prefix_stack(context), vn) -end - -""" - getfixed_nested(context, vn) - -Return the fixed value of the parameter corresponding to `vn` from `context` or its descendants. - -This is contrast to [`getfixed`](@ref) which only returns the value `vn` in `context`, -not recursively looking into its descendants. -""" -function getfixed_nested(context::AbstractContext, vn) - return getfixed_nested(NodeTrait(getfixed_nested, context), context, vn) -end -function getfixed_nested(::IsLeaf, context, vn) - return error("context $(context) does not contain value for $vn") -end -function getfixed_nested(context::PrefixContext, vn) - return getfixed_nested(collapse_prefix_stack(context), vn) -end -function getfixed_nested(::IsParent, context, vn) - return if hasfixed(context, vn) - getfixed(context, vn) - else - getfixed_nested(childcontext(context), vn) - end -end - -""" - fix([context::AbstractContext,] values::NamedTuple) - fix([context::AbstractContext]; values...) - -Return `FixedContext` with `values` and `context` if `values` is non-empty, -otherwise return `context` which is [`DefaultContext`](@ref) by default. - -See also: [`unfix`](@ref) -""" -fix(; values...) = fix(NamedTuple(values)) -fix(values::NamedTuple) = fix(DefaultContext(), values) -function fix(value::Pair{<:VarName}, values::Pair{<:VarName}...) - return fix((value, values...)) -end -function fix(values::NTuple{<:Any,<:Pair{<:VarName}}) - return fix(DefaultContext(), values) -end -fix(context::AbstractContext, values::NamedTuple{()}) = context -function fix(context::AbstractContext, values::Union{AbstractDict,NamedTuple}) - return FixedContext(values, context) -end -function fix(context::AbstractContext; values...) - return fix(context, NamedTuple(values)) -end -function fix(context::AbstractContext, value::Pair{<:VarName}, values::Pair{<:VarName}...) - return fix(context, (value, values...)) -end -function fix(context::AbstractContext, values::NTuple{<:Any,Pair{<:VarName}}) - return fix(context, Dict(values)) +function tilde_assume!!( + context::AbstractContext, right::Distribution, vn::VarName, vi::AbstractVarInfo +) + return tilde_assume!!(childcontext(context), right, vn, vi) end """ - unfix(context::AbstractContext, syms...) - -Return `context` but with `syms` no longer fixed. - -Note that this recursively traverses contexts, unfixing all along the way. - -See also: [`fix`](@ref) -""" -unfix(::IsLeaf, context, args...) = context -function unfix(::IsParent, context, args...) - return setchildcontext(context, unfix(childcontext(context), args...)) -end -function unfix(context, args...) - return unfix(NodeTrait(context), context, args...) -end -function unfix(context::FixedContext) - return unfix(childcontext(context)) -end -function unfix(context::FixedContext, sym) - return fix(unfix(childcontext(context), sym), BangBang.delete!!(context.values, sym)) -end -function unfix(context::FixedContext, sym, syms...) - return unfix( - fix(unfix(childcontext(context), syms...), BangBang.delete!!(context.values, sym)), - syms..., + DynamicPPL.tilde_observe!!( + context::AbstractContext, + right::Distribution, + left, + vn::Union{VarName, Nothing}, + vi::AbstractVarInfo ) -end - -function unfix(context::NamedFixedContext, vn::VarName{sym}) where {sym} - return fix(unfix(childcontext(context), vn), BangBang.delete!!(context.values, sym)) -end -function unfix(context::FixedContext, vn::VarName) - return fix(unfix(childcontext(context), vn), BangBang.delete!!(context.values, vn)) -end - -""" - fixed(context::AbstractContext) - -Return the values that are fixed under `context`. - -Note that this will recursively traverse the context stack and return -a merged version of the fix values. -""" -fixed(context::AbstractContext) = fixed(NodeTrait(fixed, context), context) -fixed(::IsLeaf, context) = NamedTuple() -fixed(::IsParent, context) = fixed(childcontext(context)) -function fixed(context::FixedContext) - # Note the order of arguments to `merge`. The behavior of the rest of DPPL - # is that the outermost `context` takes precendence, hence when resolving - # the `fixed` variables we need to ensure that `context.values` takes - # precedence over decendants of `context`. - return _merge(context.values, fixed(childcontext(context))) -end -function fixed(context::PrefixContext) - return fixed(collapse_prefix_stack(context)) -end - -""" - collapse_prefix_stack(context::AbstractContext) - -Apply `PrefixContext`s to any conditioned or fixed values inside them, and remove -the `PrefixContext`s from the context stack. - -!!! note - If you are reading this docstring, you might probably be interested in a more -thorough explanation of how PrefixContext and ConditionContext / FixedContext -interact with one another, especially in the context of submodels. - The DynamicPPL documentation contains [a separate page on this -topic](https://turinglang.org/DynamicPPL.jl/previews/PR892/internals/submodel_condition/) -which explains this in much more detail. - -```jldoctest -julia> using DynamicPPL: collapse_prefix_stack - -julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, ))); - -julia> collapse_prefix_stack(c1) -ConditionContext(Dict(a.x => 1), DefaultContext()) -julia> # Here, `x` gets prefixed only with `a`, whereas `y` is prefixed with both. - c2 = PrefixContext(@varname(a), ConditionContext((x=1, ), PrefixContext(@varname(b), ConditionContext((y=2,))))); +This function handles observed variables, which may be: -julia> collapsed = collapse_prefix_stack(c2); - -julia> # `collapsed` really looks something like this: - # ConditionContext(Dict{VarName{:a}, Int64}(a.b.y => 2, a.x => 1), DefaultContext()) - # To avoid fragility arising from the order of the keys in the doctest, we test - # this indirectly: - collapsed.values[@varname(a.x)], collapsed.values[@varname(a.b.y)] -(1, 2) -``` -""" -function collapse_prefix_stack(context::PrefixContext) - # Collapse the child context (thus applying any inner prefixes first) - collapsed = collapse_prefix_stack(childcontext(context)) - # Prefix any conditioned variables with the current prefix - # Note: prefix_conditioned_variables is O(N) in the depth of the context stack. - # So is this function. In the worst case scenario, this is O(N^2) in the - # depth of the context stack. - return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix) -end -function collapse_prefix_stack(context::AbstractContext) - return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context) -end -collapse_prefix_stack(::IsLeaf, context) = context -function collapse_prefix_stack(::IsParent, context) - new_child_context = collapse_prefix_stack(childcontext(context)) - return setchildcontext(context, new_child_context) -end - -""" - prefix_cond_and_fixed_variables(context::AbstractContext, prefix::VarName) +- literals on the left-hand side, e.g., `3.0 ~ Normal()` +- a model input, e.g. `x ~ Normal()` in a model `@model f(x) ... end` +- a conditioned or fixed variable, e.g. `x ~ Normal()` in a model `model | (; x = 3.0)`. -Prefix all the conditioned and fixed variables in a given context with a single -`prefix`. +The relevant log-probability associated with the observation is computed and accumulated in +the VarInfo object `vi` (except for fixed variables, which do not contribute to the +log-probability). -```jldoctest -julia> using DynamicPPL: prefix_cond_and_fixed_variables, ConditionContext +`left` is the actual value that the left-hand side evaluates to. `vn` is the VarName on the +left-hand side, or `nothing` if the left-hand side is a literal value. -julia> c1 = ConditionContext((a=1, )) -ConditionContext((a = 1,), DefaultContext()) +Observations of submodels are not yet supported in DynamicPPL. -julia> prefix_cond_and_fixed_variables(c1, @varname(y)) -ConditionContext(Dict(y.a => 1), DefaultContext()) -``` +This function should return a tuple `(left, vi)`, where `left` is the same as the input, and +`vi` is the updated VarInfo. """ -function prefix_cond_and_fixed_variables(ctx::ConditionContext, prefix::VarName) - # Replace the prefix of the conditioned variables - vn_dict = to_varname_dict(ctx.values) - prefixed_vn_dict = Dict( - AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict - ) - # Prefix the child context as well - prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) - return ConditionContext(prefixed_vn_dict, prefixed_child_ctx) -end -function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName) - # Replace the prefix of the conditioned variables - vn_dict = to_varname_dict(ctx.values) - prefixed_vn_dict = Dict( - AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict - ) - # Prefix the child context as well - prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) - return FixedContext(prefixed_vn_dict, prefixed_child_ctx) -end -function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName) - return prefix_cond_and_fixed_variables( - NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix - ) -end -function prefix_cond_and_fixed_variables( - ::IsLeaf, context::AbstractContext, prefix::VarName +function tilde_observe!!( + context::AbstractContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, ) - return context -end -function prefix_cond_and_fixed_variables( - ::IsParent, context::AbstractContext, prefix::VarName -) - return setchildcontext( - context, prefix_cond_and_fixed_variables(childcontext(context), prefix) - ) + return tilde_observe!!(childcontext(context), right, left, vn, vi) end diff --git a/src/contexts/conditionfix.jl b/src/contexts/conditionfix.jl new file mode 100644 index 000000000..d3802de85 --- /dev/null +++ b/src/contexts/conditionfix.jl @@ -0,0 +1,467 @@ +""" + + ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext} + +Model context that contains values that are to be conditioned on. The values +can either be a NamedTuple mapping symbols to values, such as `(a=1, b=2)`, or +an AbstractDict mapping varnames to values (e.g. `Dict(@varname(a) => 1, +@varname(b) => 2)`). The former is more performant, but the latter must be used +when there are varnames that cannot be represented as symbols, e.g. +`@varname(x[1])`. +""" +struct ConditionContext{ + Values<:Union{NamedTuple,AbstractDict{<:VarName}},Ctx<:AbstractContext +} <: AbstractContext + values::Values + context::Ctx +end + +const NamedConditionContext{Names} = ConditionContext{<:NamedTuple{Names}} +const DictConditionContext = ConditionContext{<:AbstractDict} + +# Use DefaultContext as the default base context +function ConditionContext(values::Union{NamedTuple,AbstractDict}) + return ConditionContext(values, DefaultContext()) +end +# Optimisation when there are no values to condition on +ConditionContext(::NamedTuple{()}, context::AbstractContext) = context +# Same as above, and avoids method ambiguity with below +ConditionContext(::NamedTuple{()}, context::NamedConditionContext) = context +# Collapse consecutive levels of `ConditionContext`. Note that this overrides +# values inside the child context, thus giving precedence to the outermost +# `ConditionContext`. +function ConditionContext(values::NamedTuple, context::NamedConditionContext) + return ConditionContext(merge(context.values, values), childcontext(context)) +end +function ConditionContext(values::AbstractDict{<:VarName}, context::DictConditionContext) + return ConditionContext(merge(context.values, values), childcontext(context)) +end + +function Base.show(io::IO, context::ConditionContext) + return print(io, "ConditionContext($(context.values), $(childcontext(context)))") +end + +NodeTrait(::ConditionContext) = IsParent() +childcontext(context::ConditionContext) = context.context +setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) + +""" + hasconditioned(context::AbstractContext, vn::VarName) + +Return `true` if `vn` is found in `context`. +""" +hasconditioned(context::AbstractContext, vn::VarName) = false +hasconditioned(context::ConditionContext, vn::VarName) = hasvalue(context.values, vn) +function hasconditioned(context::ConditionContext, vns::AbstractArray{<:VarName}) + return all(Base.Fix1(hasvalue, context.values), vns) +end + +""" + getconditioned(context::AbstractContext, vn::VarName) + +Return value of `vn` in `context`. +""" +function getconditioned(context::AbstractContext, vn::VarName) + return error("context $(context) does not contain value for $vn") +end +function getconditioned(context::ConditionContext, vn::VarName) + return getvalue(context.values, vn) +end + +""" + hasconditioned_nested(context, vn) + +Return `true` if `vn` is found in `context` or any of its descendants. + +This is contrast to [`hasconditioned(::AbstractContext, ::VarName)`](@ref) which only checks +for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. +""" +function hasconditioned_nested(context::AbstractContext, vn) + return hasconditioned_nested(NodeTrait(hasconditioned_nested, context), context, vn) +end +hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn) +function hasconditioned_nested(::IsParent, context, vn) + return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) +end +function hasconditioned_nested(context::PrefixContext, vn) + return hasconditioned_nested(collapse_prefix_stack(context), vn) +end + +""" + getconditioned_nested(context, vn) + +Return the value of the parameter corresponding to `vn` from `context` or its descendants. + +This is contrast to [`getconditioned`](@ref) which only returns the value `vn` in `context`, +not recursively looking into its descendants. +""" +function getconditioned_nested(context::AbstractContext, vn) + return getconditioned_nested(NodeTrait(getconditioned_nested, context), context, vn) +end +function getconditioned_nested(::IsLeaf, context, vn) + return error("context $(context) does not contain value for $vn") +end +function getconditioned_nested(context::PrefixContext, vn) + return getconditioned_nested(collapse_prefix_stack(context), vn) +end +function getconditioned_nested(::IsParent, context, vn) + return if hasconditioned(context, vn) + getconditioned(context, vn) + else + getconditioned_nested(childcontext(context), vn) + end +end + +""" + decondition(context::AbstractContext, syms...) + +Return `context` but with `syms` no longer conditioned on. + +Note that this recursively traverses contexts, deconditioning all along the way. + +See also: [`condition`](@ref) +""" +decondition_context(::IsLeaf, context, args...) = context +function decondition_context(::IsParent, context, args...) + return setchildcontext(context, decondition_context(childcontext(context), args...)) +end +function decondition_context(context, args...) + return decondition_context(NodeTrait(context), context, args...) +end +function decondition_context(context::ConditionContext) + return decondition_context(childcontext(context)) +end +function decondition_context(context::ConditionContext, sym, syms...) + new_values = deepcopy(context.values) + for s in (sym, syms...) + new_values = BangBang.delete!!(new_values, s) + end + return if length(new_values) == 0 + # No more values left, can unwrap + decondition_context(childcontext(context), syms...) + else + ConditionContext( + new_values, decondition_context(childcontext(context), sym, syms...) + ) + end +end +function decondition_context(context::NamedConditionContext, vn::VarName{sym}) where {sym} + return ConditionContext( + BangBang.delete!!(context.values, sym), + decondition_context(childcontext(context), vn), + ) +end + +""" + conditioned(context::AbstractContext) + +Return `NamedTuple` of values that are conditioned on under context`. + +Note that this will recursively traverse the context stack and return +a merged version of the condition values. +""" +function conditioned(context::AbstractContext) + return conditioned(NodeTrait(conditioned, context), context) +end +conditioned(::IsLeaf, context) = NamedTuple() +conditioned(::IsParent, context) = conditioned(childcontext(context)) +function conditioned(context::ConditionContext) + # Note the order of arguments to `merge`. The behavior of the rest of DPPL + # is that the outermost `context` takes precendence, hence when resolving + # the `conditioned` variables we need to ensure that `context.values` takes + # precedence over decendants of `context`. + return _merge(context.values, conditioned(childcontext(context))) +end +function conditioned(context::PrefixContext) + return conditioned(collapse_prefix_stack(context)) +end + +struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext + values::Values + context::Ctx +end + +const NamedFixedContext{Names} = FixedContext{<:NamedTuple{Names}} +const DictFixedContext = FixedContext{<:AbstractDict} + +FixedContext(values) = FixedContext(values, DefaultContext()) + +# Try to avoid nested `FixedContext`. +function FixedContext(values::NamedTuple, context::NamedFixedContext) + # Note that this potentially overrides values from `context`, thus giving + # precedence to the outmost `FixedContext`. + return FixedContext(merge(context.values, values), childcontext(context)) +end + +function Base.show(io::IO, context::FixedContext) + return print(io, "FixedContext($(context.values), $(childcontext(context)))") +end + +NodeTrait(::FixedContext) = IsParent() +childcontext(context::FixedContext) = context.context +setchildcontext(parent::FixedContext, child) = FixedContext(parent.values, child) + +""" + hasfixed(context::AbstractContext, vn::VarName) + +Return `true` if a fixed value for `vn` is found in `context`. +""" +hasfixed(context::AbstractContext, vn::VarName) = false +hasfixed(context::FixedContext, vn::VarName) = hasvalue(context.values, vn) +function hasfixed(context::FixedContext, vns::AbstractArray{<:VarName}) + return all(Base.Fix1(hasvalue, context.values), vns) +end + +""" + getfixed(context::AbstractContext, vn::VarName) + +Return the fixed value of `vn` in `context`. +""" +function getfixed(context::AbstractContext, vn::VarName) + return error("context $(context) does not contain value for $vn") +end +getfixed(context::FixedContext, vn::VarName) = getvalue(context.values, vn) + +""" + hasfixed_nested(context, vn) + +Return `true` if a fixed value for `vn` is found in `context` or any of its descendants. + +This is contrast to [`hasfixed(::AbstractContext, ::VarName)`](@ref) which only checks +for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. +""" +function hasfixed_nested(context::AbstractContext, vn) + return hasfixed_nested(NodeTrait(hasfixed_nested, context), context, vn) +end +hasfixed_nested(::IsLeaf, context, vn) = hasfixed(context, vn) +function hasfixed_nested(::IsParent, context, vn) + return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn) +end +function hasfixed_nested(context::PrefixContext, vn) + return hasfixed_nested(collapse_prefix_stack(context), vn) +end + +""" + getfixed_nested(context, vn) + +Return the fixed value of the parameter corresponding to `vn` from `context` or its descendants. + +This is contrast to [`getfixed`](@ref) which only returns the value `vn` in `context`, +not recursively looking into its descendants. +""" +function getfixed_nested(context::AbstractContext, vn) + return getfixed_nested(NodeTrait(getfixed_nested, context), context, vn) +end +function getfixed_nested(::IsLeaf, context, vn) + return error("context $(context) does not contain value for $vn") +end +function getfixed_nested(context::PrefixContext, vn) + return getfixed_nested(collapse_prefix_stack(context), vn) +end +function getfixed_nested(::IsParent, context, vn) + return if hasfixed(context, vn) + getfixed(context, vn) + else + getfixed_nested(childcontext(context), vn) + end +end + +""" + fix([context::AbstractContext,] values::NamedTuple) + fix([context::AbstractContext]; values...) + +Return `FixedContext` with `values` and `context` if `values` is non-empty, +otherwise return `context` which is [`DefaultContext`](@ref) by default. + +See also: [`unfix`](@ref) +""" +fix(; values...) = fix(NamedTuple(values)) +fix(values::NamedTuple) = fix(DefaultContext(), values) +function fix(value::Pair{<:VarName}, values::Pair{<:VarName}...) + return fix((value, values...)) +end +function fix(values::NTuple{<:Any,<:Pair{<:VarName}}) + return fix(DefaultContext(), values) +end +fix(context::AbstractContext, values::NamedTuple{()}) = context +function fix(context::AbstractContext, values::Union{AbstractDict,NamedTuple}) + return FixedContext(values, context) +end +function fix(context::AbstractContext; values...) + return fix(context, NamedTuple(values)) +end +function fix(context::AbstractContext, value::Pair{<:VarName}, values::Pair{<:VarName}...) + return fix(context, (value, values...)) +end +function fix(context::AbstractContext, values::NTuple{<:Any,Pair{<:VarName}}) + return fix(context, Dict(values)) +end + +""" + unfix(context::AbstractContext, syms...) + +Return `context` but with `syms` no longer fixed. + +Note that this recursively traverses contexts, unfixing all along the way. + +See also: [`fix`](@ref) +""" +unfix(::IsLeaf, context, args...) = context +function unfix(::IsParent, context, args...) + return setchildcontext(context, unfix(childcontext(context), args...)) +end +function unfix(context, args...) + return unfix(NodeTrait(context), context, args...) +end +function unfix(context::FixedContext) + return unfix(childcontext(context)) +end +function unfix(context::FixedContext, sym) + return fix(unfix(childcontext(context), sym), BangBang.delete!!(context.values, sym)) +end +function unfix(context::FixedContext, sym, syms...) + return unfix( + fix(unfix(childcontext(context), syms...), BangBang.delete!!(context.values, sym)), + syms..., + ) +end + +function unfix(context::NamedFixedContext, vn::VarName{sym}) where {sym} + return fix(unfix(childcontext(context), vn), BangBang.delete!!(context.values, sym)) +end +function unfix(context::FixedContext, vn::VarName) + return fix(unfix(childcontext(context), vn), BangBang.delete!!(context.values, vn)) +end + +""" + fixed(context::AbstractContext) + +Return the values that are fixed under `context`. + +Note that this will recursively traverse the context stack and return +a merged version of the fix values. +""" +fixed(context::AbstractContext) = fixed(NodeTrait(fixed, context), context) +fixed(::IsLeaf, context) = NamedTuple() +fixed(::IsParent, context) = fixed(childcontext(context)) +function fixed(context::FixedContext) + # Note the order of arguments to `merge`. The behavior of the rest of DPPL + # is that the outermost `context` takes precendence, hence when resolving + # the `fixed` variables we need to ensure that `context.values` takes + # precedence over decendants of `context`. + return _merge(context.values, fixed(childcontext(context))) +end +function fixed(context::PrefixContext) + return fixed(collapse_prefix_stack(context)) +end + +########################################################################### +### Interaction of PrefixContext with ConditionContext and FixedContext ### +########################################################################### + +""" + collapse_prefix_stack(context::AbstractContext) + +Apply `PrefixContext`s to any conditioned or fixed values inside them, and remove +the `PrefixContext`s from the context stack. + +!!! note + If you are reading this docstring, you might probably be interested in a more +thorough explanation of how PrefixContext and ConditionContext / FixedContext +interact with one another, especially in the context of submodels. + The DynamicPPL documentation contains [a separate page on this +topic](https://turinglang.org/DynamicPPL.jl/previews/PR892/internals/submodel_condition/) +which explains this in much more detail. + +```jldoctest +julia> using DynamicPPL: collapse_prefix_stack + +julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, ))); + +julia> collapse_prefix_stack(c1) +ConditionContext(Dict(a.x => 1), DefaultContext()) + +julia> # Here, `x` gets prefixed only with `a`, whereas `y` is prefixed with both. + c2 = PrefixContext(@varname(a), ConditionContext((x=1, ), PrefixContext(@varname(b), ConditionContext((y=2,))))); + +julia> collapsed = collapse_prefix_stack(c2); + +julia> # `collapsed` really looks something like this: + # ConditionContext(Dict{VarName{:a}, Int64}(a.b.y => 2, a.x => 1), DefaultContext()) + # To avoid fragility arising from the order of the keys in the doctest, we test + # this indirectly: + collapsed.values[@varname(a.x)], collapsed.values[@varname(a.b.y)] +(1, 2) +``` +""" +function collapse_prefix_stack(context::PrefixContext) + # Collapse the child context (thus applying any inner prefixes first) + collapsed = collapse_prefix_stack(childcontext(context)) + # Prefix any conditioned variables with the current prefix + # Note: prefix_conditioned_variables is O(N) in the depth of the context stack. + # So is this function. In the worst case scenario, this is O(N^2) in the + # depth of the context stack. + return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix) +end +function collapse_prefix_stack(context::AbstractContext) + return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context) +end +collapse_prefix_stack(::IsLeaf, context) = context +function collapse_prefix_stack(::IsParent, context) + new_child_context = collapse_prefix_stack(childcontext(context)) + return setchildcontext(context, new_child_context) +end + +""" + prefix_cond_and_fixed_variables(context::AbstractContext, prefix::VarName) + +Prefix all the conditioned and fixed variables in a given context with a single +`prefix`. + +```jldoctest +julia> using DynamicPPL: prefix_cond_and_fixed_variables, ConditionContext + +julia> c1 = ConditionContext((a=1, )) +ConditionContext((a = 1,), DefaultContext()) + +julia> prefix_cond_and_fixed_variables(c1, @varname(y)) +ConditionContext(Dict(y.a => 1), DefaultContext()) +``` +""" +function prefix_cond_and_fixed_variables(ctx::ConditionContext, prefix::VarName) + # Replace the prefix of the conditioned variables + vn_dict = to_varname_dict(ctx.values) + prefixed_vn_dict = Dict( + AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict + ) + # Prefix the child context as well + prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) + return ConditionContext(prefixed_vn_dict, prefixed_child_ctx) +end +function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName) + # Replace the prefix of the conditioned variables + vn_dict = to_varname_dict(ctx.values) + prefixed_vn_dict = Dict( + AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict + ) + # Prefix the child context as well + prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) + return FixedContext(prefixed_vn_dict, prefixed_child_ctx) +end +function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName) + return prefix_cond_and_fixed_variables( + NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix + ) +end +function prefix_cond_and_fixed_variables( + ::IsLeaf, context::AbstractContext, prefix::VarName +) + return context +end +function prefix_cond_and_fixed_variables( + ::IsParent, context::AbstractContext, prefix::VarName +) + return setchildcontext( + context, prefix_cond_and_fixed_variables(childcontext(context), prefix) + ) +end diff --git a/src/contexts/default.jl b/src/contexts/default.jl new file mode 100644 index 000000000..ec21e1a56 --- /dev/null +++ b/src/contexts/default.jl @@ -0,0 +1,60 @@ +""" + struct DefaultContext <: AbstractContext end + +`DefaultContext`, as the name suggests, is the default context used when instantiating a +model. + +```jldoctest +julia> @model f() = x ~ Normal(); + +julia> model = f(); model.context +DefaultContext() +``` + +As an evaluation context, the behaviour of `DefaultContext` is to require all variables to be +present in the `AbstractVarInfo` used for evaluation. Thus, semantically, evaluating a model +with `DefaultContext` means 'calculating the log-probability associated with the variables +in the `AbstractVarInfo`'. +""" +struct DefaultContext <: AbstractContext end +NodeTrait(::DefaultContext) = IsLeaf() + +""" + DynamicPPL.tilde_assume!!( + ::DefaultContext, right::Distribution, vn::VarName, vi::AbstractVarInfo + ) + +Handle assumed variables. For `DefaultContext`, this function extracts the value associated +with `vn` from `vi`, If `vi` does not contain an appropriate value then this will error. +""" +function tilde_assume!!( + ::DefaultContext, right::Distribution, vn::VarName, vi::AbstractVarInfo +) + y = getindex_internal(vi, vn) + f = from_maybe_linked_internal_transform(vi, vn, right) + x, inv_logjac = with_logabsdet_jacobian(f, y) + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) + return x, vi +end + +""" + DynamicPPL.tilde_observe!!( + ::DefaultContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, + ) + +Handle observed variables. This just accumulates the log-likelihood for `left`. +""" +function tilde_observe!!( + ::DefaultContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) + vi = accumulate_observe!!(vi, right, left, vn) + return left, vi +end diff --git a/src/contexts/prefix.jl b/src/contexts/prefix.jl new file mode 100644 index 000000000..24615e683 --- /dev/null +++ b/src/contexts/prefix.jl @@ -0,0 +1,116 @@ +""" + PrefixContext(vn::VarName[, context::AbstractContext]) + PrefixContext(vn::Val{sym}[, context::AbstractContext]) where {sym} + +Create a context that allows you to use the wrapped `context` when running the model and +prefixes all parameters with the VarName `vn`. + +`PrefixContext(Val(:a), context)` is equivalent to `PrefixContext(@varname(a), context)`. +If `context` is not provided, it defaults to `DefaultContext()`. + +This context is useful in nested models to ensure that the names of the parameters are +unique. + +See also: [`to_submodel`](@ref) +""" +struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext + vn_prefix::Tvn + context::C +end +PrefixContext(vn::VarName) = PrefixContext(vn, DefaultContext()) +function PrefixContext(::Val{sym}, context::AbstractContext) where {sym} + return PrefixContext(VarName{sym}(), context) +end +PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}()) + +NodeTrait(::PrefixContext) = IsParent() +childcontext(context::PrefixContext) = context.context +function setchildcontext(ctx::PrefixContext, child::AbstractContext) + return PrefixContext(ctx.vn_prefix, child) +end + +""" + prefix(ctx::AbstractContext, vn::VarName) + +Apply the prefixes in the context `ctx` to the variable name `vn`. +""" +function prefix(ctx::PrefixContext, vn::VarName) + return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix) +end +function prefix(ctx::AbstractContext, vn::VarName) + return prefix(NodeTrait(ctx), ctx, vn) +end +prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn +function prefix(::IsParent, ctx::AbstractContext, vn::VarName) + return prefix(childcontext(ctx), vn) +end + +""" + prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) + +Same as `prefix`, but additionally returns a new context stack that has all the +PrefixContexts removed. + +NOTE: This does _not_ modify any variables in any `ConditionContext` and +`FixedContext` that may be present in the context stack. This is because this +function is only used in `tilde_assume!!`, which is lower in the tilde-pipeline +than `contextual_isassumption` and `contextual_isfixed` (the functions which +actually use the `ConditionContext` and `FixedContext` values). Thus, by this +time, any `ConditionContext`s and `FixedContext`s present have already served +their purpose. + +If you call this function, you must therefore be careful to ensure that you _do +not_ need to modify any inner `ConditionContext`s and `FixedContext`s. If you +_do_ need to modify them, then you may need to use +`prefix_cond_and_fixed_variables` instead. +""" +function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) + child_context = childcontext(ctx) + # vn_prefixed contains the prefixes from all lower levels + vn_prefixed, child_context_without_prefixes = prefix_and_strip_contexts( + child_context, vn + ) + return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes +end +function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) + return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn) +end +prefix_and_strip_contexts(::IsLeaf, ctx::AbstractContext, vn::VarName) = (vn, ctx) +function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName) + vn, new_ctx = prefix_and_strip_contexts(childcontext(ctx), vn) + return vn, setchildcontext(ctx, new_ctx) +end + +function tilde_assume!!( + context::PrefixContext, right::Distribution, vn::VarName, vi::AbstractVarInfo +) + # Note that we can't use something like this here: + # new_vn = prefix(context, vn) + # return tilde_assume!!(childcontext(context), right, new_vn, vi) + # This is because `prefix` applies _all_ prefixes in a given context to a + # variable name. Thus, if we had two levels of nested prefixes e.g. + # `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the + # first call would apply the prefix `a.b._`, and the recursive call + # would apply the prefix `b._`, resulting in `b.a.b._`. + # This is why we need a special function, `prefix_and_strip_contexts`. + new_vn, new_context = prefix_and_strip_contexts(context, vn) + return tilde_assume!!(new_context, right, new_vn, vi) +end + +function tilde_observe!!( + context::PrefixContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) + # In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal + # value. For the need for prefix_and_strip_contexts rather than just prefix, see the + # comment in `tilde_assume!!`. + new_vn, new_context = if vn !== nothing + prefix_and_strip_contexts(context, vn) + else + vn, childcontext(context) + end + return tilde_observe!!(new_context, right, left, new_vn, vi) +end diff --git a/src/transforming.jl b/src/contexts/transformation.jl similarity index 61% rename from src/transforming.jl rename to src/contexts/transformation.jl index 589dca031..720fa978f 100644 --- a/src/transforming.jl +++ b/src/contexts/transformation.jl @@ -43,31 +43,3 @@ function tilde_observe!!( ) return tilde_observe!!(DefaultContext(), right, left, vn, vi) end - -function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return _transform!!(t, DynamicTransformationContext{false}(), vi, model) -end - -function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return _transform!!(NoTransformation(), DynamicTransformationContext{true}(), vi, model) -end - -function _transform!!( - t::AbstractTransformation, - ctx::DynamicTransformationContext, - vi::AbstractVarInfo, - model::Model, -) - # To transform using DynamicTransformationContext, we evaluate the model using that as the leaf context: - model = contextualize(model, setleafcontext(model.context, ctx)) - vi = settrans!!(last(evaluate!!(model, vi)), t) - return vi -end - -function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return link!!(t, deepcopy(vi), model) -end - -function invlink(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return invlink!!(t, deepcopy(vi), model) -end diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 2ec8b15a2..13124e3a7 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -1,7 +1,6 @@ module DebugUtils using ..DynamicPPL -using ..DynamicPPL: broadcast_safe, AbstractContext, childcontext using Random: Random using Accessors: Accessors diff --git a/src/submodel.jl b/src/submodel.jl index dcb107bb4..145bd42c9 100644 --- a/src/submodel.jl +++ b/src/submodel.jl @@ -8,6 +8,10 @@ struct Submodel{M,AutoPrefix} model::M end +# ---------------------- +# Constructing submodels +# ---------------------- + """ to_submodel(model::Model[, auto_prefix::Bool]) @@ -152,6 +156,26 @@ ERROR: ArgumentError: `x ~ to_submodel(...)` is not supported when `x` is observ """ to_submodel(m::Model, auto_prefix::Bool=true) = Submodel{typeof(m),auto_prefix}(m) +# --------------------------- +# Submodels in tilde-pipeline +# --------------------------- + +""" + DynamicPPL.tilde_assume!!( + context::AbstractContext, + right::DynamicPPL.Submodel, + vn::VarName, + vi::AbstractVarInfo + ) + +Evaluate the submodel with the given context. +""" +function tilde_assume!!( + context::AbstractContext, right::DynamicPPL.Submodel, vn::VarName, vi::AbstractVarInfo +) + return _evaluate!!(right, vi, context, vn) +end + # When automatic prefixing is used, the submodel itself doesn't carry the # prefix, as the prefix is obtained from the LHS of `~` (whereas the submodel # is on the RHS). The prefix can only be obtained in `tilde_assume!!`, and then @@ -193,3 +217,13 @@ function _evaluate!!( # returns a tuple of submodel.model's return value and the new varinfo. return _evaluate!!(model, vi) end + +function tilde_observe!!( + ::AbstractContext, + ::DynamicPPL.Submodel, + left, + vn::Union{VarName,Nothing}, + ::AbstractVarInfo, +) + throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) +end diff --git a/src/utils.jl b/src/utils.jl index a4c5f4a1b..b09bfb9fa 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -793,10 +793,6 @@ end # Handle `AbstractDict` differently since `eltype` results in a `Pair`. infer_nested_eltype(::Type{<:AbstractDict{<:Any,ET}}) where {ET} = infer_nested_eltype(ET) -broadcast_safe(x) = x -broadcast_safe(x::Distribution) = (x,) -broadcast_safe(x::AbstractContext) = (x,) - # Convert (x=1,) to Dict(@varname(x) => 1) function to_varname_dict(nt::NamedTuple) return Dict{VarName,Any}(VarName{k}() => v for (k, v) in pairs(nt))