diff --git a/docs/src/api.md b/docs/src/api.md index 14b2447b5..3dd157281 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -8,7 +8,7 @@ Part of the API of DynamicPPL is defined in the more lightweight interface packa A core component of DynamicPPL is the [`@model`](@ref) macro. It can be used to define probabilistic models in an intuitive way by specifying random variables and their distributions with `~` statements. -These statements are rewritten by `@model` as calls of [internal functions](@ref model_internal) for sampling the variables and computing their log densities. +These statements are rewritten by `@model` as calls of internal functions for sampling the variables and computing their log densities. ```@docs @model @@ -344,6 +344,13 @@ Base.empty! SimpleVarInfo ``` +### Tilde-pipeline + +```@docs +tilde_assume!! +tilde_observe!! +``` + ### Accumulators The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators. @@ -450,33 +457,45 @@ AbstractPPL.evaluate!! This method mutates the `varinfo` used for execution. By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`. -To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method: - -```@docs -DynamicPPL.evaluate_and_sample!! -``` +If you wish to sample new values, see the section on [VarInfo initialisation](#VarInfo-initialisation) just below this. The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. Contexts are subtypes of `AbstractPPL.AbstractContext`. ```@docs -SamplingContext DefaultContext -PrefixContext ConditionContext +InitContext ``` -### Samplers +### VarInfo initialisation + +The function `init!!` is used to initialise, or overwrite, values in a VarInfo. +It is really a thin wrapper around using `evaluate!!` with an `InitContext`. + +```@docs +DynamicPPL.init!! +``` -In DynamicPPL two samplers are defined that are used to initialize unobserved random variables: -[`SampleFromPrior`](@ref) which samples from the prior distribution, and [`SampleFromUniform`](@ref) which samples from a uniform distribution. +To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained. +There are three concrete strategies provided in DynamicPPL: ```@docs -SampleFromPrior -SampleFromUniform +PriorInit +UniformInit +ParamsInit ``` -Additionally, a generic sampler for inference is implemented. +If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method. + +```@docs +DynamicPPL.AbstractInitStrategy +DynamicPPL.init +``` + +### Samplers + +In DynamicPPL a generic sampler for inference is implemented. ```@docs Sampler @@ -487,7 +506,7 @@ The default implementation of [`Sampler`](@ref) uses the following unexported fu ```@docs DynamicPPL.initialstep DynamicPPL.loadstate -DynamicPPL.initialsampler +DynamicPPL.init_strategy ``` Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`. @@ -502,9 +521,3 @@ There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_va DynamicPPL.Experimental.determine_suitable_varinfo DynamicPPL.Experimental.is_suitable_varinfo ``` - -### [Model-Internal Functions](@id model_internal) - -```@docs -tilde_assume -``` diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index ceb3f4981..f2d24ad92 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -8,8 +8,6 @@ else using ..EnzymeCore end -@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true - # Mark istrans as having 0 derivative. The `nothing` return value is not significant, Enzyme # only checks whether such a method exists, and never runs it. @inline EnzymeCore.EnzymeRules.inactive_noinl(::typeof(DynamicPPL.istrans), args...) = diff --git a/ext/DynamicPPLJETExt.jl b/ext/DynamicPPLJETExt.jl index 760d17bb0..89a36ffaf 100644 --- a/ext/DynamicPPLJETExt.jl +++ b/ext/DynamicPPLJETExt.jl @@ -21,22 +21,17 @@ end function DynamicPPL.Experimental._determine_varinfo_jet( model::DynamicPPL.Model; only_ddpl::Bool=true ) - # Use SamplingContext to test type stability. - sampling_model = DynamicPPL.contextualize( - model, DynamicPPL.SamplingContext(model.context) - ) - # First we try with the typed varinfo. - varinfo = DynamicPPL.typed_varinfo(sampling_model) + varinfo = DynamicPPL.typed_varinfo(model) - # Let's make sure that both evaluation and sampling doesn't result in type errors. + # Let's make sure that evaluation doesn't result in type errors. issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo( - sampling_model, varinfo; only_ddpl + model, varinfo; only_ddpl ) if !issuccess # Useful information for debugging. - @debug "Evaluaton with typed varinfo failed with the following issues:" + @debug "Evaluation with typed varinfo failed with the following issues:" @debug result end @@ -46,7 +41,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet( else # Warn the user that we can't use the type stable one. @warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo." - DynamicPPL.untyped_varinfo(sampling_model) + DynamicPPL.untyped_varinfo(model) end end diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index a29696720..cd86cfb5e 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -28,7 +28,7 @@ end function _check_varname_indexing(c::MCMCChains.Chains) return DynamicPPL.supports_varname_indexing(c) || - error("Chains do not support indexing using `VarName`s.") + error("This `Chains` object does not support indexing using `VarName`s.") end function DynamicPPL.getindex_varname( @@ -42,6 +42,15 @@ function DynamicPPL.varnames(c::MCMCChains.Chains) return keys(c.info.varname_to_symbol) end +function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx) + _check_varname_indexing(c) + d = Dict{DynamicPPL.VarName,Any}() + for vn in DynamicPPL.varnames(c) + d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx) + end + return d +end + """ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) @@ -114,9 +123,15 @@ function DynamicPPL.predict( iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) predictive_samples = map(iters) do (sample_idx, chain_idx) - DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx) - varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo)) - + # Extract values from the chain + values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) + # Resample any variables that are not present in `values_dict` + _, varinfo = DynamicPPL.init!!( + rng, + model, + varinfo, + DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()), + ) vals = DynamicPPL.values_as_in_model(model, false, varinfo) varname_vals = mapreduce( collect, @@ -248,13 +263,14 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha varinfo = DynamicPPL.VarInfo(model) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) return map(iters) do (sample_idx, chain_idx) - # TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702. - # Update the varinfo with the current sample and make variables not present in `chain` - # to be sampled. - DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx) - # NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to - # `deepcopy` the `varinfo` before passing it to the `model`. - model(deepcopy(varinfo)) + # Extract values from the chain + values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx) + # Resample any variables that are not present in `values_dict`, and + # return the model's retval. + retval, _ = DynamicPPL.init!!( + model, varinfo, DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()) + ) + retval end end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index f190c7605..4c2702f17 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -97,18 +97,21 @@ export AbstractVarInfo, values_as_in_model, # Samplers Sampler, - SampleFromPrior, - SampleFromUniform, # LogDensityFunction LogDensityFunction, # Contexts contextualize, - SamplingContext, DefaultContext, - PrefixContext, ConditionContext, - assume, - tilde_assume, + # Tilde pipeline + tilde_assume!!, + tilde_observe!!, + # Initialisation + InitContext, + AbstractInitStrategy, + PriorInit, + UniformInit, + ParamsInit, # Pseudo distributions NamedDist, NoDist, @@ -170,11 +173,13 @@ abstract type AbstractVarInfo <: AbstractModelTrace end # Necessary forward declarations include("utils.jl") include("chains.jl") +include("contexts.jl") +include("contexts/init.jl") include("model.jl") +include("prefix.jl") include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") -include("contexts.jl") include("submodel.jl") include("varnamedvector.jl") include("accumulators.jl") diff --git a/src/compiler.jl b/src/compiler.jl index 6384eaa7c..4266ac9db 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -60,11 +60,14 @@ evaluates to a `VarName`, and this will be used in the subsequent checks. If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be used in its place. """ -function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr)) +function isassumption(expr::Union{Expr,Symbol}, left_vn=make_varname_expression(expr)) + @gensym vn return quote - if $(DynamicPPL.contextual_isassumption)( - __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) - ) + # TODO(penelopeysm): This re-prefixing seems a bit wasteful. I'd really like + # the whole `isassumption` thing to be simplified, though, so I'll + # leave it till later. + $vn = $(DynamicPPL.maybe_prefix)($left_vn, __model__.prefix) + if $(DynamicPPL.contextual_isassumption)(__model__.context, $vn) # Considered an assumption by `__model__.context` which means either: # 1. We hit the default implementation, e.g. using `DefaultContext`, # which in turn means that we haven't considered if it's one of @@ -78,8 +81,8 @@ function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr) # TODO: Support by adding context to model, and use `model.args` # as the default conditioning. Then we no longer need to check `inargnames` # since it will all be handled by `contextual_isassumption`. - if !($(DynamicPPL.inargnames)($vn, __model__)) || - $(DynamicPPL.inmissings)($vn, __model__) + if !($(DynamicPPL.inargnames)($left_vn, __model__)) || + $(DynamicPPL.inmissings)($left_vn, __model__) true else $(maybe_view(expr)) === missing @@ -99,7 +102,7 @@ isassumption(expr) = :(false) Return `true` if `vn` is considered an assumption by `context`. """ -function contextual_isassumption(context::AbstractContext, vn) +function contextual_isassumption(context::AbstractContext, vn::VarName) if hasconditioned_nested(context, vn) val = getconditioned_nested(context, vn) # TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler? @@ -115,9 +118,7 @@ end isfixed(expr, vn) = false function isfixed(::Union{Symbol,Expr}, vn) - return :($(DynamicPPL.contextual_isfixed)( - __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) - )) + return :($(DynamicPPL.contextual_isfixed)(__model__.context, $vn)) end """ @@ -413,7 +414,9 @@ function generate_assign(left, right) return quote $right_val = $right if $(DynamicPPL.is_extracting_values)(__varinfo__) - $vn = $(DynamicPPL.prefix)(__model__.context, $(make_varname_expression(left))) + $vn = $(DynamicPPL.maybe_prefix)( + $(make_varname_expression(left)), __model__.prefix + ) __varinfo__ = $(map_accumulator!!)( $acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel) ) @@ -448,24 +451,23 @@ function generate_tilde(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn isassumption value dist + @gensym left_vn vn isassumption value dist return quote $dist = $right - $vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist) - $isassumption = $(DynamicPPL.isassumption(left, vn)) + $left_vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist) + $vn = $(DynamicPPL.maybe_prefix)($left_vn, __model__.prefix) + $isassumption = $(DynamicPPL.isassumption(left, left_vn)) if $(DynamicPPL.isfixed(left, vn)) - $left = $(DynamicPPL.getfixed_nested)( - __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) - ) + $left = $(DynamicPPL.getfixed_nested)(__model__.context, $vn) elseif $isassumption $(generate_tilde_assume(left, dist, vn)) else - # If `vn` is not in `argnames`, we need to make sure that the variable is defined. - if !$(DynamicPPL.inargnames)($vn, __model__) - $left = $(DynamicPPL.getconditioned_nested)( - __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) - ) + # If `left_vn` is not in `argnames`, we need to make sure that the variable is defined. + # (Note: we use the unprefixed `left_vn` here rather than `vn` which will have had + # prefixes applied!) + if !$(DynamicPPL.inargnames)($left_vn, __model__) + $left = $(DynamicPPL.getconditioned_nested)(__model__.context, $vn) end $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( @@ -495,6 +497,7 @@ function generate_tilde_assume(left, right, vn) return quote $value, __varinfo__ = $(DynamicPPL.tilde_assume!!)( __model__.context, + __model__.prefix, $(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)..., __varinfo__, ) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 786d7c913..1afad3963 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -1,104 +1,30 @@ # assume -""" - tilde_assume(context::SamplingContext, right, vn, vi) - -Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the sampled value with a context associated -with a sampler. - -Falls back to -```julia -tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) -``` -""" -function tilde_assume(context::SamplingContext, right, vn, vi) - return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) +function tilde_assume!!(context::AbstractContext, prefix, right::Distribution, vn, vi) + return tilde_assume!!(childcontext(context), prefix, right, vn, vi) end - -function tilde_assume(context::AbstractContext, args...) - return tilde_assume(childcontext(context), args...) -end -function tilde_assume(::DefaultContext, right, vn, vi) - return assume(right, vn, vi) -end - -function tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...) - return tilde_assume(rng, childcontext(context), args...) -end -function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi) - return assume(rng, sampler, right, vn, vi) -end -function tilde_assume(::DefaultContext, sampler, right, vn, vi) - # same as above but no rng - return assume(Random.default_rng(), sampler, right, vn, vi) -end - -function tilde_assume(context::PrefixContext, right, vn, vi) - # Note that we can't use something like this here: - # new_vn = prefix(context, vn) - # return tilde_assume(childcontext(context), right, new_vn, vi) - # This is because `prefix` applies _all_ prefixes in a given context to a - # variable name. Thus, if we had two levels of nested prefixes e.g. - # `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the - # first call would apply the prefix `a.b._`, and the recursive call - # would apply the prefix `b._`, resulting in `b.a.b._`. - # This is why we need a special function, `prefix_and_strip_contexts`. - new_vn, new_context = prefix_and_strip_contexts(context, vn) - return tilde_assume(new_context, right, new_vn, vi) -end -function tilde_assume( - rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi -) - new_vn, new_context = prefix_and_strip_contexts(context, vn) - return tilde_assume(rng, new_context, sampler, right, new_vn, vi) +function tilde_assume!!(::DefaultContext, prefix, right::Distribution, vn, vi) + y = getindex_internal(vi, vn) + f = from_maybe_linked_internal_transform(vi, vn, right) + x, inv_logjac = with_logabsdet_jacobian(f, y) + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) + return x, vi end """ - tilde_assume!!(context, right, vn, vi) + tilde_assume!!(context, prefix, right, vn, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value and updated `vi`. - -By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log -probability of `vi` with the returned value. """ -function tilde_assume!!(context, right, vn, vi) - return if right isa DynamicPPL.Submodel - _evaluate!!(right, vi, context, vn) - else - tilde_assume(context, right, vn, vi) - end +function tilde_assume!!(context, prefix, right::DynamicPPL.Submodel, vn, vi) + return _evaluate!!(right, vi, context, prefix, vn) end # observe -""" - tilde_observe!!(context::SamplingContext, right, left, vi) - -Handle observed constants with a `context` associated with a sampler. - -Falls back to `tilde_observe!!(context.context, right, left, vi)`. -""" -function tilde_observe!!(context::SamplingContext, right, left, vn, vi) - return tilde_observe!!(context.context, right, left, vn, vi) -end - function tilde_observe!!(context::AbstractContext, right, left, vn, vi) return tilde_observe!!(childcontext(context), right, left, vn, vi) end -# `PrefixContext` -function tilde_observe!!(context::PrefixContext, right, left, vn, vi) - # In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal - # value. For the need for prefix_and_strip_contexts rather than just prefix, see the - # comment in `tilde_assume!!`. - new_vn, new_context = if vn !== nothing - prefix_and_strip_contexts(context, vn) - else - vn, childcontext(context) - end - return tilde_observe!!(new_context, right, left, new_vn, vi) -end - """ tilde_observe!!(context, right, left, vn, vi) @@ -108,64 +34,11 @@ accumulate the log probability, and return the observed value and updated `vi`. Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!!(::DefaultContext, right, left, vn, vi) - right isa DynamicPPL.Submodel && - throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) +function tilde_observe!!(::DefaultContext, right::Distribution, left, vn, vi) vi = accumulate_observe!!(vi, right, left, vn) return left, vi end -function assume(::Random.AbstractRNG, spl::Sampler, dist) - return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") -end - -# fallback without sampler -function assume(dist::Distribution, vn::VarName, vi) - y = getindex_internal(vi, vn) - f = from_maybe_linked_internal_transform(vi, vn, dist) - x, inv_logjac = with_logabsdet_jacobian(f, y) - vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist) - return x, vi -end - -# TODO: Remove this thing. -# SampleFromPrior and SampleFromUniform -function assume( - rng::Random.AbstractRNG, - sampler::Union{SampleFromPrior,SampleFromUniform}, - dist::Distribution, - vn::VarName, - vi::VarInfoOrThreadSafeVarInfo, -) - if haskey(vi, vn) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") - # TODO(mhauru) Is it important to unset the flag here? The `true` allows us - # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure - # if that's okay. - unset_flag!(vi, vn, "del", true) - r = init(rng, dist, sampler) - f = to_maybe_linked_internal_transform(vi, vn, dist) - # TODO(mhauru) This should probably be call a function called setindex_internal! - vi = BangBang.setindex!!(vi, f(r), vn) - else - # Otherwise we just extract it. - r = vi[vn, dist] - end - else - r = init(rng, dist, sampler) - if istrans(vi) - f = to_linked_internal_transform(vi, vn, dist) - vi = push!!(vi, vn, f(r), dist) - # By default `push!!` sets the transformed flag to `false`. - vi = settrans!!(vi, true, vn) - else - vi = push!!(vi, vn, r, dist) - end - end - - # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct. - logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r) - vi = accumulate_assume!!(vi, r, logjac, vn, dist) - return r, vi +function tilde_observe!!(::DefaultContext, ::DynamicPPL.Submodel, left, vn, vi) + throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) end diff --git a/src/contexts.jl b/src/contexts.jl index addadfa1a..0679ed7e3 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -47,7 +47,7 @@ effectively updating the child context. ```jldoctest julia> using DynamicPPL: DynamicTransformationContext -julia> ctx = SamplingContext(); +julia> ctx = ConditionContext((; a = 1)); julia> DynamicPPL.childcontext(ctx) DefaultContext() @@ -121,73 +121,6 @@ setleafcontext(::IsLeaf, ::IsParent, left, right) = right setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right # Contexts -""" - SamplingContext( - [rng::Random.AbstractRNG=Random.default_rng()], - [sampler::AbstractSampler=SampleFromPrior()], - [context::AbstractContext=DefaultContext()], - ) - -Create a context that allows you to sample parameters with the `sampler` when running the model. -The `context` determines how the returned log density is computed when running the model. - -See also: [`DefaultContext`](@ref) -""" -struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext - rng::R - sampler::S - context::C -end - -function SamplingContext( - rng::Random.AbstractRNG=Random.default_rng(), sampler::AbstractSampler=SampleFromPrior() -) - return SamplingContext(rng, sampler, DefaultContext()) -end - -function SamplingContext( - sampler::AbstractSampler, context::AbstractContext=DefaultContext() -) - return SamplingContext(Random.default_rng(), sampler, context) -end - -function SamplingContext(rng::Random.AbstractRNG, context::AbstractContext) - return SamplingContext(rng, SampleFromPrior(), context) -end - -function SamplingContext(context::AbstractContext) - return SamplingContext(Random.default_rng(), SampleFromPrior(), context) -end - -NodeTrait(context::SamplingContext) = IsParent() -childcontext(context::SamplingContext) = context.context -function setchildcontext(parent::SamplingContext, child) - return SamplingContext(parent.rng, parent.sampler, child) -end - -""" - hassampler(context) - -Return `true` if `context` has a sampler. -""" -hassampler(::SamplingContext) = true -hassampler(context::AbstractContext) = hassampler(NodeTrait(context), context) -hassampler(::IsLeaf, context::AbstractContext) = false -hassampler(::IsParent, context::AbstractContext) = hassampler(childcontext(context)) - -""" - getsampler(context) - -Return the sampler of the context `context`. - -This will traverse the context tree until it reaches the first [`SamplingContext`](@ref), -at which point it will return the sampler of that context. -""" -getsampler(context::SamplingContext) = context.sampler -getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context) -getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context)) -getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context") - """ struct DefaultContext <: AbstractContext end @@ -197,124 +130,6 @@ when running the model. struct DefaultContext <: AbstractContext end NodeTrait(::DefaultContext) = IsLeaf() -""" - PrefixContext(vn::VarName[, context::AbstractContext]) - PrefixContext(vn::Val{sym}[, context::AbstractContext]) where {sym} - -Create a context that allows you to use the wrapped `context` when running the model and -prefixes all parameters with the VarName `vn`. - -`PrefixContext(Val(:a), context)` is equivalent to `PrefixContext(@varname(a), context)`. -If `context` is not provided, it defaults to `DefaultContext()`. - -This context is useful in nested models to ensure that the names of the parameters are -unique. - -See also: [`to_submodel`](@ref) -""" -struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext - vn_prefix::Tvn - context::C -end -PrefixContext(vn::VarName) = PrefixContext(vn, DefaultContext()) -function PrefixContext(::Val{sym}, context::AbstractContext) where {sym} - return PrefixContext(VarName{sym}(), context) -end -PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}()) - -NodeTrait(::PrefixContext) = IsParent() -childcontext(context::PrefixContext) = context.context -function setchildcontext(ctx::PrefixContext, child::AbstractContext) - return PrefixContext(ctx.vn_prefix, child) -end - -""" - prefix(ctx::AbstractContext, vn::VarName) - -Apply the prefixes in the context `ctx` to the variable name `vn`. -""" -function prefix(ctx::PrefixContext, vn::VarName) - return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix) -end -function prefix(ctx::AbstractContext, vn::VarName) - return prefix(NodeTrait(ctx), ctx, vn) -end -prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn -function prefix(::IsParent, ctx::AbstractContext, vn::VarName) - return prefix(childcontext(ctx), vn) -end - -""" - prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) - -Same as `prefix`, but additionally returns a new context stack that has all the -PrefixContexts removed. - -NOTE: This does _not_ modify any variables in any `ConditionContext` and -`FixedContext` that may be present in the context stack. This is because this -function is only used in `tilde_assume`, which is lower in the tilde-pipeline -than `contextual_isassumption` and `contextual_isfixed` (the functions which -actually use the `ConditionContext` and `FixedContext` values). Thus, by this -time, any `ConditionContext`s and `FixedContext`s present have already served -their purpose. - -If you call this function, you must therefore be careful to ensure that you _do -not_ need to modify any inner `ConditionContext`s and `FixedContext`s. If you -_do_ need to modify them, then you may need to use -`prefix_cond_and_fixed_variables` instead. -""" -function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) - child_context = childcontext(ctx) - # vn_prefixed contains the prefixes from all lower levels - vn_prefixed, child_context_without_prefixes = prefix_and_strip_contexts( - child_context, vn - ) - return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes -end -function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) - return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn) -end -prefix_and_strip_contexts(::IsLeaf, ctx::AbstractContext, vn::VarName) = (vn, ctx) -function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName) - vn, new_ctx = prefix_and_strip_contexts(childcontext(ctx), vn) - return vn, setchildcontext(ctx, new_ctx) -end - -""" - prefix(model::Model, x::VarName) - prefix(model::Model, x::Val{sym}) - prefix(model::Model, x::Any) - -Return `model` but with all random variables prefixed by `x`, where `x` is either: -- a `VarName` (e.g. `@varname(a)`), -- a `Val{sym}` (e.g. `Val(:a)`), or -- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that - this will introduce runtime overheads so is not recommended unless absolutely - necessary. - -# Examples - -```jldoctest -julia> using DynamicPPL: prefix - -julia> @model demo() = x ~ Dirac(1) -demo (generic function with 2 methods) - -julia> rand(prefix(demo(), @varname(my_prefix))) -(var"my_prefix.x" = 1,) - -julia> rand(prefix(demo(), Val(:my_prefix))) -(var"my_prefix.x" = 1,) -``` -""" -prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) -function prefix(model::Model, x::Val{sym}) where {sym} - return contextualize(model, PrefixContext(VarName{sym}(), model.context)) -end -function prefix(model::Model, x) - return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) -end - """ ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext} @@ -400,9 +215,6 @@ hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn) function hasconditioned_nested(::IsParent, context, vn) return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) end -function hasconditioned_nested(context::PrefixContext, vn) - return hasconditioned_nested(collapse_prefix_stack(context), vn) -end """ getconditioned_nested(context, vn) @@ -418,9 +230,6 @@ end function getconditioned_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end -function getconditioned_nested(context::PrefixContext, vn) - return getconditioned_nested(collapse_prefix_stack(context), vn) -end function getconditioned_nested(::IsParent, context, vn) return if hasconditioned(context, vn) getconditioned(context, vn) @@ -489,9 +298,6 @@ function conditioned(context::ConditionContext) # precedence over decendants of `context`. return _merge(context.values, conditioned(childcontext(context))) end -function conditioned(context::PrefixContext) - return conditioned(collapse_prefix_stack(context)) -end struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext values::Values @@ -554,9 +360,6 @@ hasfixed_nested(::IsLeaf, context, vn) = hasfixed(context, vn) function hasfixed_nested(::IsParent, context, vn) return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn) end -function hasfixed_nested(context::PrefixContext, vn) - return hasfixed_nested(collapse_prefix_stack(context), vn) -end """ getfixed_nested(context, vn) @@ -572,9 +375,6 @@ end function getfixed_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end -function getfixed_nested(context::PrefixContext, vn) - return getfixed_nested(collapse_prefix_stack(context), vn) -end function getfixed_nested(::IsParent, context, vn) return if hasfixed(context, vn) getfixed(context, vn) @@ -668,113 +468,3 @@ function fixed(context::FixedContext) # precedence over decendants of `context`. return _merge(context.values, fixed(childcontext(context))) end -function fixed(context::PrefixContext) - return fixed(collapse_prefix_stack(context)) -end - -""" - collapse_prefix_stack(context::AbstractContext) - -Apply `PrefixContext`s to any conditioned or fixed values inside them, and remove -the `PrefixContext`s from the context stack. - -!!! note - If you are reading this docstring, you might probably be interested in a more -thorough explanation of how PrefixContext and ConditionContext / FixedContext -interact with one another, especially in the context of submodels. - The DynamicPPL documentation contains [a separate page on this -topic](https://turinglang.org/DynamicPPL.jl/previews/PR892/internals/submodel_condition/) -which explains this in much more detail. - -```jldoctest -julia> using DynamicPPL: collapse_prefix_stack - -julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, ))); - -julia> collapse_prefix_stack(c1) -ConditionContext(Dict(a.x => 1), DefaultContext()) - -julia> # Here, `x` gets prefixed only with `a`, whereas `y` is prefixed with both. - c2 = PrefixContext(@varname(a), ConditionContext((x=1, ), PrefixContext(@varname(b), ConditionContext((y=2,))))); - -julia> collapsed = collapse_prefix_stack(c2); - -julia> # `collapsed` really looks something like this: - # ConditionContext(Dict{VarName{:a}, Int64}(a.b.y => 2, a.x => 1), DefaultContext()) - # To avoid fragility arising from the order of the keys in the doctest, we test - # this indirectly: - collapsed.values[@varname(a.x)], collapsed.values[@varname(a.b.y)] -(1, 2) -``` -""" -function collapse_prefix_stack(context::PrefixContext) - # Collapse the child context (thus applying any inner prefixes first) - collapsed = collapse_prefix_stack(childcontext(context)) - # Prefix any conditioned variables with the current prefix - # Note: prefix_conditioned_variables is O(N) in the depth of the context stack. - # So is this function. In the worst case scenario, this is O(N^2) in the - # depth of the context stack. - return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix) -end -function collapse_prefix_stack(context::AbstractContext) - return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context) -end -collapse_prefix_stack(::IsLeaf, context) = context -function collapse_prefix_stack(::IsParent, context) - new_child_context = collapse_prefix_stack(childcontext(context)) - return setchildcontext(context, new_child_context) -end - -""" - prefix_cond_and_fixed_variables(context::AbstractContext, prefix::VarName) - -Prefix all the conditioned and fixed variables in a given context with a single -`prefix`. - -```jldoctest -julia> using DynamicPPL: prefix_cond_and_fixed_variables, ConditionContext - -julia> c1 = ConditionContext((a=1, )) -ConditionContext((a = 1,), DefaultContext()) - -julia> prefix_cond_and_fixed_variables(c1, @varname(y)) -ConditionContext(Dict(y.a => 1), DefaultContext()) -``` -""" -function prefix_cond_and_fixed_variables(ctx::ConditionContext, prefix::VarName) - # Replace the prefix of the conditioned variables - vn_dict = to_varname_dict(ctx.values) - prefixed_vn_dict = Dict( - AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict - ) - # Prefix the child context as well - prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) - return ConditionContext(prefixed_vn_dict, prefixed_child_ctx) -end -function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName) - # Replace the prefix of the conditioned variables - vn_dict = to_varname_dict(ctx.values) - prefixed_vn_dict = Dict( - AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict - ) - # Prefix the child context as well - prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) - return FixedContext(prefixed_vn_dict, prefixed_child_ctx) -end -function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName) - return prefix_cond_and_fixed_variables( - NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix - ) -end -function prefix_cond_and_fixed_variables( - ::IsLeaf, context::AbstractContext, prefix::VarName -) - return context -end -function prefix_cond_and_fixed_variables( - ::IsParent, context::AbstractContext, prefix::VarName -) - return setchildcontext( - context, prefix_cond_and_fixed_variables(childcontext(context), prefix) - ) -end diff --git a/src/contexts/init.jl b/src/contexts/init.jl new file mode 100644 index 000000000..23b7fa7ab --- /dev/null +++ b/src/contexts/init.jl @@ -0,0 +1,176 @@ +""" + AbstractInitStrategy + +Abstract type representing the possible ways of initialising new values for +the random variables in a model (e.g., when creating a new VarInfo). +""" +abstract type AbstractInitStrategy end + +""" + init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, strategy::AbstractInitStrategy) + +Generate a new value for a random variable with the given distribution. + +!!! warning "Values must be unlinked" + The values returned by `init` are always in the untransformed space, i.e., + they must be within the support of the original distribution. That means that, + for example, `init(rng, dist, u::UniformInit)` will in general return values that + are outside the range [u.lower, u.upper]. +""" +function init end + +""" + PriorInit() + +Obtain new values by sampling from the prior distribution. +""" +struct PriorInit <: AbstractInitStrategy end +init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::PriorInit) = rand(rng, dist) + +""" + UniformInit() + UniformInit(lower, upper) + +Obtain new values by first transforming the distribution of the random variable +to unconstrained space, then sampling a value uniformly between `lower` and +`upper`, and transforming that value back to the original space. + +If `lower` and `upper` are unspecified, they default to `(-2, 2)`, which mimics +Stan's default initialisation strategy. + +Requires that `lower <= upper`. + +# References + +[Stan reference manual page on initialization](https://mc-stan.org/docs/reference-manual/execution.html#initialization) +""" +struct UniformInit{T<:AbstractFloat} <: AbstractInitStrategy + lower::T + upper::T + function UniformInit(lower::T, upper::T) where {T<:AbstractFloat} + lower > upper && + throw(ArgumentError("`lower` must be less than or equal to `upper`")) + return new{T}(lower, upper) + end + UniformInit() = UniformInit(-2.0, 2.0) +end +function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::UniformInit) + b = Bijectors.bijector(dist) + sz = Bijectors.output_size(b, size(dist)) + y = u.lower .+ ((u.upper - u.lower) .* rand(rng, sz...)) + b_inv = Bijectors.inverse(b) + x = b_inv(y) + # 0-dim arrays: https://github.com/TuringLang/Bijectors.jl/issues/398 + if x isa Array{<:Any,0} + x = x[] + end + return x +end + +""" + ParamsInit(params::AbstractDict{<:VarName}, default::AbstractInitStrategy=PriorInit()) + ParamsInit(params::NamedTuple, default::AbstractInitStrategy=PriorInit()) + +Obtain new values by extracting them from the given dictionary or NamedTuple. +The parameter `default` specifies how new values are to be obtained if they +cannot be found in `params`, or they are specified as `missing`. The default +for `default` is `PriorInit()`. + +!!! note + These values must be provided in the space of the untransformed distribution. +""" +struct ParamsInit{P,S<:AbstractInitStrategy} <: AbstractInitStrategy + params::P + default::S + function ParamsInit(params::AbstractDict{<:VarName}, default::AbstractInitStrategy) + return new{typeof(params),typeof(default)}(params, default) + end + ParamsInit(params::AbstractDict{<:VarName}) = ParamsInit(params, PriorInit()) + function ParamsInit(params::NamedTuple, default::AbstractInitStrategy=PriorInit()) + return ParamsInit(to_varname_dict(params), default) + end +end +function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::ParamsInit) + # TODO(penelopeysm): It would be nice to do a check to make sure that all + # of the parameters in `p.params` were actually used, and either warn or + # error if they aren't. This is actually quite non-trivial though because + # the structure of Dicts in particular can have arbitrary nesting. + return if hasvalue(p.params, vn, dist) + x = getvalue(p.params, vn, dist) + if x === missing + init(rng, vn, dist, p.default) + else + # TODO(penelopeysm): Since x is user-supplied, maybe we could also + # check here that the type / size of x matches the dist? + x + end + else + init(rng, vn, dist, p.default) + end +end + +""" + InitContext( + [rng::Random.AbstractRNG=Random.default_rng()], + [strategy::AbstractInitStrategy=PriorInit()], + ) + +A leaf context that indicates that new values for random variables are +currently being obtained through sampling. Used e.g. when initialising a fresh +VarInfo. Note that, if `leafcontext(model.context) isa InitContext`, then +`evaluate!!(model, varinfo)` will override all values in the VarInfo. +""" +struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractContext + rng::R + strategy::S + function InitContext( + rng::Random.AbstractRNG, strategy::AbstractInitStrategy=PriorInit() + ) + return new{typeof(rng),typeof(strategy)}(rng, strategy) + end + function InitContext(strategy::AbstractInitStrategy=PriorInit()) + return InitContext(Random.default_rng(), strategy) + end +end +NodeTrait(::InitContext) = IsLeaf() + +function tilde_assume!!( + ctx::InitContext, prefix, dist::Distribution, vn::VarName, vi::AbstractVarInfo +) + in_varinfo = haskey(vi, vn) + # `init()` always returns values in original space, i.e. possibly + # constrained + x = init(ctx.rng, vn, dist, ctx.strategy) + # Determine whether to insert a transformed value into the VarInfo. + # If the VarInfo alrady had a value for this variable, we will + # keep the same linked status as in the original VarInfo. If not, we + # check the rest of the VarInfo to see if other variables are linked. + # istrans(vi) returns true if vi is nonempty and all variables in vi + # are linked. + insert_transformed_value = in_varinfo ? istrans(vi, vn) : istrans(vi) + f = if insert_transformed_value + link_transform(dist) + else + identity + end + y, logjac = with_logabsdet_jacobian(f, x) + # Add the new value to the VarInfo. `push!!` errors if the value already + # exists, hence the need for setindex!!. + if in_varinfo + vi = setindex!!(vi, y, vn) + else + vi = push!!(vi, vn, y, dist) + end + # Neither of these set the `trans` flag so we have to do it manually if + # necessary. + insert_transformed_value && settrans!!(vi, true, vn) + # `accumulate_assume!!` wants untransformed values as the second argument. + vi = accumulate_assume!!(vi, x, logjac, vn, dist) + # We always return the untransformed value here, as that will determine + # what the lhs of the tilde-statement is set to. + return x, vi +end + +function tilde_observe!!(::InitContext, right, left, vn, vi) + return tilde_observe!!(DefaultContext(), right, left, vn, vi) +end diff --git a/src/debug_utils.jl b/src/debug_utils.jl index d71fa57cc..ef9e1b8cf 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -475,7 +475,7 @@ and checking if the model is consistent across runs. function has_static_constraints( rng::Random.AbstractRNG, model::Model; num_evals::Int=5, error_on_failure::Bool=false ) - new_model = DynamicPPL.contextualize(model, SamplingContext(rng, SampleFromPrior())) + new_model = DynamicPPL.contextualize(model, InitContext(rng)) results = map(1:num_evals) do _ check_model_and_trace(new_model, VarInfo(); error_on_failure=error_on_failure) end diff --git a/src/extract_priors.jl b/src/extract_priors.jl index 64dcf2eea..5342d70c4 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -117,7 +117,7 @@ extract_priors(args::Union{Model,AbstractVarInfo}...) = function extract_priors(rng::Random.AbstractRNG, model::Model) varinfo = VarInfo() varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(),)) - varinfo = last(evaluate_and_sample!!(rng, model, varinfo)) + varinfo = last(init!!(rng, model, varinfo)) return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end diff --git a/src/model.jl b/src/model.jl index ac9968cf2..7e4134993 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,9 +1,19 @@ """ - struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} + struct Model{ + F, + argnames, + defaultnames, + missings, + Targs, + Tdefaults, + Ctx<:AbstractContext, + Tprefix<:Union{Nothing,<:VarName} + } f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} context::Ctx=DefaultContext() + prefix::Tprefix=nothing end A `Model` struct with model evaluation function of type `F`, arguments of names `argnames` @@ -33,12 +43,21 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <: - AbstractProbabilisticProgram +struct Model{ + F, + argnames, + defaultnames, + missings, + Targs, + Tdefaults, + Ctx<:AbstractContext, + Tprefix<:Union{Nothing,<:VarName}, +} <: AbstractProbabilisticProgram f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} context::Ctx + prefix::Tprefix @doc """ Model{missings}(f, args::NamedTuple, defaults::NamedTuple) @@ -51,9 +70,10 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractConte args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}, context::Ctx=DefaultContext(), - ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}( - f, args, defaults, context + prefix::Tprefix=nothing, + ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx,Tprefix} + return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,Tprefix}( + f, args, defaults, context, prefix ) end end @@ -71,18 +91,27 @@ model with different arguments. args::NamedTuple{argnames,Targs}, defaults::NamedTuple{kwargnames,Tkwargs}, context::AbstractContext=DefaultContext(), -) where {F,argnames,Targs,kwargnames,Tkwargs} + prefix::Tprefix=nothing, +) where {F,argnames,Targs,kwargnames,Tkwargs,Tprefix} missing_args = Tuple( name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing ) missing_kwargs = Tuple( name for (name, typ) in zip(kwargnames, Tkwargs.types) if typ <: Missing ) - return :(Model{$(missing_args..., missing_kwargs...)}(f, args, defaults, context)) + return :(Model{$(missing_args..., missing_kwargs...)}( + f, args, defaults, context, prefix + )) end -function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs...) - return Model(f, args, NamedTuple(kwargs), context) +function Model( + f, + args::NamedTuple, + context::AbstractContext=DefaultContext(), + prefix::Union{Nothing,<:VarName}=nothing; + kwargs..., +) + return Model(f, args, NamedTuple(kwargs), context, prefix) end """ @@ -92,7 +121,7 @@ Return a new `Model` with the same evaluation function and other arguments, but with its underlying context set to `context`. """ function contextualize(model::Model, context::AbstractContext) - return Model(model.f, model.args, model.defaults, context) + return Model(model.f, model.args, model.defaults, context, model.prefix) end """ @@ -417,7 +446,7 @@ Return the conditioned values in `model`. ```jldoctest julia> using Distributions -julia> using DynamicPPL: conditioned, contextualize +julia> using DynamicPPL: conditioned julia> @model function demo() m ~ Normal() @@ -427,36 +456,25 @@ demo (generic function with 2 methods) julia> m = demo(); -julia> # Returns all the variables we have conditioned on + their values. - conditioned(condition(m, x=100.0, m=1.0)) -(x = 100.0, m = 1.0) - -julia> # Nested ones also work. - # (Note that `PrefixContext` also prefixes the variables of any - # ConditionContext that is _inside_ it; because of this, the type of the - # container has to be broadened to a `Dict`.) - cm = condition(contextualize(m, PrefixContext(@varname(a), ConditionContext((m=1.0,)))), x=100.0); - -julia> Set(keys(conditioned(cm))) == Set([@varname(a.m), @varname(x)]) -true +julia> # Condition on some values. + cm = m | (; x = 100.0, m = 1.0); -julia> # Since we conditioned on `a.m`, it is not treated as a random variable. - # However, `a.x` will still be a random variable. - keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: - a.x +julia> # Returns all the variables we have conditioned on, and their values. + conditioned(cm) +(x = 100.0, m = 1.0) -julia> # We can also condition on `a.m` _outside_ of the PrefixContext: - cm = condition(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0)); +julia> # If we prefix the model, the conditioned variables will also be prefixed. + pm = prefix(cm, @varname(f)); conditioned(pm) +Dict{VarName{:f}, Float64} with 2 entries: + f.x => 100.0 + f.m => 1.0 -julia> conditioned(cm) -Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: - a.m => 1.0 +julia> # If we condition _after_ the prefix, the prefix is not applied. + pm2 = prefix(m, @varname(f)); cm2 = pm2 | (; x = 100.0, m = 1.0); -julia> # Now `a.x` will be sampled. - keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: - a.x +julia> # When running this model, the variables inside are not treated as conditioned! + conditioned(cm2) +(x = 100.0, m = 1.0) ``` """ conditioned(model::Model) = conditioned(model.context) @@ -760,7 +778,7 @@ Return the fixed values in `model`. ```jldoctest julia> using Distributions -julia> using DynamicPPL: fixed, contextualize +julia> using DynamicPPL: fixed julia> @model function demo() m ~ Normal() @@ -770,31 +788,25 @@ demo (generic function with 2 methods) julia> m = demo(); -julia> # Returns all the variables we have fixed on + their values. - fixed(fix(m, x=100.0, m=1.0)) -(x = 100.0, m = 1.0) - -julia> # The rest of this is the same as the `condition` example above. - cm = fix(contextualize(m, PrefixContext(@varname(a), fix(m=1.0))), x=100.0); - -julia> Set(keys(fixed(cm))) == Set([@varname(a.m), @varname(x)]) -true +julia> # Fix some values. + fm = fix(m, (; x = 100.0, m = 1.0)); -julia> keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: - a.x +julia> # Returns all the variables we have fixed on, and their values. + fixed(fm) +(x = 100.0, m = 1.0) -julia> # We can also condition on `a.m` _outside_ of the PrefixContext: - cm = fix(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0)); +julia> # If we prefix the model, the fixed variables will also be prefixed. + pm = prefix(fm, @varname(f)); fixed(pm) +Dict{VarName{:f}, Float64} with 2 entries: + f.x => 100.0 + f.m => 1.0 -julia> fixed(cm) -Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: - a.m => 1.0 +julia> # If we fix _after_ the prefix, the prefix is not applied. + pm2 = prefix(m, @varname(f)); fm2 = fix(pm2, (; x = 100.0, m = 1.0)); -julia> # Now `a.x` will be sampled. - keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: - a.x +julia> # When running this model, the variables inside are not treated as fixed! + fixed(fm2) +(x = 100.0, m = 1.0) ``` """ fixed(model::Model) = fixed(model.context) @@ -815,7 +827,7 @@ end # ^ Weird Documenter.jl bug means that we have to write the two above separately # as it can only detect the `function`-less syntax. function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo()) - return first(evaluate_and_sample!!(rng, model, varinfo)) + return first(init!!(rng, model, varinfo)) end """ @@ -829,29 +841,36 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) end """ - evaluate_and_sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler]) + init!!( + [rng::Random.AbstractRNG,] + model::Model, + varinfo::AbstractVarInfo, + [init_strategy::AbstractInitStrategy=PriorInit()] + ) -Evaluate the `model` with the given `varinfo`, but perform sampling during the -evaluation using the given `sampler` by wrapping the model's context in a -`SamplingContext`. +Evaluate the `model` and replace the values of the model's random variables +in the given `varinfo` with new values, using a specified initialisation strategy. +If the values in `varinfo` are not set, they will be added. +using a specified initialisation strategy. -If `sampler` is not provided, defaults to [`SampleFromPrior`](@ref). +If `init_strategy` is not provided, defaults to PriorInit(). Returns a tuple of the model's return value, plus the updated `varinfo` object. """ -function evaluate_and_sample!!( +function init!!( rng::Random.AbstractRNG, model::Model, varinfo::AbstractVarInfo, - sampler::AbstractSampler=SampleFromPrior(), + init_strategy::AbstractInitStrategy=PriorInit(), ) - sampling_model = contextualize(model, SamplingContext(rng, sampler, model.context)) - return evaluate!!(sampling_model, varinfo) + new_context = setleafcontext(model.context, InitContext(rng, init_strategy)) + new_model = contextualize(model, new_context) + return evaluate!!(new_model, varinfo) end -function evaluate_and_sample!!( - model::Model, varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior() +function init!!( + model::Model, varinfo::AbstractVarInfo, init_strategy::AbstractInitStrategy=PriorInit() ) - return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler) + return init!!(Random.default_rng(), model, varinfo, init_strategy) end """ @@ -981,11 +1000,7 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f) Generate a sample of type `T` from the prior distribution of the `model`. """ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} - x = last( - evaluate_and_sample!!( - rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()) - ), - ) + x = last(init!!(rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()))) return values_as(x, T) end @@ -1157,25 +1172,8 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC end end -""" - predict([rng::Random.AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) - -Generate samples from the posterior predictive distribution by evaluating `model` at each set -of parameter values provided in `chain`. The number of posterior predictive samples matches -the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values -and the predicted values. -""" -function predict( - rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo} -) - varinfo = DynamicPPL.VarInfo(model) - return map(chain) do params_varinfo - vi = deepcopy(varinfo) - DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple)) - model(rng, vi) - return vi - end -end +# Implemented & documented in DynamicPPLMCMCChainsExt +function predict end """ returned(model::Model, parameters::NamedTuple) diff --git a/src/prefix.jl b/src/prefix.jl new file mode 100644 index 000000000..c8f258cac --- /dev/null +++ b/src/prefix.jl @@ -0,0 +1,108 @@ +""" + maybe_prefix(inner::Union{Nothing,<:VarName}, outer::Union{Nothing,<:VarName}) + +Prefix `inner` with the prefix `outer`. Both `inner` and `outer` can be either +`VarName`s or `Nothing`. + +Note that this differs from `AbstractPPL.prefix` in that it handles `nothing` values. +This can happen e.g. when prefixing a model that is not already prefixed; or when +executing submodels without automatic prefixing. +""" +maybe_prefix(inner::VarName, outer::VarName) = AbstractPPL.prefix(inner, outer) +maybe_prefix(vn::VarName, ::Nothing) = vn +maybe_prefix(::Nothing, vn::VarName) = vn +maybe_prefix(::Nothing, ::Nothing) = nothing + +""" + prefix_cond_and_fixed_variables(context::AbstractContext, prefix::VarName) + +Prefix all the conditioned and fixed variables in a given context with a single +`prefix`. + +```jldoctest +julia> using DynamicPPL: prefix_cond_and_fixed_variables, ConditionContext + +julia> c1 = ConditionContext((a=1, )) +ConditionContext((a = 1,), DefaultContext()) + +julia> prefix_cond_and_fixed_variables(c1, @varname(y)) +ConditionContext(Dict(y.a => 1), DefaultContext()) +``` +""" +function prefix_cond_and_fixed_variables(ctx::ConditionContext, prefix::VarName) + # Replace the prefix of the conditioned variables + vn_dict = to_varname_dict(ctx.values) + prefixed_vn_dict = Dict( + AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict + ) + # Prefix the child context as well + prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) + return ConditionContext(prefixed_vn_dict, prefixed_child_ctx) +end +function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName) + # Replace the prefix of the conditioned variables + vn_dict = to_varname_dict(ctx.values) + prefixed_vn_dict = Dict( + AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict + ) + # Prefix the child context as well + prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) + return FixedContext(prefixed_vn_dict, prefixed_child_ctx) +end +function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName) + return prefix_cond_and_fixed_variables( + NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix + ) +end +function prefix_cond_and_fixed_variables( + ::IsLeaf, context::AbstractContext, prefix::VarName +) + return context +end +function prefix_cond_and_fixed_variables( + ::IsParent, context::AbstractContext, prefix::VarName +) + return setchildcontext( + context, prefix_cond_and_fixed_variables(childcontext(context), prefix) + ) +end + +""" + DynamicPPL.prefix(model::Model, x::VarName) + DynamicPPL.prefix(model::Model, x::Val{sym}) + DynamicPPL.prefix(model::Model, x::Any) + +Return `model` but with all random variables prefixed by `x`, where `x` is either: +- a `VarName` (e.g. `@varname(a)`), +- a `Val{sym}` (e.g. `Val(:a)`), or +- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that + this will introduce runtime overheads so is not recommended unless absolutely + necessary. + +If `x` is `nothing`, then the model is returned unchanged. + +# Examples + +```jldoctest +julia> using DynamicPPL: prefix + +julia> @model demo() = x ~ Dirac(1) +demo (generic function with 2 methods) + +julia> rand(prefix(demo(), @varname(my_prefix))) +(var"my_prefix.x" = 1,) + +julia> rand(prefix(demo(), Val(:my_prefix))) +(var"my_prefix.x" = 1,) +``` +""" +prefix(model::Model, ::Nothing) = model +function prefix(model::Model, vn::VarName) + # Add it to the model prefix field + new_prefix = maybe_prefix(model.prefix, vn) + # And also make sure to prefix any conditioned and fixed variables stored in the model + new_context = prefix_cond_and_fixed_variables(model.context, vn) + return Model(model.f, model.args, model.defaults, new_context, new_prefix) +end +prefix(model::Model, ::Val{sym}) where {sym} = prefix(model, VarName{sym}()) +prefix(model::Model, x) = return prefix(model, VarName{Symbol(x)}()) diff --git a/src/sampler.jl b/src/sampler.jl index 673b5128f..711865008 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -1,34 +1,3 @@ -# TODO: Make `UniformSampling` and `Prior` algs + just use `Sampler` -# That would let us use all defaults for Sampler, combine it with other samplers etc. -""" - SampleFromUniform - -Sampling algorithm that samples unobserved random variables from a uniform distribution. - -# References - -[Stan reference manual](https://mc-stan.org/docs/2_28/reference-manual/initialization.html#random-initial-values) -""" -struct SampleFromUniform <: AbstractSampler end - -""" - SampleFromPrior - -Sampling algorithm that samples unobserved random variables from their prior distribution. -""" -struct SampleFromPrior <: AbstractSampler end - -# Initializations. -init(rng, dist, ::SampleFromPrior) = rand(rng, dist) -function init(rng, dist, ::SampleFromUniform) - return istransformable(dist) ? inittrans(rng, dist) : rand(rng, dist) -end - -init(rng, dist, ::SampleFromPrior, n::Int) = rand(rng, dist, n) -function init(rng, dist, ::SampleFromUniform, n::Int) - return istransformable(dist) ? inittrans(rng, dist, n) : rand(rng, dist, n) -end - # TODO(mhauru) Could we get rid of Sampler now that it's just a wrapper around `alg`? # (Selector has been removed). """ @@ -41,7 +10,7 @@ Generic sampler type for inference algorithms of type `T` in DynamicPPL. provided that supports resuming sampling from a previous state and setting initial parameter values. It requires to overload [`loadstate`](@ref) and [`initialstep`](@ref) for loading previous states and actually performing the initial sampling step, -respectively. Additionally, sometimes one might want to implement [`initialsampler`](@ref) +respectively. Additionally, sometimes one might want to implement [`init_strategy`](@ref) that specifies how the initial parameter values are sampled if they are not provided. By default, values are sampled from the prior. """ @@ -49,24 +18,13 @@ struct Sampler{T} <: AbstractSampler alg::T end -# AbstractMCMC interface for SampleFromUniform and SampleFromPrior -function AbstractMCMC.step( - rng::Random.AbstractRNG, - model::Model, - sampler::Union{SampleFromUniform,SampleFromPrior}, - state=nothing; - kwargs..., -) - vi = VarInfo() - DynamicPPL.evaluate_and_sample!!(rng, model, vi, sampler) - return vi, nothing -end - """ default_varinfo(rng, model, sampler) Return a default varinfo object for the given `model` and `sampler`. +The default method for this returns an empty NTVarInfo (i.e. 'typed varinfo'). + # Arguments - `rng::Random.AbstractRNG`: Random number generator. - `model::Model`: Model for which we want to create a varinfo object. @@ -75,9 +33,10 @@ Return a default varinfo object for the given `model` and `sampler`. # Returns - `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`. """ -function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler) - init_sampler = initialsampler(sampler) - return typed_varinfo(rng, model, init_sampler) +function default_varinfo(::Random.AbstractRNG, ::Model, ::AbstractSampler) + # Note that variable values are unconditionally initialized later, so no + # point putting them in now. + return typed_varinfo(VarInfo()) end function AbstractMCMC.sample( @@ -95,24 +54,32 @@ function AbstractMCMC.sample( ) end -# initial step: general interface for resuming and +""" + init_strategy(sampler) + +Define the initialisation strategy used for generating initial values when +sampling with `sampler`. Defaults to `PriorInit()`, but can be overridden. +""" +init_strategy(::Sampler) = PriorInit() + function AbstractMCMC.step( - rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs... + rng::Random.AbstractRNG, + model::Model, + spl::Sampler; + initial_params::AbstractInitStrategy=init_strategy(spl), + kwargs..., ) - # Sample initial values. + # Generate the default varinfo (usually this just makes an empty VarInfo + # with NamedTuple of Metadata). vi = default_varinfo(rng, model, spl) - # Update the parameters if provided. - if initial_params !== nothing - vi = initialize_parameters!!(vi, initial_params, model) - - # Update joint log probability. - # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 - # and https://github.com/TuringLang/Turing.jl/issues/1563 - # to avoid that existing variables are resampled - vi = last(evaluate!!(model, vi)) - end + # Fill it with initial parameters. Note that, if `ParamsInit` is used, the + # parameters provided must be in unlinked space (when inserted into the + # varinfo, they will be adjusted to match the linking status of the + # varinfo). + _, vi = init!!(rng, model, vi, initial_params) + # Call the actual function that does the first step. return initialstep(rng, model, spl, vi; initial_params, kwargs...) end @@ -130,110 +97,7 @@ loadstate(data) = data Default type of the chain of posterior samples from `sampler`. """ -default_chain_type(sampler::Sampler) = Any - -""" - initialsampler(sampler::Sampler) - -Return the sampler that is used for generating the initial parameters when sampling with -`sampler`. - -By default, it returns an instance of [`SampleFromPrior`](@ref). -""" -initialsampler(spl::Sampler) = SampleFromPrior() - -""" - set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector) - set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple) - -Take the values inside `initial_params`, replace the corresponding values in -the given VarInfo object, and return a new VarInfo object with the updated values. - -This differs from `DynamicPPL.unflatten` in two ways: - -1. It works with `NamedTuple` arguments. -2. For the `AbstractVector` method, if any of the elements are missing, it will not -overwrite the original value in the VarInfo (it will just use the original -value instead). -""" -function set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector) - throw( - ArgumentError( - "`initial_params` must be a vector of type `Union{Real,Missing}`. " * - "If `initial_params` is a vector of vectors, please flatten it (e.g. using `vcat`) first.", - ), - ) -end - -function set_initial_values( - varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}} -) - flattened_param_vals = varinfo[:] - length(flattened_param_vals) == length(initial_params) || throw( - DimensionMismatch( - "Provided initial value size ($(length(initial_params))) doesn't match " * - "the model size ($(length(flattened_param_vals))).", - ), - ) - - # Update values that are provided. - for i in eachindex(initial_params) - x = initial_params[i] - if x !== missing - flattened_param_vals[i] = x - end - end - - # Update in `varinfo`. - new_varinfo = unflatten(varinfo, flattened_param_vals) - return new_varinfo -end - -function set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple) - varinfo = deepcopy(varinfo) - vars_in_varinfo = keys(varinfo) - for v in keys(initial_params) - vn = VarName{v}() - if !(vn in vars_in_varinfo) - for vv in vars_in_varinfo - if subsumes(vn, vv) - throw( - ArgumentError( - "The current model contains sub-variables of $v, such as ($vv). " * - "Using NamedTuple for initial_params is not supported in such a case. " * - "Please use AbstractVector for initial_params instead of NamedTuple.", - ), - ) - end - end - throw(ArgumentError("Variable $v not found in the model.")) - end - end - initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing) - return update_values!!( - varinfo, initial_params, map(k -> VarName{k}(), keys(initial_params)) - ) -end - -function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Model) - @debug "Using passed-in initial variable values" initial_params - - # `link` the varinfo if needed. - linked = islinked(vi) - if linked - vi = invlink!!(vi, model) - end - - # Set the values in `vi`. - vi = set_initial_values(vi, initial_params) - - # `invlink` if needed. - if linked - vi = link!!(vi, model) - end - - return vi -end +default_chain_type(::Sampler) = Any """ initialstep(rng, model, sampler, varinfo; kwargs...) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index ad22bf52d..9bb56830d 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -39,7 +39,7 @@ julia> rng = StableRNG(42); julia> # In the `NamedTuple` version we need to provide the place-holder values for # the variables which are using "containers", e.g. `Array`. # In this case, this means that we need to specify `x` but not `m`. - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo((x = ones(2), ))); + _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo((x = ones(2), ))); julia> # (✓) Vroom, vroom! FAST!!! vi[@varname(x[1])] @@ -57,12 +57,12 @@ julia> vi[@varname(x[1:2])] 1.3736306979834252 julia> # (×) If we don't provide the container... - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); vi + _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); vi ERROR: type NamedTuple has no field x [...] julia> # If one does not know the varnames, we can use a `OrderedDict` instead. - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())); + _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())); julia> # (✓) Sort of fast, but only possible at runtime. vi[@varname(x[1])] @@ -91,28 +91,28 @@ demo_constrained (generic function with 2 methods) julia> m = demo_constrained(); -julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); +julia> _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞ 1.8632965762164932 -julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)); +julia> _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ -0.21080155351918753 -julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true julia> # And with `OrderedDict` of course! - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); + _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ 0.6225185067787314 -julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true @@ -232,24 +232,23 @@ end # Constructor from `Model`. function SimpleVarInfo{T}( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) where {T<:Real} - new_model = contextualize(model, SamplingContext(rng, sampler, model.context)) - return last(evaluate!!(new_model, SimpleVarInfo{T}())) + return last(init!!(rng, model, SimpleVarInfo{T}(), init_strategy)) end function SimpleVarInfo{T}( - model::Model, sampler::AbstractSampler=SampleFromPrior() + model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) where {T<:Real} - return SimpleVarInfo{T}(Random.default_rng(), model, sampler) + return SimpleVarInfo{T}(Random.default_rng(), model, init_strategy) end # Constructors without type param function SimpleVarInfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return SimpleVarInfo{LogProbType}(rng, model, sampler) + return SimpleVarInfo{LogProbType}(rng, model, init_strategy) end -function SimpleVarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return SimpleVarInfo{LogProbType}(Random.default_rng(), model, sampler) +function SimpleVarInfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit()) + return SimpleVarInfo{LogProbType}(Random.default_rng(), model, init_strategy) end # Constructor from `VarInfo`. @@ -265,12 +264,12 @@ end function untyped_simple_varinfo(model::Model) varinfo = SimpleVarInfo(OrderedDict{VarName,Any}()) - return last(evaluate_and_sample!!(model, varinfo)) + return last(init!!(model, varinfo)) end function typed_simple_varinfo(model::Model) varinfo = SimpleVarInfo{Float64}() - return last(evaluate_and_sample!!(model, varinfo)) + return last(init!!(model, varinfo)) end function unflatten(svi::SimpleVarInfo, x::AbstractVector) @@ -463,26 +462,6 @@ function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) return SimpleVarInfo(values, accs, transformation) end -# Context implementations -# NOTE: Evaluations, i.e. those without `rng` are shared with other -# implementations of `AbstractVarInfo`. -function assume( - rng::Random.AbstractRNG, - sampler::Union{SampleFromPrior,SampleFromUniform}, - dist::Distribution, - vn::VarName, - vi::SimpleOrThreadSafeSimple, -) - value = init(rng, dist, sampler) - # Transform if we're working in unconstrained space. - f = to_maybe_linked_internal_transform(vi, vn, dist) - value_raw, logjac = with_logabsdet_jacobian(f, value) - vi = BangBang.push!!(vi, vn, value_raw, dist) - vi = accumulate_assume!!(vi, value, logjac, vn, dist) - return value, vi -end - -# NOTE: We don't implement `settrans!!(vi, trans, vn)`. function settrans!!(vi::SimpleVarInfo, trans) return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) end @@ -492,6 +471,16 @@ end function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans) end +function settrans!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName) + # We keep this method around just to obey the AbstractVarInfo interface. + # However, note that this would only be a valid operation if it would be a + # no-op, which we check here. + if trans != istrans(vi) + error( + "Individual variables in SimpleVarInfo cannot have different `settrans` statuses.", + ) + end +end istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi) diff --git a/src/submodel.jl b/src/submodel.jl index dcb107bb4..4ed22db1d 100644 --- a/src/submodel.jl +++ b/src/submodel.jl @@ -158,28 +158,32 @@ to_submodel(m::Model, auto_prefix::Bool=true) = Submodel{typeof(m),auto_prefix}( # passed into this function. # # `parent_context` here refers to the context of the model that contains the -# submodel. +# submodel. `parent_prefix` is the prefix that is applied to the parent model. function _evaluate!!( submodel::Submodel{M,AutoPrefix}, vi::AbstractVarInfo, parent_context::AbstractContext, - left_vn::VarName, + parent_prefix::Union{Nothing,<:VarName}, + vn::VarName, ) where {M<:Model,AutoPrefix} # First, we construct the context to be used when evaluating the submodel. There # are several considerations here: - # (1) We need to apply an appropriate PrefixContext when evaluating the submodel, but - # _only_ if automatic prefixing is supposed to be applied. - submodel_context_prefixed = if AutoPrefix - PrefixContext(left_vn, submodel.model.context) + + # (1) Before even touching the contexts, we need to make sure that we apply + # automatic prefixing if it was requested. (If the prefix was manually applied, then + # `prefix()` will have been called by the user, and we don't need to do it again.) + submodel_prefix = if AutoPrefix + # Note that by the time we see it here (in `tilde_assume!!`), `vn` + # has already prefixed with `parent_prefix`, so no need to re-prefix it + vn else - submodel.model.context + parent_prefix end + submodel_model = DynamicPPL.prefix(submodel.model, submodel_prefix) # (2) We need to respect the leaf-context of the parent model. This, unfortunately, # means disregarding the leaf-context of the submodel. - submodel_context = setleafcontext( - submodel_context_prefixed, leafcontext(parent_context) - ) + submodel_context = setleafcontext(submodel_model.context, leafcontext(parent_context)) # (3) We need to use the parent model's context to wrap the whole thing, so that # e.g. if the user conditions the parent model, the conditioned variables will be @@ -187,7 +191,7 @@ function _evaluate!!( eval_context = setleafcontext(parent_context, submodel_context) # (4) Finally, we need to store that context inside the submodel. - model = contextualize(submodel.model, eval_context) + model = contextualize(submodel_model, eval_context) # Once that's all set up nicely, we can just _evaluate!! the wrapped model. This # returns a tuple of submodel.model's return value and the new varinfo. diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 863db4262..4a019441b 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -29,21 +29,45 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod node_trait = DynamicPPL.NodeTrait(context) # Throw error immediately if it it's missing a `NodeTrait` implementation. node_trait isa Union{DynamicPPL.IsLeaf,DynamicPPL.IsParent} || - throw(ValueError("Invalid NodeTrait: $node_trait")) + error("Invalid NodeTrait: $node_trait") - # To see change, let's make sure we're using a different leaf context than the current. - leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext - DynamicPPL.DynamicTransformationContext{false}() + if node_trait isa DynamicPPL.IsLeaf + test_leaf_context(context, model) else - DefaultContext() + test_parent_context(context, model) end - @test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) == - leafcontext_new +end + +function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) + @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf + + # Note that for a leaf context we can't assume that it will work with an + # empty VarInfo. Thus we only test evaluation (i.e., assuming that the + # varinfo already contains all necessary variables). + @testset "evaluation" begin + # Generate a new filled untyped varinfo + _, untyped_vi = DynamicPPL.init!!(model, DynamicPPL.VarInfo()) + typed_vi = DynamicPPL.typed_varinfo(untyped_vi) + new_model = contextualize(model, context) + for vi in [untyped_vi, typed_vi] + _, vi = DynamicPPL.evaluate!!(new_model, vi) + @test vi isa DynamicPPL.VarInfo + end + end +end + +function test_parent_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) + @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent - # The interface methods. - if node_trait isa DynamicPPL.IsParent - # `childcontext` and `setchildcontext` - # With new child context + @testset "{set,}{leaf,child}context" begin + # Ensure we're using a different leaf context than the current. + leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext + DynamicPPL.DynamicTransformationContext{false}() + else + DefaultContext() + end + @test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) == + leafcontext_new childcontext_new = TestParentContext() @test DynamicPPL.childcontext( DynamicPPL.setchildcontext(context, childcontext_new) @@ -56,19 +80,15 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod leafcontext_new end - # Make sure that the we can evaluate the model with the context (i.e. that none of the tilde-functions are incorrectly overloaded). - # The tilde-pipeline contains two different paths: with `SamplingContext` as a parent, and without it. - # NOTE(torfjelde): Need to sample with the untyped varinfo _using_ the context, since the - # context might alter which variables are present, their names, etc., e.g. `PrefixContext`. - # TODO(torfjelde): Make the `varinfo` used for testing a kwarg once it makes sense for other varinfos. - # Untyped varinfo. - varinfo_untyped = DynamicPPL.VarInfo() - model_with_spl = contextualize(model, SamplingContext(context)) - model_without_spl = contextualize(model, context) - @test DynamicPPL.evaluate!!(model_with_spl, varinfo_untyped) isa Any - @test DynamicPPL.evaluate!!(model_without_spl, varinfo_untyped) isa Any - # Typed varinfo. - varinfo_typed = DynamicPPL.typed_varinfo(varinfo_untyped) - @test DynamicPPL.evaluate!!(model_with_spl, varinfo_typed) isa Any - @test DynamicPPL.evaluate!!(model_without_spl, varinfo_typed) isa Any + @testset "initialisation and evaluation" begin + new_model = contextualize(model, context) + for vi in [DynamicPPL.VarInfo(), DynamicPPL.typed_varinfo(DynamicPPL.VarInfo())] + # Initialisation + _, vi = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo()) + @test vi isa DynamicPPL.VarInfo + # Evaluation + _, vi = DynamicPPL.evaluate!!(new_model, vi) + @test vi isa DynamicPPL.VarInfo + end + end end diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index 93aed074c..cb949464e 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -92,9 +92,7 @@ Even though it is recommended to implement this by hand for a particular `Model` a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. """ function varnames(model::Model) - return collect( - keys(last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(Dict())))) - ) + return collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(Dict()))))) end """ diff --git a/src/transforming.jl b/src/transforming.jl index 56f861cff..22493b49b 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -12,8 +12,8 @@ how to do the transformation, used by e.g. `SimpleVarInfo`. struct DynamicTransformationContext{isinverse} <: AbstractContext end NodeTrait(::DynamicTransformationContext) = IsLeaf() -function tilde_assume( - ::DynamicTransformationContext{isinverse}, right, vn, vi +function tilde_assume!!( + ::DynamicTransformationContext{isinverse}, prefix, right::Distribution, vn, vi ) where {isinverse} # vi[vn, right] always provides the value in unlinked space. x = vi[vn, right] @@ -31,7 +31,7 @@ function tilde_assume( return x, vi end -function tilde_observe!!(::DynamicTransformationContext, right, left, vn, vi) +function tilde_observe!!(::DynamicTransformationContext, right::Distribution, left, vn, vi) return tilde_observe!!(DefaultContext(), right, left, vn, vi) end diff --git a/src/utils.jl b/src/utils.jl index d3371271f..c70576a5d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -456,50 +456,6 @@ function recombine(d::MultivariateDistribution, val::AbstractVector, n::Int) return copy(reshape(val, length(d), n)) end -# Uniform random numbers with range 4 for robust initializations -# Reference: https://mc-stan.org/docs/2_19/reference-manual/initialization.html -randrealuni(rng::Random.AbstractRNG) = 4 * rand(rng) - 2 -randrealuni(rng::Random.AbstractRNG, args...) = 4 .* rand(rng, args...) .- 2 - -istransformable(dist) = link_transform(dist) !== identity - -################################# -# Single-sample initialisations # -################################# - -inittrans(rng, dist::UnivariateDistribution) = Bijectors.invlink(dist, randrealuni(rng)) -function inittrans(rng, dist::MultivariateDistribution) - # Get the length of the unconstrained vector - b = link_transform(dist) - d = Bijectors.output_length(b, length(dist)) - return Bijectors.invlink(dist, randrealuni(rng, d)) -end -function inittrans(rng, dist::MatrixDistribution) - # Get the size of the unconstrained vector - b = link_transform(dist) - sz = Bijectors.output_size(b, size(dist)) - return Bijectors.invlink(dist, randrealuni(rng, sz...)) -end -function inittrans(rng, dist::Distribution{CholeskyVariate}) - # Get the size of the unconstrained vector - b = link_transform(dist) - sz = Bijectors.output_size(b, size(dist)) - return Bijectors.invlink(dist, randrealuni(rng, sz...)) -end -################################ -# Multi-sample initialisations # -################################ - -function inittrans(rng, dist::UnivariateDistribution, n::Int) - return Bijectors.invlink(dist, randrealuni(rng, n)) -end -function inittrans(rng, dist::MultivariateDistribution, n::Int) - return Bijectors.invlink(dist, randrealuni(rng, size(dist)[1], n)) -end -function inittrans(rng, dist::MatrixDistribution, n::Int) - return Bijectors.invlink(dist, [randrealuni(rng, size(dist)...) for _ in 1:n]) -end - ####################### # Convenience methods # ####################### diff --git a/src/varinfo.jl b/src/varinfo.jl index e115a6799..26a5c34ac 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -113,10 +113,14 @@ function VarInfo(meta=Metadata()) end """ - VarInfo([rng, ]model[, sampler]) + VarInfo( + [rng::Random.AbstractRNG], + model, + [init_strategy::AbstractInitStrategy] + ) -Generate a `VarInfo` object for the given `model`, by evaluating it once using -the given `rng`, `sampler`. +Generate a `VarInfo` object for the given `model`, by initialising it with the +given `rng` and `init_strategy`. !!! warning @@ -129,12 +133,12 @@ the given `rng`, `sampler`. instead. """ function VarInfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return typed_varinfo(rng, model, sampler) + return typed_varinfo(rng, model, init_strategy) end -function VarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return VarInfo(Random.default_rng(), model, sampler) +function VarInfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit()) + return VarInfo(Random.default_rng(), model, init_strategy) end const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} @@ -195,7 +199,7 @@ end ######################## """ - untyped_varinfo([rng, ]model[, sampler]) + untyped_varinfo([rng, ]model[, init_strategy]) Construct a VarInfo object for the given `model`, which has just a single `Metadata` as its metadata field. @@ -203,15 +207,15 @@ Construct a VarInfo object for the given `model`, which has just a single # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `PriorInit()`. """ function untyped_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return last(evaluate_and_sample!!(rng, model, VarInfo(Metadata()), sampler)) + return last(init!!(rng, model, VarInfo(Metadata()), init_strategy)) end -function untyped_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return untyped_varinfo(Random.default_rng(), model, sampler) +function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit()) + return untyped_varinfo(Random.default_rng(), model, init_strategy) end """ @@ -270,7 +274,7 @@ function typed_varinfo(vi::NTVarInfo) return vi end """ - typed_varinfo([rng, ]model[, sampler]) + typed_varinfo([rng, ]model[, init_strategy]) Return a VarInfo object for the given `model`, which has a NamedTuple of `Metadata` structs as its metadata field. @@ -278,19 +282,19 @@ Return a VarInfo object for the given `model`, which has a NamedTuple of # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `PriorInit()`. """ function typed_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return typed_varinfo(untyped_varinfo(rng, model, sampler)) + return typed_varinfo(untyped_varinfo(rng, model, init_strategy)) end -function typed_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return typed_varinfo(Random.default_rng(), model, sampler) +function typed_varinfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit()) + return typed_varinfo(Random.default_rng(), model, init_strategy) end """ - untyped_vector_varinfo([rng, ]model[, sampler]) + untyped_vector_varinfo([rng, ]model[, init_strategy]) Return a VarInfo object for the given `model`, which has just a single `VarNamedVector` as its metadata field. @@ -298,23 +302,25 @@ Return a VarInfo object for the given `model`, which has just a single # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `PriorInit()`. """ function untyped_vector_varinfo(vi::UntypedVarInfo) md = metadata_to_varnamedvector(vi.metadata) return VarInfo(md, copy(vi.accs)) end function untyped_vector_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return untyped_vector_varinfo(untyped_varinfo(rng, model, sampler)) + return untyped_vector_varinfo(untyped_varinfo(rng, model, init_strategy)) end -function untyped_vector_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return untyped_vector_varinfo(Random.default_rng(), model, sampler) +function untyped_vector_varinfo( + model::Model, init_strategy::AbstractInitStrategy=PriorInit() +) + return untyped_vector_varinfo(Random.default_rng(), model, init_strategy) end """ - typed_vector_varinfo([rng, ]model[, sampler]) + typed_vector_varinfo([rng, ]model[, init_strategy]) Return a VarInfo object for the given `model`, which has a NamedTuple of `VarNamedVector`s as its metadata field. @@ -322,7 +328,7 @@ Return a VarInfo object for the given `model`, which has a NamedTuple of # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `PriorInit()`. """ function typed_vector_varinfo(vi::NTVarInfo) md = map(metadata_to_varnamedvector, vi.metadata) @@ -334,12 +340,12 @@ function typed_vector_varinfo(vi::UntypedVectorVarInfo) return VarInfo(nt, copy(vi.accs)) end function typed_vector_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return typed_vector_varinfo(untyped_vector_varinfo(rng, model, sampler)) + return typed_vector_varinfo(untyped_vector_varinfo(rng, model, init_strategy)) end -function typed_vector_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return typed_vector_varinfo(Random.default_rng(), model, sampler) +function typed_vector_varinfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit()) + return typed_vector_varinfo(Random.default_rng(), model, init_strategy) end """ @@ -1508,42 +1514,6 @@ function islinked(vi::VarInfo) return any(istrans(vi, vn) for vn in keys(vi)) end -function nested_setindex_maybe!(vi::UntypedVarInfo, val, vn::VarName) - return _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn) -end -function nested_setindex_maybe!( - vi::VarInfo{<:NamedTuple{names}}, val, vn::VarName{sym} -) where {names,sym} - return if sym in names - _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn) - else - nothing - end -end -function _nested_setindex_maybe!( - vi::VarInfo, md::Union{Metadata,VarNamedVector}, val, vn::VarName -) - # If `vn` is in `vns`, then we can just use the standard `setindex!`. - vns = Base.keys(md) - if vn in vns - setindex!(vi, val, vn) - return vn - end - - # Otherwise, we need to check if either of the `vns` subsumes `vn`. - i = findfirst(Base.Fix2(subsumes, vn), vns) - i === nothing && return nothing - - vn_parent = vns[i] - val_parent = getindex(vi, vn_parent) # TODO: Ensure that we're working with a view here. - # Split the varname into its tail optic. - optic = remove_parent_optic(vn_parent, vn) - # Update the value for the parent. - val_parent_updated = set!!(val_parent, optic, val) - setindex!(vi, val_parent_updated, vn_parent) - return vn_parent -end - # The default getindex & setindex!() for get & set values # NOTE: vi[vn] will always transform the variable to its original space and Julia type function getindex(vi::VarInfo, vn::VarName) @@ -1966,113 +1936,6 @@ function _setval_kernel!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, ke return indices end -""" - setval_and_resample!(vi::VarInfo, x) - setval_and_resample!(vi::VarInfo, values, keys) - setval_and_resample!(vi::VarInfo, chains::AbstractChains, sample_idx, chain_idx) - -Set the values in `vi` to the provided values and those which are not present -in `x` or `chains` to *be* resampled. - -Note that this does *not* resample the values not provided! It will call -`setflag!(vi, vn, "del")` for variables `vn` for which no values are provided, which means -that the next time we call `model(vi)` these variables will be resampled. - -## Note -- This suffers from the same limitations as [`setval!`](@ref). See `setval!` for more info. - -## Example -```jldoctest -julia> using DynamicPPL, Distributions, StableRNGs - -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1) - end - end; - -julia> rng = StableRNG(42); - -julia> m = demo([missing]); - -julia> var_info = DynamicPPL.VarInfo(rng, m); - # Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set. - -julia> var_info[@varname(m)] --0.6702516921145671 - -julia> var_info[@varname(x[1])] --0.22312984965118443 - -julia> DynamicPPL.setval_and_resample!(var_info, (m = 100.0, )); # set `m` and ready `x[1]` for resampling - -julia> var_info[@varname(m)] # [✓] changed -100.0 - -julia> var_info[@varname(x[1])] # [✓] unchanged --0.22312984965118443 - -julia> m(rng, var_info); # sample `x[1]` conditioned on `m = 100.0` - -julia> var_info[@varname(m)] # [✓] unchanged -100.0 - -julia> var_info[@varname(x[1])] # [✓] changed -101.37363069798343 -``` - -## See also -- [`setval!`](@ref) -""" -function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, x) - return setval_and_resample!(vi, values(x), keys(x)) -end -function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, values, keys) - return _apply!(_setval_and_resample_kernel!, vi, values, keys) -end -function setval_and_resample!( - vi::VarInfoOrThreadSafeVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int -) - if supports_varname_indexing(chains) - # First we need to set every variable to be resampled. - for vn in keys(vi) - set_flag!(vi, vn, "del") - end - # Then we set the variables in `varinfo` from `chain`. - for vn in varnames(chains) - vn_updated = nested_setindex_maybe!( - vi, getindex_varname(chains, sample_idx, vn, chain_idx), vn - ) - - # Unset the `del` flag if we found something. - if vn_updated !== nothing - # NOTE: This will be triggered even if only a subset of a variable has been set! - unset_flag!(vi, vn_updated, "del") - end - end - else - setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) - end -end - -function _setval_and_resample_kernel!( - vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys -) - indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) - if !isempty(indices) - val = reduce(vcat, values[indices]) - setval!(vi, val, vn) - settrans!!(vi, false, vn) - else - # Ensures that we'll resample the variable corresponding to `vn` if we run - # the model on `vi` again. - set_flag!(vi, vn, "del") - end - - return indices -end - values_as(vi::VarInfo) = vi.metadata values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon())) function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index d756a4922..2336b89b6 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -766,6 +766,11 @@ function update_internal!( return nothing end +function BangBang.push!(vnv::VarNamedVector, vn, val, dist) + f = from_vec_transform(dist) + return setindex_internal!(vnv, tovec(val), vn, f) +end + # BangBang versions of the above functions. # The only difference is that update_internal!! and insert_internal!! check whether the # container types of the VarNamedVector vector need to be expanded to accommodate the new diff --git a/test/Project.toml b/test/Project.toml index 6da3786f5..3bd424237 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -11,7 +11,6 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -39,7 +38,6 @@ DifferentiationInterface = "0.6.41, 0.7" Distributions = "0.25" DistributionsAD = "0.6.3" Documenter = "1" -EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10.12, 1" JET = "0.9, 0.10" LogDensityProblems = "2" diff --git a/test/ad.jl b/test/ad.jl index 371e79b06..23e676ee7 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -77,48 +77,6 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest end end - @testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin - # Failing model - t = 1:0.05:8 - σ = 0.3 - y = @. rand(sin(t) + Normal(0, σ)) - @model function state_space(y, TT, ::Type{T}=Float64) where {T} - # Priors - α ~ Normal(y[1], 0.001) - τ ~ Exponential(1) - η ~ filldist(Normal(0, 1), TT - 1) - σ ~ Exponential(1) - # create latent variable - x = Vector{T}(undef, TT) - x[1] = α - for t in 2:TT - x[t] = x[t - 1] + η[t - 1] * τ - end - # measurement model - y ~ MvNormal(x, σ^2 * I) - return x - end - model = state_space(y, length(t)) - - # Dummy sampling algorithm for testing. The test case can only be replicated - # with a custom sampler, it doesn't work with SampleFromPrior(). We need to - # overload assume so that model evaluation doesn't fail due to a lack - # of implementation - struct MyEmptyAlg end - DynamicPPL.assume( - ::Random.AbstractRNG, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi - ) = DynamicPPL.assume(dist, vn, vi) - - # Compiling the ReverseDiff tape used to fail here - spl = Sampler(MyEmptyAlg()) - sampling_model = contextualize(model, SamplingContext(model.context)) - ldf = LogDensityFunction( - sampling_model, getlogjoint_internal; adtype=AutoReverseDiff(; compile=true) - ) - x = ldf.varinfo[:] - @test LogDensityProblems.logdensity_and_gradient(ldf, x) isa Any - end - # Test that various different ways of specifying array types as arguments work with all # ADTypes. @testset "Array argument types" begin diff --git a/test/compiler.jl b/test/compiler.jl index 97121715a..874b71204 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -193,11 +193,8 @@ module Issue537 end varinfo = VarInfo(model) @test getlogjoint(varinfo) == lp @test varinfo_ isa AbstractVarInfo - # During the model evaluation, its context is wrapped in a - # SamplingContext, so `model_` is not going to be equal to `model`. - # We can still check equality of `f` though. @test model_.f === model.f - @test model_.context isa SamplingContext + @test model_.context isa DynamicPPL.InitContext @test model_.context.rng isa Random.AbstractRNG # disable warnings @@ -598,13 +595,13 @@ module Issue537 end # an attempt at a `NamedTuple` of the form `(x = 1, __varinfo__)`. @model empty_model() = return x = 1 empty_vi = VarInfo() - retval_and_vi = DynamicPPL.evaluate_and_sample!!(empty_model(), empty_vi) + retval_and_vi = DynamicPPL.init!!(empty_model(), empty_vi) @test retval_and_vi isa Tuple{Int,typeof(empty_vi)} # Even if the return-value is `AbstractVarInfo`, we should return # a `Tuple` with `AbstractVarInfo` in the second component too. @model demo() = return __varinfo__ - retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) @test svi == SimpleVarInfo() if Threads.nthreads() > 1 @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} @@ -620,11 +617,11 @@ module Issue537 end f(x) = return x^2 return f(1.0) end - retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) @test retval isa Float64 @model demo() = x ~ Normal() - retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) # Return-value when using `to_submodel` @model inner() = x ~ Normal() diff --git a/test/contexts.jl b/test/contexts.jl index 597ab736c..058b98d14 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,5 +1,5 @@ using Test, DynamicPPL, Accessors -using AbstractPPL: getoptic +using AbstractPPL: getoptic, hasvalue, getvalue using DynamicPPL: leafcontext, setleafcontext, @@ -18,12 +18,9 @@ using DynamicPPL: conditioned, fixed, hasconditioned_nested, - getconditioned_nested, - collapse_prefix_stack, - prefix_cond_and_fixed_variables, - getvalue - -using EnzymeCore + getconditioned_nested +using LinearAlgebra: I +using Random: Xoshiro # TODO: Should we maybe put this in DPPL itself? function Base.iterate(context::AbstractContext) @@ -49,16 +46,10 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() contexts = Dict( :default => DefaultContext(), :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), - :sampling => SamplingContext(), - :prefix => PrefixContext(@varname(x)), :condition1 => ConditionContext((x=1.0,)), :condition2 => ConditionContext( (x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,))) ), - :condition3 => ConditionContext( - (x=1.0,), - PrefixContext(@varname(a), ConditionContext(Dict(@varname(y) => 2.0))), - ), :condition4 => ConditionContext((x=[1.0, missing],)), ) @@ -103,7 +94,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # sometimes only the main symbol (e.g. it contains `x` when # `vn` is `x[1]`) for vn in conditioned_vns - val = DynamicPPL.getvalue(conditioned_values, vn) + val = getvalue(conditioned_values, vn) # These VarNames are present in the conditioning values, so # we should always be able to extract the value. @test hasconditioned_nested(context, vn) @@ -120,104 +111,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end - @testset "PrefixContext" begin - @testset "prefixing" begin - ctx = @inferred PrefixContext( - @varname(a), - PrefixContext( - @varname(b), - PrefixContext( - @varname(c), - PrefixContext( - @varname(d), - PrefixContext( - @varname(e), PrefixContext(@varname(f), DefaultContext()) - ), - ), - ), - ), - ) - vn = @varname(x) - vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) - @test vn_prefixed == @varname(a.b.c.d.e.f.x) - - vn = @varname(x[1]) - vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) - @test vn_prefixed == @varname(a.b.c.d.e.f.x[1]) - end - - @testset "nested within arbitrary context stacks" begin - vn = @varname(x[1]) - ctx1 = PrefixContext(@varname(a)) - @test DynamicPPL.prefix(ctx1, vn) == @varname(a.x[1]) - ctx2 = SamplingContext(ctx1) - @test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1]) - ctx3 = PrefixContext(@varname(b), ctx2) - @test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1]) - ctx4 = DynamicPPL.SamplingContext(ctx3) - @test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1]) - end - - @testset "prefix_and_strip_contexts" begin - vn = @varname(x[1]) - ctx1 = PrefixContext(@varname(a)) - new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx1, vn) - @test new_vn == @varname(a.x[1]) - @test new_ctx == DefaultContext() - - ctx2 = SamplingContext(PrefixContext(@varname(a))) - new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx2, vn) - @test new_vn == @varname(a.x[1]) - @test new_ctx == SamplingContext() - - ctx3 = PrefixContext(@varname(a), ConditionContext((a=1,))) - new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx3, vn) - @test new_vn == @varname(a.x[1]) - @test new_ctx == ConditionContext((a=1,)) - - ctx4 = SamplingContext(PrefixContext(@varname(a), ConditionContext((a=1,)))) - new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx4, vn) - @test new_vn == @varname(a.x[1]) - @test new_ctx == SamplingContext(ConditionContext((a=1,))) - end - - @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - prefix_vn = @varname(my_prefix) - context = DynamicPPL.PrefixContext(prefix_vn, SamplingContext()) - sampling_model = contextualize(model, context) - # Sample with the context. - varinfo = DynamicPPL.VarInfo() - DynamicPPL.evaluate!!(sampling_model, varinfo) - # Extract the resulting varnames - vns_actual = Set(keys(varinfo)) - - # Extract the ground truth varnames - vns_expected = Set([ - AbstractPPL.prefix(vn, prefix_vn) for - vn in DynamicPPL.TestUtils.varnames(model) - ]) - - # Check that all variables are prefixed correctly. - @test vns_actual == vns_expected - end - end - - @testset "SamplingContext" begin - context = SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()) - @test context isa SamplingContext - - # convenience constructors - @test SamplingContext() == context - @test SamplingContext(Random.default_rng()) == context - @test SamplingContext(SampleFromPrior()) == context - @test SamplingContext(DefaultContext()) == context - @test SamplingContext(Random.default_rng(), SampleFromPrior()) == context - @test SamplingContext(Random.default_rng(), DefaultContext()) == context - @test SamplingContext(SampleFromPrior(), DefaultContext()) == context - @test SamplingContext(SampleFromPrior(), DefaultContext()) == context - @test EnzymeCore.EnzymeRules.inactive_type(typeof(context)) - end - @testset "ConditionContext" begin @testset "Nesting" begin @testset "NamedTuple" begin @@ -333,102 +226,207 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end - @testset "PrefixContext + Condition/FixedContext interactions" begin - @testset "prefix_cond_and_fixed_variables" begin - c1 = ConditionContext((c=1, d=2)) - c1_prefixed = prefix_cond_and_fixed_variables(c1, @varname(a)) - @test c1_prefixed isa ConditionContext - @test childcontext(c1_prefixed) isa DefaultContext - @test c1_prefixed.values[@varname(a.c)] == 1 - @test c1_prefixed.values[@varname(a.d)] == 2 - - c2 = FixedContext((f=1, g=2)) - c2_prefixed = prefix_cond_and_fixed_variables(c2, @varname(a)) - @test c2_prefixed isa FixedContext - @test childcontext(c2_prefixed) isa DefaultContext - @test c2_prefixed.values[@varname(a.f)] == 1 - @test c2_prefixed.values[@varname(a.g)] == 2 - - c3 = ConditionContext((c=1, d=2), FixedContext((f=1, g=2))) - c3_prefixed = prefix_cond_and_fixed_variables(c3, @varname(a)) - c3_prefixed_child = childcontext(c3_prefixed) - @test c3_prefixed isa ConditionContext - @test c3_prefixed.values[@varname(a.c)] == 1 - @test c3_prefixed.values[@varname(a.d)] == 2 - @test c3_prefixed_child isa FixedContext - @test c3_prefixed_child.values[@varname(a.f)] == 1 - @test c3_prefixed_child.values[@varname(a.g)] == 2 - @test childcontext(c3_prefixed_child) isa DefaultContext + @testset "InitContext" begin + empty_varinfos = [ + VarInfo(), + DynamicPPL.typed_varinfo(VarInfo()), + VarInfo(DynamicPPL.VarNamedVector()), + DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())), + SimpleVarInfo(), + SimpleVarInfo(Dict{VarName,Any}()), + ] + + @model function test_init_model() + x ~ Normal() + y ~ MvNormal(fill(x, 2), I) + 1.0 ~ Normal() + return nothing end - - @testset "collapse_prefix_stack" begin - # Utility function to make sure that there are no PrefixContexts in - # the context stack. - function has_no_prefixcontexts(ctx::AbstractContext) - return !(ctx isa PrefixContext) && ( - NodeTrait(ctx) isa IsLeaf || has_no_prefixcontexts(childcontext(ctx)) + function test_generating_new_values(strategy::AbstractInitStrategy) + @testset "generating new values: $(typeof(strategy))" begin + # Check that init!! can generate values that weren't there + # previously. + model = test_init_model() + for empty_vi in empty_varinfos + this_vi = deepcopy(empty_vi) + _, vi = DynamicPPL.init!!(model, this_vi, strategy) + @test Set(keys(vi)) == Set([@varname(x), @varname(y)]) + x, y = vi[@varname(x)], vi[@varname(y)] + @test x isa Real + @test y isa AbstractVector{<:Real} + @test length(y) == 2 + (; logprior, loglikelihood) = getlogp(vi) + @test logpdf(Normal(), x) + logpdf(MvNormal(fill(x, 2), I), y) == + logprior + @test logpdf(Normal(), 1.0) == loglikelihood + end + end + end + function test_replacing_values(strategy::AbstractInitStrategy) + @testset "replacing old values: $(typeof(strategy))" begin + # Check that init!! can overwrite values that were already there. + model = test_init_model() + for empty_vi in empty_varinfos + # start by generating some rubbish values + vi = deepcopy(empty_vi) + old_x, old_y = 100000.00, [300000.00, 500000.00] + push!!(vi, @varname(x), old_x, Normal()) + push!!(vi, @varname(y), old_y, MvNormal(fill(old_x, 2), I)) + # then overwrite it + _, new_vi = DynamicPPL.init!!(model, vi, strategy) + new_x, new_y = new_vi[@varname(x)], new_vi[@varname(y)] + # check that the values are (presumably) different + @test old_x != new_x + @test old_y != new_y + end + end + end + function test_rng_respected(strategy::AbstractInitStrategy) + @testset "check that RNG is respected: $(typeof(strategy))" begin + model = test_init_model() + for empty_vi in empty_varinfos + _, vi1 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi2 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi3 = DynamicPPL.init!!( + Xoshiro(469), model, deepcopy(empty_vi), strategy + ) + @test vi1[@varname(x)] == vi2[@varname(x)] + @test vi1[@varname(y)] == vi2[@varname(y)] + @test vi1[@varname(x)] != vi3[@varname(x)] + @test vi1[@varname(y)] != vi3[@varname(y)] + end + end + end + function test_link_status_respected(strategy::AbstractInitStrategy) + @testset "check that varinfo linking is preserved: $(typeof(strategy))" begin + @model logn() = a ~ LogNormal() + model = logn() + vi = VarInfo(model) + linked_vi = DynamicPPL.link!!(vi, model) + _, new_vi = DynamicPPL.init!!(model, linked_vi, strategy) + @test DynamicPPL.istrans(new_vi) + # this is the unlinked value, since it uses `getindex` + a = new_vi[@varname(a)] + # internal logjoint should correspond to the transformed value + @test isapprox( + DynamicPPL.getlogjoint_internal(new_vi), logpdf(Normal(), log(a)) + ) + # user logjoint should correspond to the transformed value + @test isapprox(DynamicPPL.getlogjoint(new_vi), logpdf(LogNormal(), a)) + @test isapprox( + only(DynamicPPL.getindex_internal(new_vi, @varname(a))), log(a) ) end + end - # Prefix -> Condition - c1 = PrefixContext(@varname(a), ConditionContext((c=1, d=2))) - c1 = collapse_prefix_stack(c1) - @test has_no_prefixcontexts(c1) - c1_vals = conditioned(c1) - @test length(c1_vals) == 2 - @test getvalue(c1_vals, @varname(a.c)) == 1 - @test getvalue(c1_vals, @varname(a.d)) == 2 - - # Condition -> Prefix - c2 = ConditionContext((c=1, d=2), PrefixContext(@varname(a))) - c2 = collapse_prefix_stack(c2) - @test has_no_prefixcontexts(c2) - c2_vals = conditioned(c2) - @test length(c2_vals) == 2 - @test getvalue(c2_vals, @varname(c)) == 1 - @test getvalue(c2_vals, @varname(d)) == 2 - - # Prefix -> Fixed - c3 = PrefixContext(@varname(a), FixedContext((f=1, g=2))) - c3 = collapse_prefix_stack(c3) - c3_vals = fixed(c3) - @test length(c3_vals) == 2 - @test length(c3_vals) == 2 - @test getvalue(c3_vals, @varname(a.f)) == 1 - @test getvalue(c3_vals, @varname(a.g)) == 2 + @testset "PriorInit" begin + test_generating_new_values(PriorInit()) + test_replacing_values(PriorInit()) + test_rng_respected(PriorInit()) + test_link_status_respected(PriorInit()) + + @testset "check that values are within support" begin + # Not many other sensible checks we can do for priors. + @model just_unif() = x ~ Uniform(0.0, 1e-7) + for _ in 1:100 + _, vi = DynamicPPL.init!!(just_unif(), VarInfo(), PriorInit()) + @test vi[@varname(x)] isa Real + @test 0.0 <= vi[@varname(x)] <= 1e-7 + end + end + end - # Fixed -> Prefix - c4 = FixedContext((f=1, g=2), PrefixContext(@varname(a))) - c4 = collapse_prefix_stack(c4) - @test has_no_prefixcontexts(c4) - c4_vals = fixed(c4) - @test length(c4_vals) == 2 - @test getvalue(c4_vals, @varname(f)) == 1 - @test getvalue(c4_vals, @varname(g)) == 2 + @testset "UniformInit" begin + test_generating_new_values(UniformInit()) + test_replacing_values(UniformInit()) + test_rng_respected(UniformInit()) + test_link_status_respected(UniformInit()) + + @testset "check that bounds are respected" begin + @testset "unconstrained" begin + umin, umax = -1.0, 1.0 + @model just_norm() = x ~ Normal() + for _ in 1:100 + _, vi = DynamicPPL.init!!( + just_norm(), VarInfo(), UniformInit(umin, umax) + ) + @test vi[@varname(x)] isa Real + @test umin <= vi[@varname(x)] <= umax + end + end + @testset "constrained" begin + umin, umax = -1.0, 1.0 + @model just_beta() = x ~ Beta(2, 2) + inv_bijector = inverse(Bijectors.bijector(Beta(2, 2))) + tmin, tmax = inv_bijector(umin), inv_bijector(umax) + for _ in 1:100 + _, vi = DynamicPPL.init!!( + just_beta(), VarInfo(), UniformInit(umin, umax) + ) + @test vi[@varname(x)] isa Real + @test tmin <= vi[@varname(x)] <= tmax + end + end + end + end - # Prefix -> Condition -> Prefix -> Condition - c5 = PrefixContext( - @varname(a), - ConditionContext( - (c=1,), PrefixContext(@varname(b), ConditionContext((d=2,))) - ), - ) - c5 = collapse_prefix_stack(c5) - @test has_no_prefixcontexts(c5) - c5_vals = conditioned(c5) - @test length(c5_vals) == 2 - @test getvalue(c5_vals, @varname(a.c)) == 1 - @test getvalue(c5_vals, @varname(a.b.d)) == 2 + @testset "ParamsInit" begin + test_link_status_respected(ParamsInit((; a=1.0))) + test_link_status_respected(ParamsInit(Dict(@varname(a) => 1.0))) + + @testset "given full set of parameters" begin + # test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I) + my_x, my_y = 1.0, [2.0, 3.0] + params_nt = (; x=my_x, y=my_y) + params_dict = Dict(@varname(x) => my_x, @varname(y) => my_y) + model = test_init_model() + for empty_vi in empty_varinfos + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), ParamsInit(params_nt) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_nt = getlogp(vi) + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), ParamsInit(params_dict) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_dict = getlogp(vi) + @test logp_nt == logp_dict + end + end - # Prefix -> Condition -> Prefix -> Fixed - c6 = PrefixContext( - @varname(a), - ConditionContext((c=1,), PrefixContext(@varname(b), FixedContext((d=2,)))), - ) - c6 = collapse_prefix_stack(c6) - @test has_no_prefixcontexts(c6) - @test conditioned(c6) == Dict(@varname(a.c) => 1) - @test fixed(c6) == Dict(@varname(a.b.d) => 2) + @testset "given only partial parameters" begin + # In this case, we expect `ParamsInit` to use the value of x, and + # generate a new value for y. + my_x = 1.0 + params_nt = (; x=my_x) + params_dict = Dict(@varname(x) => my_x) + model = test_init_model() + for empty_vi in empty_varinfos + _, vi = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), ParamsInit(params_nt) + ) + @test vi[@varname(x)] == my_x + nt_y = vi[@varname(y)] + @test nt_y isa AbstractVector{<:Real} + @test length(nt_y) == 2 + _, vi = DynamicPPL.init!!( + Xoshiro(469), model, deepcopy(empty_vi), ParamsInit(params_dict) + ) + @test vi[@varname(x)] == my_x + dict_y = vi[@varname(y)] + @test dict_y isa AbstractVector{<:Real} + @test length(dict_y) == 2 + # the values should be different since we used different seeds + @test dict_y != nt_y + end + end end end end diff --git a/test/debug_utils.jl b/test/debug_utils.jl index 5bf741ff3..f950f6b45 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -149,7 +149,7 @@ model = demo_missing_in_multivariate([1.0, missing]) # Have to run this check_model call with an empty varinfo, because actually # instantiating the VarInfo would cause it to throw a MethodError. - model = contextualize(model, SamplingContext()) + model = contextualize(model, InitContext()) @test_throws ErrorException check_model(model, VarInfo(); error_on_failure=true) end diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 6737cf056..a820d885e 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -40,6 +40,11 @@ end end @test DynamicPPL.Experimental.determine_suitable_varinfo(demo4()) isa + DynamicPPL.NTVarInfo + init_model = DynamicPPL.contextualize( + demo4(), DynamicPPL.InitContext(DynamicPPL.PriorInit()) + ) + @test DynamicPPL.Experimental.determine_suitable_varinfo(init_model) isa DynamicPPL.UntypedVarInfo # In this model, the type error occurs in the user code rather than in DynamicPPL. @@ -62,19 +67,14 @@ @testset "demo models" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - sampling_model = contextualize(model, SamplingContext(model.context)) # Use debug logging below. varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) - # Check that the inferred varinfo is indeed suitable for evaluation and sampling + # Check that the inferred varinfo is indeed suitable for evaluation f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( model, varinfo ) JET.test_call(f_eval, argtypes_eval) - f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - sampling_model, varinfo - ) - JET.test_call(f_sample, argtypes_sample) # For our demo models, they should all result in typed. is_typed = varinfo isa DynamicPPL.NTVarInfo @test is_typed @@ -85,10 +85,6 @@ model, typed_vi ) JET.test_call(f_eval, argtypes_eval) - f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - sampling_model, typed_vi - ) - JET.test_call(f_sample, argtypes_sample) end end end diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index 3ba5edfe1..79e13ad84 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -2,7 +2,12 @@ @model demo() = x ~ Normal() model = demo() - chain = MCMCChains.Chains(randn(1000, 2, 1), [:x, :y], Dict(:internals => [:y])) + chain = MCMCChains.Chains( + randn(1000, 2, 1), + [:x, :y], + Dict(:internals => [:y]); + info=(; varname_to_symbol=Dict(@varname(x) => :x)), + ) chain_generated = @test_nowarn returned(model, chain) @test size(chain_generated) == (1000, 1) @test mean(chain_generated) ≈ 0 atol = 0.1 diff --git a/test/lkj.jl b/test/lkj.jl index d581cd21b..03e744b84 100644 --- a/test/lkj.jl +++ b/test/lkj.jl @@ -16,20 +16,15 @@ end # Same for both distributions target_mean = vec(Matrix{Float64}(I, 2, 2)) +n_samples = 1000 _lkj_atol = 0.05 @testset "Sample from x ~ LKJ(2, 1)" begin model = lkj_prior_demo() - # `SampleFromPrior` will sample in constrained space. - @testset "SampleFromPrior" begin - samples = sample(model, SampleFromPrior(), 1_000; progress=false) - @test mean(map(Base.Fix2(getindex, Colon()), samples)) ≈ target_mean atol = - _lkj_atol - end - - # `SampleFromUniform` will sample in unconstrained space. - @testset "SampleFromUniform" begin - samples = sample(model, SampleFromUniform(), 1_000; progress=false) + for init_strategy in [PriorInit(), UniformInit()] + samples = [ + last(DynamicPPL.init!!(model, VarInfo(), init_strategy)) for _ in 1:n_samples + ] @test mean(map(Base.Fix2(getindex, Colon()), samples)) ≈ target_mean atol = _lkj_atol end @@ -38,20 +33,10 @@ end @testset "Sample from x ~ LKJCholesky(2, 1, $(uplo))" for uplo in ['U', 'L'] model = lkj_chol_prior_demo(uplo) # `SampleFromPrior` will sample in unconstrained space. - @testset "SampleFromPrior" begin - samples = sample(model, SampleFromPrior(), 1_000; progress=false) - # Build correlation matrix from factor - corr_matrices = map(samples) do s - M = reshape(s.metadata.vals, (2, 2)) - pd_from_triangular(M, uplo) - end - @test vec(mean(corr_matrices)) ≈ target_mean atol = _lkj_atol - end - - # `SampleFromUniform` will sample in unconstrained space. - @testset "SampleFromUniform" begin - samples = sample(model, SampleFromUniform(), 1_000; progress=false) - # Build correlation matrix from factor + for init_strategy in [PriorInit(), UniformInit()] + samples = [ + last(DynamicPPL.init!!(model, VarInfo(), init_strategy)) for _ in 1:n_samples + ] corr_matrices = map(samples) do s M = reshape(s.metadata.vals, (2, 2)) pd_from_triangular(M, uplo) diff --git a/test/model.jl b/test/model.jl index 81f84e548..964383c56 100644 --- a/test/model.jl +++ b/test/model.jl @@ -155,24 +155,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() logjoint(model, chain) end - @testset "rng" begin - model = GDEMO_DEFAULT - - for sampler in (SampleFromPrior(), SampleFromUniform()) - for i in 1:10 - Random.seed!(100 + i) - vi = VarInfo() - DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) - vals = vi[:] - - Random.seed!(100 + i) - vi = VarInfo() - DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) - @test vi[:] == vals - end - end - end - @testset "defaults without VarInfo, Sampler, and Context" begin model = GDEMO_DEFAULT @@ -332,7 +314,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test logjoint(model, x) != DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...) # Ensure `varnames` is implemented. - vi = last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(OrderedDict()))) + vi = last(DynamicPPL.init!!(model, SimpleVarInfo(OrderedDict{VarName,Any}()))) @test all(collect(keys(vi)) .== DynamicPPL.TestUtils.varnames(model)) # Ensure `posterior_mean` is implemented. @test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x) @@ -513,7 +495,11 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() # Construct a chain with 'sampled values' of β ground_truth_β = 2 - β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [:β]) + β_chain = MCMCChains.Chains( + rand(Normal(ground_truth_β, 0.002), 1000), + [:β]; + info=(; varname_to_symbol=Dict(@varname(β) => :β)), + ) # Generate predictions from that chain xs_test = [10 + 0.1, 10 + 2 * 0.1] @@ -559,7 +545,9 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "prediction from multiple chains" begin # Normal linreg model multiple_β_chain = MCMCChains.Chains( - reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), [:β] + reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), + [:β]; + info=(; varname_to_symbol=Dict(@varname(β) => :β)), ) predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain) @test size(multiple_β_chain, 3) == size(predictions, 3) @@ -584,43 +572,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end end end - - @testset "with AbstractVector{<:AbstractVarInfo}" begin - @model function linear_reg(x, y, σ=0.1) - β ~ Normal(1, 1) - for i in eachindex(y) - y[i] ~ Normal(β * x[i], σ) - end - end - - ground_truth_β = 2.0 - # the data will be ignored, as we are generating samples from the prior - xs_train = 1:0.1:10 - ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train)) - m_lin_reg = linear_reg(xs_train, ys_train) - chain = [ - last(DynamicPPL.evaluate_and_sample!!(m_lin_reg, VarInfo())) for - _ in 1:10000 - ] - - # chain is generated from the prior - @test mean([chain[i][@varname(β)] for i in eachindex(chain)]) ≈ 1.0 atol = 0.1 - - xs_test = [10 + 0.1, 10 + 2 * 0.1] - m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test))) - predicted_vis = DynamicPPL.predict(m_lin_reg_test, chain) - - @test size(predicted_vis) == size(chain) - @test Set(keys(predicted_vis[1])) == - Set([@varname(β), @varname(y[1]), @varname(y[2])]) - # because β samples are from the prior, the std will be larger - @test mean([ - predicted_vis[i][@varname(y[1])] for i in eachindex(predicted_vis) - ]) ≈ 1.0 * xs_test[1] rtol = 0.1 - @test mean([ - predicted_vis[i][@varname(y[2])] for i in eachindex(predicted_vis) - ]) ≈ 1.0 * xs_test[2] rtol = 0.1 - end end @testset "ProductNamedTupleDistribution sampling" begin diff --git a/test/prefix.jl b/test/prefix.jl new file mode 100644 index 000000000..57065689b --- /dev/null +++ b/test/prefix.jl @@ -0,0 +1,121 @@ +""" +Note that `test/submodel.jl` also contains a number of tests which make use of +prefixing functionality (more like end-to-end tests). This file contains what +are essentially unit tests for prefixing functions. +""" +module DPPLPrefixTests + +using DynamicPPL +# not exported +using DynamicPPL: FixedContext, prefix_cond_and_fixed_variables, childcontext +using Distributions +using Test + +@testset "prefix.jl" begin + @testset "prefix_cond_and_fixed_variables" begin + @testset "ConditionContext" begin + c1 = ConditionContext((c=1, d=2)) + c1_prefixed = prefix_cond_and_fixed_variables(c1, @varname(a)) + @test c1_prefixed isa ConditionContext + @test childcontext(c1_prefixed) isa DefaultContext + @test length(c1_prefixed.values) == 2 + @test c1_prefixed.values[@varname(a.c)] == 1 + @test c1_prefixed.values[@varname(a.d)] == 2 + end + + @testset "FixedContext" begin + c2 = FixedContext((f=1, g=2)) + c2_prefixed = prefix_cond_and_fixed_variables(c2, @varname(a)) + @test c2_prefixed isa FixedContext + @test childcontext(c2_prefixed) isa DefaultContext + @test length(c2_prefixed.values) == 2 + @test c2_prefixed.values[@varname(a.f)] == 1 + @test c2_prefixed.values[@varname(a.g)] == 2 + end + + @testset "Nested ConditionContext and FixedContext" begin + c3 = ConditionContext((c=1, d=2), FixedContext((f=1, g=2))) + c3_prefixed = prefix_cond_and_fixed_variables(c3, @varname(a)) + c3_prefixed_child = childcontext(c3_prefixed) + @test c3_prefixed isa ConditionContext + @test length(c3_prefixed.values) == 2 + @test c3_prefixed.values[@varname(a.c)] == 1 + @test c3_prefixed.values[@varname(a.d)] == 2 + @test c3_prefixed_child isa FixedContext + @test length(c3_prefixed_child.values) == 2 + @test c3_prefixed_child.values[@varname(a.f)] == 1 + @test c3_prefixed_child.values[@varname(a.g)] == 2 + @test childcontext(c3_prefixed_child) isa DefaultContext + end + end + + @testset "DynamicPPL.prefix(::Model, x)" begin + @model function demo() + x ~ Normal() + return y ~ Normal() + end + model = demo() + + @testset "No conditioning / fixing" begin + pmodel = DynamicPPL.prefix(model, @varname(a)) + @test pmodel.prefix == @varname(a) + vi = VarInfo(pmodel) + @test Set(keys(vi)) == Set([@varname(a.x), @varname(a.y)]) + end + + @testset "Prefixing a conditioned model" begin + cmodel = model | (; x=1.0) + # Sanity check. + vi = VarInfo(cmodel) + @test Set(keys(vi)) == Set([@varname(y)]) + # Now prefix. + pcmodel = DynamicPPL.prefix(cmodel, @varname(a)) + @test pcmodel.prefix == @varname(a) + # Because the model was conditioned on `x` _prior_ to prefixing, + # the resulting `a.x` variable should also be conditioned. In + # other words, which variables are treated as conditioned should be + # invariant to prefixing. + vi = VarInfo(pcmodel) + @test Set(keys(vi)) == Set([@varname(a.y)]) + end + + @testset "Prefixing a fixed model" begin + # Same as above but for FixedContext rather than Condition. + fmodel = fix(model, (; y=1.0)) + # Sanity check. + vi = VarInfo(fmodel) + @test Set(keys(vi)) == Set([@varname(x)]) + # Now prefix. + pfmodel = DynamicPPL.prefix(fmodel, @varname(a)) + @test pfmodel.prefix == @varname(a) + # Because the model was conditioned on `x` _prior_ to prefixing, + # the resulting `a.x` variable should also be conditioned. In + # other words, which variables are treated as conditioned should be + # invariant to prefixing. + vi = VarInfo(pfmodel) + @test Set(keys(vi)) == Set([@varname(a.x)]) + end + + @testset "Conditioning a prefixed model" begin + # If the prefixing happens first, then we want to make sure that the + # user is forced to apply conditioning WITH the prefix. + pmodel = DynamicPPL.prefix(model, @varname(a)) + + # If this doesn't happen... + cpmodel_wrong = pmodel | (; x=1.0) + @test cpmodel_wrong.prefix == @varname(a) + vi = VarInfo(cpmodel_wrong) + # Then `a.x` will be `assume`d + @test Set(keys(vi)) == Set([@varname(a.x), @varname(a.y)]) + + # If it does... + cpmodel_right = pmodel | (@varname(a.x) => 1.0) + @test cpmodel_right.prefix == @varname(a) + vi = VarInfo(cpmodel_right) + # Then `a.x` will be `observe`d + @test Set(keys(vi)) == Set([@varname(a.y)]) + end + end +end + +end diff --git a/test/sampler.jl b/test/sampler.jl index fe9fd331a..5b6a623e8 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -1,58 +1,4 @@ @testset "sampler.jl" begin - @testset "SampleFromPrior and SampleUniform" begin - @model function gdemo(x, y) - s ~ InverseGamma(2, 3) - m ~ Normal(2.0, sqrt(s)) - x ~ Normal(m, sqrt(s)) - return y ~ Normal(m, sqrt(s)) - end - - model = gdemo(1.0, 2.0) - N = 1_000 - - chains = sample(model, SampleFromPrior(), N; progress=false) - @test chains isa Vector{<:VarInfo} - @test length(chains) == N - - # Expected value of ``X`` where ``X ~ N(2, ...)`` is 2. - @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.15 - - # Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3. - @test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.2 - - chains = sample(model, SampleFromUniform(), N; progress=false) - @test chains isa Vector{<:VarInfo} - @test length(chains) == N - - # `m` is Gaussian, i.e. no transformation is used, so it - # should have a mean equal to its prior, i.e. 2. - @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.1 - - # Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8. - @test mean(vi[@varname(s)] for vi in chains) ≈ 1.8 atol = 0.1 - end - - @testset "init" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - N = 1000 - chain_init = sample(model, SampleFromUniform(), N; progress=false) - - for vn in keys(first(chain_init)) - if AbstractPPL.subsumes(@varname(s), vn) - # `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2]. - dist = InverseGamma(2, 3) - b = DynamicPPL.link_transform(dist) - @test mean(mean(b(vi[vn])) for vi in chain_init) ≈ 0 atol = 0.11 - elseif AbstractPPL.subsumes(@varname(m), vn) - # `m ~ Normal(0, sqrt(s))` and its constrained value is the same. - @test mean(mean(vi[vn]) for vi in chain_init) ≈ 0 atol = 0.11 - else - error("Unknown variable name: $vn") - end - end - end - end - @testset "Initial parameters" begin # dummy algorithm that just returns initial value and does not perform any sampling abstract type OnlyInitAlg end @@ -69,8 +15,8 @@ end # initial samplers - DynamicPPL.initialsampler(::Sampler{OnlyInitAlgUniform}) = SampleFromUniform() - @test DynamicPPL.initialsampler(Sampler(OnlyInitAlgDefault())) == SampleFromPrior() + DynamicPPL.init_strategy(::Sampler{OnlyInitAlgUniform}) = UniformInit() + @test DynamicPPL.init_strategy(Sampler(OnlyInitAlgDefault())) == PriorInit() for alg in (OnlyInitAlgDefault(), OnlyInitAlgUniform()) # model with one variable: initialization p = 0.2 @@ -81,7 +27,7 @@ model = coinflip() sampler = Sampler(alg) lptrue = logpdf(Binomial(25, 0.2), 10) - let inits = (; p=0.2) + let inits = ParamsInit((; p=0.2)) chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test chain[1].metadata.p.vals == [0.2] @test getlogjoint(chain[1]) == lptrue @@ -109,7 +55,7 @@ end model = twovars() lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) - for inits in ([4, -1], (; s=4, m=-1)) + let inits = ParamsInit((; s=4, m=-1)) chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test chain[1].metadata.s.vals == [4] @test chain[1].metadata.m.vals == [-1] @@ -133,7 +79,7 @@ end # set only m = -1 - for inits in ([missing, -1], (; s=missing, m=-1), (; m=-1)) + for inits in (ParamsInit((; s=missing, m=-1)), ParamsInit((; m=-1))) chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test !ismissing(chain[1].metadata.s.vals[1]) @test chain[1].metadata.m.vals == [-1] @@ -153,54 +99,6 @@ @test c[1].metadata.m.vals == [-1] end end - - # specify `initial_params=nothing` - Random.seed!(1234) - chain1 = sample(model, sampler, 1; progress=false) - Random.seed!(1234) - chain2 = sample(model, sampler, 1; initial_params=nothing, progress=false) - @test_throws DimensionMismatch sample( - model, sampler, 1; progress=false, initial_params=zeros(10) - ) - @test chain1[1].metadata.m.vals == chain2[1].metadata.m.vals - @test chain1[1].metadata.s.vals == chain2[1].metadata.s.vals - - # parallel sampling - Random.seed!(1234) - chains1 = sample(model, sampler, MCMCThreads(), 1, 10; progress=false) - Random.seed!(1234) - chains2 = sample( - model, sampler, MCMCThreads(), 1, 10; initial_params=nothing, progress=false - ) - for (c1, c2) in zip(chains1, chains2) - @test c1[1].metadata.m.vals == c2[1].metadata.m.vals - @test c1[1].metadata.s.vals == c2[1].metadata.s.vals - end - end - - @testset "error handling" begin - # https://github.com/TuringLang/Turing.jl/issues/2452 - @model function constrained_uniform(n) - Z ~ Uniform(10, 20) - X = Vector{Float64}(undef, n) - for i in 1:n - X[i] ~ Uniform(0, Z) - end - end - - n = 2 - initial_z = 15 - initial_x = [0.2, 0.5] - model = constrained_uniform(n) - vi = VarInfo(model) - - @test_throws ArgumentError DynamicPPL.initialize_parameters!!( - vi, [initial_z, initial_x], model - ) - - @test_throws ArgumentError DynamicPPL.initialize_parameters!!( - vi, (X=initial_x, Z=initial_z), model - ) end end end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index be6deb96e..93c7b069e 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -160,7 +160,7 @@ ### Sampling ### # Sample a new varinfo! - _, svi_new = DynamicPPL.evaluate_and_sample!!(model, svi) + _, svi_new = DynamicPPL.init!!(model, svi) # Realization for `m` should be different wp. 1. for vn in DynamicPPL.TestUtils.varnames(model) @@ -228,9 +228,9 @@ # Initialize. svi_nt = DynamicPPL.settrans!!(SimpleVarInfo(), true) - svi_nt = last(DynamicPPL.evaluate_and_sample!!(model, svi_nt)) + svi_nt = last(DynamicPPL.init!!(model, svi_nt)) svi_vnv = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - svi_vnv = last(DynamicPPL.evaluate_and_sample!!(model, svi_vnv)) + svi_vnv = last(DynamicPPL.init!!(model, svi_vnv)) for svi in (svi_nt, svi_vnv) # Sample with large variations in unconstrained space. @@ -275,7 +275,7 @@ ) # Resulting varinfo should no longer be transformed. - vi_result = last(DynamicPPL.evaluate_and_sample!!(model, deepcopy(vi))) + vi_result = last(DynamicPPL.init!!(model, deepcopy(vi))) @test !DynamicPPL.istrans(vi_result) # Set the values to something that is out of domain if we're in constrained space. diff --git a/test/submodels.jl b/test/submodels.jl index 986aea1d0..7463ed0e2 100644 --- a/test/submodels.jl +++ b/test/submodels.jl @@ -135,7 +135,54 @@ end end end - @testset "Nested submodels" begin + @testset "Nested submodels with auto prefix" begin + @model function f() + x ~ Normal() + return y ~ Normal() + end + @model function g() + return b ~ to_submodel(f()) + end + @model function h() + return a ~ to_submodel(g()) + end + + # No conditioning + vi = VarInfo(h()) + @test Set(keys(vi)) == Set([@varname(a.b.x), @varname(a.b.y)]) + @test getlogjoint(vi) == + logpdf(Normal(), vi[@varname(a.b.x)]) + + logpdf(Normal(), vi[@varname(a.b.y)]) + + # Conditioning/fixing at the top level + op_h = op(h(), (@varname(a.b.x) => x_val)) + + # Conditioning/fixing at the second level + op_g = op(g(), (@varname(b.x) => x_val)) + @model function h2() + return a ~ to_submodel(op_g) + end + + # Conditioning/fixing at the very bottom + op_f = op(f(), (@varname(x) => x_val)) + @model function g2() + return _unused ~ to_submodel(prefix(op_f, :b), false) + end + @model function h3() + return a ~ to_submodel(g2()) + end + + models = [("top", op_h), ("middle", h2()), ("bottom", h3())] + @testset "$name" for (name, model) in models + vi = VarInfo(model) + @test Set(keys(vi)) == Set([@varname(a.b.y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)]) + end + end + + @testset "Nested submodels with manual prefix" begin + # Same tests as above, just that the middle layer has manual prefixing + # rather than automatic. @model function f() x ~ Normal() return y ~ Normal() diff --git a/test/test_util.jl b/test/test_util.jl index d5335249d..b7c46ff34 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -81,8 +81,10 @@ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::I varnames = collect(varnames) # Construct matrix of values vals = [get(dict, vn, missing) for dict in dicts, vn in varnames] + # Construct dict of varnames -> symbol + vn_to_sym_dict = Dict(zip(varnames, map(Symbol, varnames))) # Construct and return the Chains object - return Chains(vals, varnames) + return Chains(vals, varnames; info=(; varname_to_symbol=vn_to_sym_dict)) end function make_chain_from_prior(model::Model, n_iters::Int) return make_chain_from_prior(Random.default_rng(), model, n_iters) diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 24a738a78..c86f8da69 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -68,8 +68,7 @@ @time model(vi) # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - sampling_model = contextualize(model, SamplingContext(model.context)) - DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) + DynamicPPL.evaluate_threadsafe!!(model, vi) @test getlogjoint(vi) ≈ lp_w_threads # check that it's wrapped during the model evaluation @test vi_ isa DynamicPPL.ThreadSafeVarInfo @@ -77,7 +76,7 @@ @test vi isa VarInfo println(" evaluate_threadsafe!!:") - @time DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) + @time DynamicPPL.evaluate_threadsafe!!(model, vi) @model function wothreads(x) global vi_ = __varinfo__ @@ -104,13 +103,12 @@ @test lp_w_threads ≈ lp_wo_threads # Ensure that we use `VarInfo`. - sampling_model = contextualize(model, SamplingContext(model.context)) - DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) + DynamicPPL.evaluate_threadunsafe!!(model, vi) @test getlogjoint(vi) ≈ lp_w_threads @test vi_ isa VarInfo @test vi isa VarInfo println(" evaluate_threadunsafe!!:") - @time DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) + @time DynamicPPL.evaluate_threadunsafe!!(model, vi) end end diff --git a/test/varinfo.jl b/test/varinfo.jl index 202ddc1b2..3cc547449 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -42,7 +42,7 @@ end end model = gdemo(1.0, 2.0) - vi = DynamicPPL.untyped_varinfo(model, SampleFromUniform()) + _, vi = DynamicPPL.init!!(model, VarInfo(), UniformInit()) tvi = DynamicPPL.typed_varinfo(vi) meta = vi.metadata @@ -262,7 +262,7 @@ end @test typed_vi[vn_y] == 2.0 end - @testset "setval! & setval_and_resample!" begin + @testset "setval!" begin @model function testmodel(x) n = length(x) s ~ truncated(Normal(); lower=0) @@ -313,8 +313,8 @@ end else DynamicPPL.setval!(vicopy, (m=zeros(5),)) end - # Setting `m` fails for univariate due to limitations of `setval!` - # and `setval_and_resample!`. See docstring of `setval!` for more info. + # Setting `m` fails for univariate due to limitations of `setval!`. + # See docstring of `setval!` for more info. if model == model_uv && vi in [vi_untyped, vi_typed] @test_broken vicopy[m_vns] == zeros(5) else @@ -339,57 +339,6 @@ end DynamicPPL.setval!(vicopy, (s=42,)) @test vicopy[m_vns] == 1:5 @test vicopy[s_vns] == 42 - - ### `setval_and_resample!` ### - if model == model_mv && vi == vi_untyped - # Trying to re-run model with `MvNormal` on `vi_untyped` will call - # `MvNormal(μ::Vector{Real}, Σ)` which causes `StackOverflowError` - # so we skip this particular case. - continue - end - - if vi in [vi_vnv, vi_vnv_typed] - # `setval_and_resample!` works differently for `VarNamedVector`: All - # values will be resampled when model(vicopy) is called. Hence the below - # tests are not applicable. - continue - end - - vicopy = deepcopy(vi) - DynamicPPL.setval_and_resample!(vicopy, (m=zeros(5),)) - model(vicopy) - # Setting `m` fails for univariate due to limitations of `subsumes(::String, ::String)` - if model == model_uv - @test_broken vicopy[m_vns] == zeros(5) - else - @test vicopy[m_vns] == zeros(5) - end - @test vicopy[s_vns] != vi[s_vns] - - # Ordering is NOT preserved. - DynamicPPL.setval_and_resample!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...) - ) - model(vicopy) - if model == model_uv - @test vicopy[m_vns] == 1:5 - else - @test vicopy[m_vns] == [1, 3, 5, 4, 2] - end - @test vicopy[s_vns] != vi[s_vns] - - # Correct ordering. - DynamicPPL.setval_and_resample!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 2, 3, 4, 5))...) - ) - model(vicopy) - @test vicopy[m_vns] == 1:5 - @test vicopy[s_vns] != vi[s_vns] - - DynamicPPL.setval_and_resample!(vicopy, (s=42,)) - model(vicopy) - @test vicopy[m_vns] != 1:5 - @test vicopy[s_vns] == 42 end end @@ -403,9 +352,6 @@ end ks = [@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[1, 2]), @varname(x[2, 2])] DynamicPPL.setval!(vi, vi.metadata.x.vals, ks) @test vals_prev == vi.metadata.x.vals - - DynamicPPL.setval_and_resample!(vi, vi.metadata.x.vals, ks) - @test vals_prev == vi.metadata.x.vals end @testset "setval! on chain" begin @@ -470,17 +416,18 @@ end end model = gdemo([1.0, 1.5], [2.0, 2.5]) - # Check that instantiating the model using SampleFromUniform does not + # Check that instantiating the model using UniformInit does not # perform linking - # Note (penelopeysm): The purpose of using SampleFromUniform (SFU) - # specifically in this test is because SFU samples from the linked - # distribution i.e. in unconstrained space. However, it does this not - # by linking the varinfo but by transforming the distributions on the - # fly. That's why it's worth specifically checking that it can do this - # without having to change the VarInfo object. + # Note (penelopeysm): The purpose of using UniformInit specifically in + # this test is because it samples from the linked distribution i.e. in + # unconstrained space. However, it does this not by linking the varinfo + # but by transforming the distributions on the fly. That's why it's + # worth specifically checking that it can do this without having to + # change the VarInfo object. + # TODO(penelopeysm): Move this to UniformInit tests rather than here. vi = VarInfo() meta = vi.metadata - _, vi = DynamicPPL.evaluate_and_sample!!(model, vi, SampleFromUniform()) + _, vi = DynamicPPL.init!!(model, vi, UniformInit()) @test all(x -> !istrans(vi, x), meta.vns) # Check that linking and invlinking set the `trans` flag accordingly @@ -544,7 +491,7 @@ end function test_linked_varinfo(model, vi) # vn and dist are taken from the containing scope - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + vi = last(DynamicPPL.init!!(model, vi, PriorInit())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test istrans(vi, vn) @@ -555,6 +502,11 @@ end @test getlogprior(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) end + ### `VarInfo` + # Need to run once since we can't specify that we want to _sample_ + # in the unconstrained space for `VarInfo` without having `vn` + # present in the `varinfo`. + ## `untyped_varinfo` vi = DynamicPPL.untyped_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) @@ -565,11 +517,6 @@ end vi = DynamicPPL.settrans!!(vi, true, vn) test_linked_varinfo(model, vi) - ## `typed_varinfo` - vi = DynamicPPL.typed_varinfo(model) - vi = DynamicPPL.settrans!!(vi, true, vn) - test_linked_varinfo(model, vi) - ### `SimpleVarInfo` ## `SimpleVarInfo{<:NamedTuple}` vi = DynamicPPL.settrans!!(SimpleVarInfo(), true) @@ -960,10 +907,9 @@ end end model1 = demo(1) varinfo1 = DynamicPPL.link!!(VarInfo(model1), model1) - # Sampling from `model2` should hit the `istrans(vi) == true` branches - # because all the existing variables are linked. + # Calling init!! should preserve the fact that the variables are linked. model2 = demo(2) - varinfo2 = last(DynamicPPL.evaluate_and_sample!!(model2, deepcopy(varinfo1))) + varinfo2 = last(DynamicPPL.init!!(model2, deepcopy(varinfo1), PriorInit())) for vn in [@varname(x[1]), @varname(x[2])] @test DynamicPPL.istrans(varinfo2, vn) end @@ -979,10 +925,9 @@ end end model1 = demo_dot(1) varinfo1 = DynamicPPL.link!!(DynamicPPL.untyped_varinfo(model1), model1) - # Sampling from `model2` should hit the `istrans(vi) == true` branches - # because all the existing variables are linked. + # Calling init!! should preserve the fact that the variables are linked. model2 = demo_dot(2) - varinfo2 = last(DynamicPPL.evaluate_and_sample!!(model2, deepcopy(varinfo1))) + varinfo2 = last(DynamicPPL.init!!(model2, deepcopy(varinfo1), PriorInit())) for vn in [@varname(x), @varname(y[1])] @test DynamicPPL.istrans(varinfo2, vn) end diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index 57a8175d4..af24be86f 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -610,9 +610,7 @@ end DynamicPPL.TestUtils.test_values(varinfo_eval, value_true, vns) # Is sampling correct? - varinfo_sample = last( - DynamicPPL.evaluate_and_sample!!(model, deepcopy(varinfo)) - ) + varinfo_sample = last(DynamicPPL.init!!(model, deepcopy(varinfo))) # Log density should be different. @test getlogjoint(varinfo_sample) != getlogjoint(varinfo) # Values should be different.