From cab48c6970d59735ddd290285740021460cbd630 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 5 Mar 2025 10:33:25 +0000 Subject: [PATCH 1/2] AbstractPPL 0.11; change prefixing behaviour --- HISTORY.md | 49 +++++++++++++ Project.toml | 2 +- src/DynamicPPL.jl | 5 +- src/compiler.jl | 4 +- src/context_implementations.jl | 6 +- src/contexts.jl | 40 +++++----- src/debug_utils.jl | 2 +- src/model.jl | 58 +++++---------- src/submodel_macro.jl | 16 ++-- src/utils.jl | 8 +- src/values_as_in_model.jl | 2 +- test/Project.toml | 2 +- test/compiler.jl | 9 ++- test/contexts.jl | 130 ++++++++++++++++++--------------- test/deprecated.jl | 2 +- test/model.jl | 2 +- 16 files changed, 191 insertions(+), 146 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 3ea8071f3..cd2757edc 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,54 @@ # DynamicPPL Changelog +## 0.36.0 + +**Breaking changes** + +### VarName prefixing behaviour + +The way in which VarNames in submodels are prefixed has been changed. +This is best explained through an example. +Consider this model and submodel: + +```julia +using DynamicPPL, Distributions +@model inner() = x ~ Normal() +@model outer() = a ~ to_submodel(inner()) +``` + +In previous versions, the inner variable `x` would be saved as `a.x`. +However, this was represented as a single symbol `Symbol("a.x")`: + +```julia +julia> dump(keys(VarInfo(outer()))[1]) +VarName{Symbol("a.x"), typeof(identity)} + optic: identity (function of type typeof(identity)) +``` + +Now, the inner variable is stored as a field `x` on the VarName `a`: + +```julia +julia> dump(keys(VarInfo(outer()))[1]) +VarName{:a, Accessors.PropertyLens{:x}} + optic: Accessors.PropertyLens{:x} (@o _.x) +``` + +In practice, this means that if you are trying to condition a variable in the submodel, you now need to use + +```julia +outer() | (@varname(a.x) => 1.0,) +``` + +instead of either of these (which would have worked previously) + +```julia +outer() | (@varname(var"a.x") => 1.0,) +outer() | (a.x=1.0,) +``` + +If you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain. +(This behaviour will likely be changed in the future, in that Chains should be indexable by VarNames and not just Symbols, but that has not been implemented yet.) + ## 0.35.5 Several internal methods have been removed: diff --git a/Project.toml b/Project.toml index d5185d727..516dee26e 100644 --- a/Project.toml +++ b/Project.toml @@ -44,7 +44,7 @@ DynamicPPLMooncakeExt = ["Mooncake"] [compat] ADTypes = "1" AbstractMCMC = "5" -AbstractPPL = "0.10.1" +AbstractPPL = "0.11" Accessors = "0.1" BangBang = "0.4.1" Bijectors = "0.13.18, 0.14, 0.15" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 50fe0edc7..519a34d58 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -21,6 +21,9 @@ using DocStringExtensions using Random: Random +# For extending +import AbstractPPL: predict, prefix + # TODO: Remove these when it's possible. import Bijectors: link, invlink @@ -39,8 +42,6 @@ import Base: keys, haskey -import AbstractPPL: predict - # VarInfo export AbstractVarInfo, VarInfo, diff --git a/src/compiler.jl b/src/compiler.jl index 95e76778b..e16edc11b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -113,7 +113,7 @@ function contextual_isassumption(context::ConditionContext, vn) return contextual_isassumption(childcontext(context), vn) end function contextual_isassumption(context::PrefixContext, vn) - return contextual_isassumption(childcontext(context), prefix(context, vn)) + return contextual_isassumption(childcontext(context), prefix_with_context(context, vn)) end isfixed(expr, vn) = false @@ -132,7 +132,7 @@ function contextual_isfixed(context::AbstractContext, vn) return contextual_isfixed(NodeTrait(context), context, vn) end function contextual_isfixed(context::PrefixContext, vn) - return contextual_isfixed(childcontext(context), prefix(context, vn)) + return contextual_isfixed(childcontext(context), prefix_with_context(context, vn)) end function contextual_isfixed(context::FixedContext, vn) if hasfixed(context, vn) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e4ba5d252..990fc70c1 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -85,12 +85,14 @@ function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, rig end function tilde_assume(context::PrefixContext, right, vn, vi) - return tilde_assume(context.context, right, prefix(context, vn), vi) + return tilde_assume(context.context, right, prefix_with_context(context, vn), vi) end function tilde_assume( rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi ) - return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), vi) + return tilde_assume( + rng, context.context, sampler, right, prefix_with_context(context, vn), vi + ) end """ diff --git a/src/contexts.jl b/src/contexts.jl index a54c60374..d63f4f1b6 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -260,25 +260,25 @@ function setchildcontext(::PrefixContext{Prefix}, child) where {Prefix} return PrefixContext{Prefix}(child) end -const PREFIX_SEPARATOR = Symbol(".") - -@generated function PrefixContext{PrefixOuter}( - context::PrefixContext{PrefixInner} -) where {PrefixOuter,PrefixInner} - return :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)))}( - context.context - )) -end +""" + prefix_with_context(ctx::AbstractContext, vn::VarName) -function prefix(ctx::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} - vn_prefixed_inner = prefix(childcontext(ctx), vn) - return VarName{Symbol(Prefix, PREFIX_SEPARATOR, getsym(vn_prefixed_inner))}( - getoptic(vn_prefixed_inner) +Apply the prefixes in the context `ctx` to the variable name `vn`. +""" +function prefix_with_context( + ctx::PrefixContext{Prefix}, vn::VarName{Sym} +) where {Prefix,Sym} + return AbstractPPL.prefix( + prefix_with_context(childcontext(ctx), vn), VarName{Symbol(Prefix)}() ) end -prefix(ctx::AbstractContext, vn::VarName) = prefix(NodeTrait(ctx), ctx, vn) -prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn -prefix(::IsParent, ctx::AbstractContext, vn::VarName) = prefix(childcontext(ctx), vn) +function prefix_with_context(ctx::AbstractContext, vn::VarName) + return prefix_with_context(NodeTrait(ctx), ctx, vn) +end +prefix_with_context(::IsLeaf, ::AbstractContext, vn::VarName) = vn +function prefix_with_context(::IsParent, ctx::AbstractContext, vn::VarName) + return prefix_with_context(childcontext(ctx), vn) +end """ prefix(model::Model, x) @@ -392,7 +392,7 @@ function hasconditioned_nested(::IsParent, context, vn) return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) end function hasconditioned_nested(context::PrefixContext, vn) - return hasconditioned_nested(childcontext(context), prefix(context, vn)) + return hasconditioned_nested(childcontext(context), prefix_with_context(context, vn)) end """ @@ -410,7 +410,7 @@ function getconditioned_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getconditioned_nested(context::PrefixContext, vn) - return getconditioned_nested(childcontext(context), prefix(context, vn)) + return getconditioned_nested(childcontext(context), prefix_with_context(context, vn)) end function getconditioned_nested(::IsParent, context, vn) return if hasconditioned(context, vn) @@ -543,7 +543,7 @@ function hasfixed_nested(::IsParent, context, vn) return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn) end function hasfixed_nested(context::PrefixContext, vn) - return hasfixed_nested(childcontext(context), prefix(context, vn)) + return hasfixed_nested(childcontext(context), prefix_with_context(context, vn)) end """ @@ -561,7 +561,7 @@ function getfixed_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getfixed_nested(context::PrefixContext, vn) - return getfixed_nested(childcontext(context), prefix(context, vn)) + return getfixed_nested(childcontext(context), prefix_with_context(context, vn)) end function getfixed_nested(::IsParent, context, vn) return if hasfixed(context, vn) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 328fe6983..78024ec47 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -183,7 +183,7 @@ function DynamicPPL.setchildcontext(context::DebugContext, child) end function record_varname!(context::DebugContext, varname::VarName, dist) - prefixed_varname = prefix(context, varname) + prefixed_varname = DynamicPPL.prefix_with_context(context, varname) if haskey(context.varnames_seen, prefixed_varname) if context.error_on_failure error("varname $prefixed_varname used multiple times in model") diff --git a/src/model.jl b/src/model.jl index a0451b1b6..b4d5f6bb7 100644 --- a/src/model.jl +++ b/src/model.jl @@ -243,7 +243,7 @@ julia> model() ≠ 1.0 true julia> # To condition the variable inside `demo_inner` we need to refer to it as `inner.m`. - conditioned_model = model | (var"inner.m" = 1.0, ); + conditioned_model = model | (@varname(inner.m) => 1.0, ); julia> conditioned_model() 1.0 @@ -255,15 +255,6 @@ julia> conditioned_model_fail() ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported [...] ``` - -And similarly when using `Dict`: - -```jldoctest condition -julia> conditioned_model_dict = model | (@varname(var"inner.m") => 1.0); - -julia> conditioned_model_dict() -1.0 -``` """ function AbstractPPL.condition(model::Model, values...) # Positional arguments - need to handle cases carefully @@ -443,16 +434,16 @@ julia> conditioned(cm) julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed, # `a.m` is treated as a random variable. keys(VarInfo(cm)) -1-element Vector{VarName{Symbol("a.m"), typeof(identity)}}: +1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}: a.m julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation. - cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((var"a.m"=1.0,)))), x=100.0); + cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext(Dict(@varname(a.m) => 1.0)))), x=100.0); -julia> conditioned(cm).x +julia> conditioned(cm)[@varname(x)] 100.0 -julia> conditioned(cm).var"a.m" +julia> conditioned(cm)[@varname(a.m)] 1.0 julia> keys(VarInfo(cm)) # No variables are sampled @@ -583,7 +574,7 @@ julia> model = demo_outer(); julia> model() ≠ 1.0 true -julia> fixed_model = fix(model, var"inner.m" = 1.0, ); +julia> fixed_model = fix(model, (@varname(inner.m) => 1.0, )); julia> fixed_model() 1.0 @@ -599,24 +590,9 @@ julia> fixed_model() 2.0 ``` -And similarly when using `Dict`: - -```jldoctest fix -julia> fixed_model_dict = fix(model, @varname(var"inner.m") => 1.0); - -julia> fixed_model_dict() -1.0 - -julia> fixed_model_dict = fix(model, @varname(inner) => 2.0); - -julia> fixed_model_dict() -2.0 -``` - ## Difference from `condition` -A very similar functionality is also provided by [`condition`](@ref) which, -not surprisingly, _conditions_ variables instead of fixing them. The only +A very similar functionality is also provided by [`condition`](@ref). The only difference between fixing and conditioning is as follows: - `condition`ed variables are considered to be observations, and are thus included in the computation [`logjoint`](@ref) and [`loglikelihood`](@ref), @@ -798,16 +774,16 @@ julia> fixed(cm) julia> # Since we fixed on `m`, not `a.m` as it will appear after prefixed, # `a.m` is treated as a random variable. keys(VarInfo(cm)) -1-element Vector{VarName{Symbol("a.m"), typeof(identity)}}: +1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}: a.m julia> # If we instead fix on `a.m`, `m` in the model will be considered an observation. - cm = fix(contextualize(m, PrefixContext{:a}(fix(var"a.m"=1.0))), x=100.0); + cm = fix(contextualize(m, PrefixContext{:a}(fix(@varname(a.m) => 1.0,))), x=100.0); -julia> fixed(cm).x +julia> fixed(cm)[@varname(x)] 100.0 -julia> fixed(cm).var"a.m" +julia> fixed(cm)[@varname(a.m)] 1.0 julia> keys(VarInfo(cm)) # <= no variables are sampled @@ -1365,7 +1341,7 @@ When we sample from the model `demo2(missing, 0.4)` random variable `x` will be ```jldoctest submodel-to_submodel julia> vi = VarInfo(demo2(missing, 0.4)); -julia> @varname(var\"a.x\") in keys(vi) +julia> @varname(a.x) in keys(vi) true ``` @@ -1379,7 +1355,7 @@ false We can check that the log joint probability of the model accumulated in `vi` is correct: ```jldoctest submodel-to_submodel -julia> x = vi[@varname(var\"a.x\")]; +julia> x = vi[@varname(a.x)]; julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) true @@ -1417,10 +1393,10 @@ julia> @model function demo2(x, y, z) julia> vi = VarInfo(demo2(missing, missing, 0.4)); -julia> @varname(var"sub1.x") in keys(vi) +julia> @varname(sub1.x) in keys(vi) true -julia> @varname(var"sub2.x") in keys(vi) +julia> @varname(sub2.x) in keys(vi) true ``` @@ -1437,9 +1413,9 @@ false We can check that the log joint probability of the model accumulated in `vi` is correct: ```jldoctest submodel-to_submodel-prefix -julia> sub1_x = vi[@varname(var"sub1.x")]; +julia> sub1_x = vi[@varname(sub1.x)]; -julia> sub2_x = vi[@varname(var"sub2.x")]; +julia> sub2_x = vi[@varname(sub2.x)]; julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index e5a8e0617..f6b9c4479 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -96,10 +96,10 @@ julia> vi = VarInfo(demo2(missing, missing, 0.4)); │ caller = ip:0x0 └ @ Core :-1 -julia> @varname(var"sub1.x") in keys(vi) +julia> @varname(sub1.x) in keys(vi) true -julia> @varname(var"sub2.x") in keys(vi) +julia> @varname(sub2.x) in keys(vi) true ``` @@ -116,9 +116,9 @@ false We can check that the log joint probability of the model accumulated in `vi` is correct: ```jldoctest submodelprefix -julia> sub1_x = vi[@varname(var"sub1.x")]; +julia> sub1_x = vi[@varname(sub1.x)]; -julia> sub2_x = vi[@varname(var"sub2.x")]; +julia> sub2_x = vi[@varname(sub2.x)]; julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); @@ -157,7 +157,7 @@ julia> # Automatically determined from `a`. @model submodel_prefix_true() = @submodel prefix=true a = inner() submodel_prefix_true (generic function with 2 methods) -julia> @varname(var"a.x") in keys(VarInfo(submodel_prefix_true())) +julia> @varname(a.x) in keys(VarInfo(submodel_prefix_true())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 @@ -167,7 +167,7 @@ julia> # Using a static string. @model submodel_prefix_string() = @submodel prefix="my prefix" a = inner() submodel_prefix_string (generic function with 2 methods) -julia> @varname(var"my prefix.x") in keys(VarInfo(submodel_prefix_string())) +julia> @varname(var"my prefix".x) in keys(VarInfo(submodel_prefix_string())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 @@ -177,7 +177,7 @@ julia> # Using string interpolation. @model submodel_prefix_interpolation() = @submodel prefix="\$(nameof(inner()))" a = inner() submodel_prefix_interpolation (generic function with 2 methods) -julia> @varname(var"inner.x") in keys(VarInfo(submodel_prefix_interpolation())) +julia> @varname(inner.x) in keys(VarInfo(submodel_prefix_interpolation())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 @@ -187,7 +187,7 @@ julia> # Or using some arbitrary expression. @model submodel_prefix_expr() = @submodel prefix=1 + 2 a = inner() submodel_prefix_expr (generic function with 2 methods) -julia> @varname(var"3.x") in keys(VarInfo(submodel_prefix_expr())) +julia> @varname(var"3".x) in keys(VarInfo(submodel_prefix_expr())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 diff --git a/src/utils.jl b/src/utils.jl index 50f9baf61..56c3d70af 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1285,14 +1285,18 @@ broadcast_safe(x) = x broadcast_safe(x::Distribution) = (x,) broadcast_safe(x::AbstractContext) = (x,) +# Convert (x=1,) to Dict(@varname(x) => 1) +_nt_to_varname_dict(nt) = Dict(VarName{k}() => v for (k, v) in pairs(nt)) # Version of `merge` used by `conditioned` and `fixed` to handle # the scenario where we might try to merge a dict with an empty # tuple. # TODO: Maybe replace the default of returning `NamedTuple` with `nothing`? _merge(left::NamedTuple, right::NamedTuple) = merge(left, right) _merge(left::AbstractDict, right::AbstractDict) = merge(left, right) -_merge(left::AbstractDict, right::NamedTuple{()}) = left -_merge(left::NamedTuple{()}, right::AbstractDict) = right +_merge(left::AbstractDict, ::NamedTuple{()}) = left +_merge(left::AbstractDict, right::NamedTuple) = merge(left, _nt_to_varname_dict(right)) +_merge(::NamedTuple{()}, right::AbstractDict) = right +_merge(left::NamedTuple, right::AbstractDict) = merge(_nt_to_varname_dict(left), right) """ unique_syms(vns::T) where {T<:NTuple{N,VarName}} diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index d3bfd697a..faf4331fb 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -45,7 +45,7 @@ is_extracting_values(::IsParent, ::AbstractContext) = false is_extracting_values(::IsLeaf, ::AbstractContext) = false function Base.push!(context::ValuesAsInModelContext, vn::VarName, value) - return setindex!(context.values, copy(value), prefix(context, vn)) + return setindex!(context.values, copy(value), prefix_with_context(context, vn)) end function broadcast_push!(context::ValuesAsInModelContext, vns, values) diff --git a/test/Project.toml b/test/Project.toml index 9fa3fd872..79e6d129b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -32,7 +32,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1" AbstractMCMC = "5" -AbstractPPL = "0.10.1" +AbstractPPL = "0.11" Accessors = "0.1" Aqua = "0.8" Bijectors = "0.15.1" diff --git a/test/compiler.jl b/test/compiler.jl index 3d3c6d9e3..3b7ebc617 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -481,8 +481,8 @@ module Issue537 end m = demo_useval(missing, missing) vi = VarInfo(m) ks = keys(vi) - @test VarName{Symbol("sub1.x")}() ∈ ks - @test VarName{Symbol("sub2.x")}() ∈ ks + @test @varname(sub1.x) ∈ ks + @test @varname(sub2.x) ∈ ks @test @varname(z) ∈ ks @test abs(mean([VarInfo(m)[@varname(z)] for i in 1:10]) - 100) ≤ 10 @@ -514,8 +514,9 @@ module Issue537 end m = demo(ys) vi = VarInfo(m) - for k in [:α, :μ, :σ, Symbol("ar1_1.η"), Symbol("ar1_2.η")] - @test VarName{k}() ∈ keys(vi) + for vn in + [@varname(α), @varname(μ), @varname(σ), @varname(ar1_1.η), @varname(ar1_2.η)] + @test vn ∈ keys(vi) end end diff --git a/test/contexts.jl b/test/contexts.jl index faa831cc1..5d6e2e49f 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -39,44 +39,39 @@ end Base.IteratorSize(::Type{<:AbstractContext}) = Base.SizeUnknown() Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() -""" - remove_prefix(vn::VarName) - -Return `vn` but now with the prefix removed. -""" -function remove_prefix(vn::VarName) - return VarName{Symbol(split(string(vn), string(DynamicPPL.PREFIX_SEPARATOR))[end])}( - getoptic(vn) +@testset "contexts.jl" begin + child_contexts = Dict( + :default => DefaultContext(), + :prior => PriorContext(), + :likelihood => LikelihoodContext(), ) -end -@testset "contexts.jl" begin - child_contexts = [DefaultContext(), PriorContext(), LikelihoodContext()] - - parent_contexts = [ - DynamicPPL.TestUtils.TestParentContext(DefaultContext()), - SamplingContext(), - MiniBatchContext(DefaultContext(), 0.0), - PrefixContext{:x}(DefaultContext()), - PointwiseLogdensityContext(), - ConditionContext((x=1.0,)), - ConditionContext( + parent_contexts = Dict( + :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), + :sampling => SamplingContext(), + :minibatch => MiniBatchContext(DefaultContext(), 0.0), + :prefix => PrefixContext{:x}(DefaultContext()), + :pointwiselogdensity => PointwiseLogdensityContext(), + :condition1 => ConditionContext((x=1.0,)), + :condition2 => ConditionContext( (x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,))) ), - ConditionContext((x=1.0,), PrefixContext{:a}(ConditionContext((var"a.y"=2.0,)))), - ConditionContext((x=[1.0, missing],)), - ] + :condition3 => ConditionContext( + (x=1.0,), PrefixContext{:a}(ConditionContext(Dict(@varname(a.y) => 2.0))) + ), + :condition4 => ConditionContext((x=[1.0, missing],)), + ) - contexts = vcat(child_contexts, parent_contexts) + contexts = merge(child_contexts, parent_contexts) - @testset "$(context)" for context in contexts + @testset "$(name)" for (name, context) in contexts @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS DynamicPPL.TestUtils.test_context(context, model) end end @testset "contextual_isassumption" begin - @testset "$context" for context in contexts + @testset "$(name)" for (name, context) in contexts # Any `context` should return `true` by default. @test contextual_isassumption(context, VarName{gensym(:x)}()) @@ -85,14 +80,28 @@ end # Let's first extract the conditioned variables. conditioned_values = DynamicPPL.conditioned(context) - for (sym, val) in pairs(conditioned_values) - vn = VarName{sym}() + # The conditioned values might be a NamedTuple, or a Dict. + # We convert to a Dict for consistency + if conditioned_values isa NamedTuple + conditioned_values = Dict( + VarName{sym}() => val for (sym, val) in pairs(conditioned_values) + ) + end + for (vn, val) in pairs(conditioned_values) # We need to drop the prefix of `var` since in `contextual_isassumption` # it will be threaded through the `PrefixContext` before it reaches # `ConditionContext` with the conditioned variable. - vn_without_prefix = remove_prefix(vn) + vn_without_prefix = if getoptic(vn) isa PropertyLens + # Hacky: This assumes that there is exactly one level of prefixing + # that we need to undo. This is appropriate for the :condition3 + # test case above, but is not generally correct. + AbstractPPL.unprefix(vn, VarName{getsym(vn)}()) + else + vn + end + @show DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) # Let's check elementwise. for vn_child in DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) @@ -108,7 +117,7 @@ end end @testset "getconditioned_nested & hasconditioned_nested" begin - @testset "$context" for context in contexts + @testset "$name" for (name, context) in contexts fake_vn = VarName{gensym(:x)}() @test !hasconditioned_nested(context, fake_vn) @test_throws ErrorException getconditioned_nested(context, fake_vn) @@ -118,14 +127,26 @@ end # Let's first extract the conditioned variables. conditioned_values = DynamicPPL.conditioned(context) + # The conditioned values might be a NamedTuple, or a Dict. + # We convert to a Dict for consistency + if conditioned_values isa NamedTuple + conditioned_values = Dict( + VarName{sym}() => val for (sym, val) in pairs(conditioned_values) + ) + end - for (sym, val) in pairs(conditioned_values) - vn = VarName{sym}() - + for (vn, val) in pairs(conditioned_values) # We need to drop the prefix of `var` since in `contextual_isassumption` # it will be threaded through the `PrefixContext` before it reaches # `ConditionContext` with the conditioned variable. - vn_without_prefix = remove_prefix(vn) + vn_without_prefix = if getoptic(vn) isa PropertyLens + # Hacky: This assumes that there is exactly one level of prefixing + # that we need to undo. This is appropriate for the :condition3 + # test case above, but is not generally correct. + AbstractPPL.unprefix(vn, VarName{getsym(vn)}()) + else + vn + end for vn_child in DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) @@ -152,52 +173,43 @@ end ), ) vn = VarName{:x}() - vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) - @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test getoptic(vn_prefixed) === getoptic(vn) + vn_prefixed = @inferred DynamicPPL.prefix_with_context(ctx, vn) + @test vn_prefixed == @varname(a.b.c.d.e.f.x) vn = VarName{:x}(((1,),)) - vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) - @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test getoptic(vn_prefixed) === getoptic(vn) + vn_prefixed = @inferred DynamicPPL.prefix_with_context(ctx, vn) + @test vn_prefixed == @varname(a.b.c.d.e.f.x[1]) end @testset "nested within arbitrary context stacks" begin vn = @varname(x[1]) ctx1 = PrefixContext{:a}(DefaultContext()) + @test DynamicPPL.prefix_with_context(ctx1, vn) == @varname(a.x[1]) ctx2 = SamplingContext(ctx1) + @test DynamicPPL.prefix_with_context(ctx2, vn) == @varname(a.x[1]) ctx3 = PrefixContext{:b}(ctx2) + @test DynamicPPL.prefix_with_context(ctx3, vn) == @varname(b.a.x[1]) ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3) - vn_prefixed1 = prefix(ctx1, vn) - vn_prefixed2 = prefix(ctx2, vn) - vn_prefixed3 = prefix(ctx3, vn) - vn_prefixed4 = prefix(ctx4, vn) - @test DynamicPPL.getsym(vn_prefixed1) == Symbol("a.x") - @test DynamicPPL.getsym(vn_prefixed2) == Symbol("a.x") - @test DynamicPPL.getsym(vn_prefixed3) == Symbol("b.a.x") - @test DynamicPPL.getsym(vn_prefixed4) == Symbol("b.a.x") - @test DynamicPPL.getoptic(vn_prefixed1) === DynamicPPL.getoptic(vn) - @test DynamicPPL.getoptic(vn_prefixed2) === DynamicPPL.getoptic(vn) - @test DynamicPPL.getoptic(vn_prefixed3) === DynamicPPL.getoptic(vn) - @test DynamicPPL.getoptic(vn_prefixed4) === DynamicPPL.getoptic(vn) + @test DynamicPPL.prefix_with_context(ctx4, vn) == @varname(b.a.x[1]) end - context = DynamicPPL.PrefixContext{:prefix}(SamplingContext()) @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + prefix = :my_prefix + context = DynamicPPL.PrefixContext{prefix}(SamplingContext()) # Sample with the context. varinfo = DynamicPPL.VarInfo() DynamicPPL.evaluate!!(model, varinfo, context) - # Extract the resulting symbols. - vns_varinfo_syms = Set(map(DynamicPPL.getsym, keys(varinfo))) + # Extract the resulting varnames + vns_actual = Set(keys(varinfo)) - # Extract the ground truth symbols. - vns_syms = Set([ - Symbol("prefix", DynamicPPL.PREFIX_SEPARATOR, DynamicPPL.getsym(vn)) for + # Extract the ground truth varnames + vns_expected = Set([ + AbstractPPL.prefix(vn, VarName{prefix}()) for vn in DynamicPPL.TestUtils.varnames(model) ]) # Check that all variables are prefixed correctly. - @test vns_syms == vns_varinfo_syms + @test vns_actual == vns_expected end end diff --git a/test/deprecated.jl b/test/deprecated.jl index f12217983..500d3eb7f 100644 --- a/test/deprecated.jl +++ b/test/deprecated.jl @@ -31,7 +31,7 @@ @test outer()() isa Tuple{Float64,Float64} vi = VarInfo(outer()) @test @varname(x) in keys(vi) - @test @varname(var"sub.x") in keys(vi) + @test @varname(sub.x) in keys(vi) end @testset "logp is still accumulated properly" begin diff --git a/test/model.jl b/test/model.jl index a863b6596..c884e9393 100644 --- a/test/model.jl +++ b/test/model.jl @@ -456,7 +456,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() for model in (outer_auto_prefix(), outer_manual_prefix()) vi = VarInfo(model) vns = Set(keys(values_as_in_model(model, false, vi))) - @test vns == Set([@varname(var"a.x"), @varname(var"b.x")]) + @test vns == Set([@varname(a.x), @varname(b.x)]) end end end From 68e8c8e5755bd0cdc07192c963cd3c9cd1d72336 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 28 Mar 2025 15:26:02 +0000 Subject: [PATCH 2/2] Use DynamicPPL.prefix rather than overloading --- docs/Project.toml | 1 + docs/src/api.md | 2 +- src/DynamicPPL.jl | 2 +- src/compiler.jl | 4 ++-- src/context_implementations.jl | 6 ++---- src/contexts.jl | 28 ++++++++++++---------------- src/debug_utils.jl | 2 +- src/values_as_in_model.jl | 2 +- test/compiler.jl | 2 +- test/contexts.jl | 12 ++++++------ test/debug_utils.jl | 4 ++-- test/model.jl | 4 ++-- 12 files changed, 32 insertions(+), 37 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index fa57f2c1c..40a719e03 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,6 +4,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterMermaid = "a078cd44-4d9c-4618-b545-3ab9d77f9177" +DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" diff --git a/docs/src/api.md b/docs/src/api.md index 9c8249c97..2f6376f5d 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -149,7 +149,7 @@ In the past, one would instead embed sub-models using [`@submodel`](@ref), which In the context of including models within models, it's also useful to prefix the variables in sub-models to avoid variable names clashing: ```@docs -prefix +DynamicPPL.prefix ``` Under the hood, [`to_submodel`](@ref) makes use of the following method to indicate that the model it's wrapping is a model over its return-values rather than something else diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 519a34d58..9f45718c5 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -22,7 +22,7 @@ using DocStringExtensions using Random: Random # For extending -import AbstractPPL: predict, prefix +import AbstractPPL: predict # TODO: Remove these when it's possible. import Bijectors: link, invlink diff --git a/src/compiler.jl b/src/compiler.jl index e16edc11b..95e76778b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -113,7 +113,7 @@ function contextual_isassumption(context::ConditionContext, vn) return contextual_isassumption(childcontext(context), vn) end function contextual_isassumption(context::PrefixContext, vn) - return contextual_isassumption(childcontext(context), prefix_with_context(context, vn)) + return contextual_isassumption(childcontext(context), prefix(context, vn)) end isfixed(expr, vn) = false @@ -132,7 +132,7 @@ function contextual_isfixed(context::AbstractContext, vn) return contextual_isfixed(NodeTrait(context), context, vn) end function contextual_isfixed(context::PrefixContext, vn) - return contextual_isfixed(childcontext(context), prefix_with_context(context, vn)) + return contextual_isfixed(childcontext(context), prefix(context, vn)) end function contextual_isfixed(context::FixedContext, vn) if hasfixed(context, vn) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 990fc70c1..e4ba5d252 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -85,14 +85,12 @@ function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, rig end function tilde_assume(context::PrefixContext, right, vn, vi) - return tilde_assume(context.context, right, prefix_with_context(context, vn), vi) + return tilde_assume(context.context, right, prefix(context, vn), vi) end function tilde_assume( rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi ) - return tilde_assume( - rng, context.context, sampler, right, prefix_with_context(context, vn), vi - ) + return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), vi) end """ diff --git a/src/contexts.jl b/src/contexts.jl index d63f4f1b6..58ac612b8 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -261,23 +261,19 @@ function setchildcontext(::PrefixContext{Prefix}, child) where {Prefix} end """ - prefix_with_context(ctx::AbstractContext, vn::VarName) + prefix(ctx::AbstractContext, vn::VarName) Apply the prefixes in the context `ctx` to the variable name `vn`. """ -function prefix_with_context( - ctx::PrefixContext{Prefix}, vn::VarName{Sym} -) where {Prefix,Sym} - return AbstractPPL.prefix( - prefix_with_context(childcontext(ctx), vn), VarName{Symbol(Prefix)}() - ) +function prefix(ctx::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} + return AbstractPPL.prefix(prefix(childcontext(ctx), vn), VarName{Symbol(Prefix)}()) end -function prefix_with_context(ctx::AbstractContext, vn::VarName) - return prefix_with_context(NodeTrait(ctx), ctx, vn) +function prefix(ctx::AbstractContext, vn::VarName) + return prefix(NodeTrait(ctx), ctx, vn) end -prefix_with_context(::IsLeaf, ::AbstractContext, vn::VarName) = vn -function prefix_with_context(::IsParent, ctx::AbstractContext, vn::VarName) - return prefix_with_context(childcontext(ctx), vn) +prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn +function prefix(::IsParent, ctx::AbstractContext, vn::VarName) + return prefix(childcontext(ctx), vn) end """ @@ -392,7 +388,7 @@ function hasconditioned_nested(::IsParent, context, vn) return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) end function hasconditioned_nested(context::PrefixContext, vn) - return hasconditioned_nested(childcontext(context), prefix_with_context(context, vn)) + return hasconditioned_nested(childcontext(context), prefix(context, vn)) end """ @@ -410,7 +406,7 @@ function getconditioned_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getconditioned_nested(context::PrefixContext, vn) - return getconditioned_nested(childcontext(context), prefix_with_context(context, vn)) + return getconditioned_nested(childcontext(context), prefix(context, vn)) end function getconditioned_nested(::IsParent, context, vn) return if hasconditioned(context, vn) @@ -543,7 +539,7 @@ function hasfixed_nested(::IsParent, context, vn) return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn) end function hasfixed_nested(context::PrefixContext, vn) - return hasfixed_nested(childcontext(context), prefix_with_context(context, vn)) + return hasfixed_nested(childcontext(context), prefix(context, vn)) end """ @@ -561,7 +557,7 @@ function getfixed_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getfixed_nested(context::PrefixContext, vn) - return getfixed_nested(childcontext(context), prefix_with_context(context, vn)) + return getfixed_nested(childcontext(context), prefix(context, vn)) end function getfixed_nested(::IsParent, context, vn) return if hasfixed(context, vn) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 78024ec47..529092e8e 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -183,7 +183,7 @@ function DynamicPPL.setchildcontext(context::DebugContext, child) end function record_varname!(context::DebugContext, varname::VarName, dist) - prefixed_varname = DynamicPPL.prefix_with_context(context, varname) + prefixed_varname = DynamicPPL.prefix(context, varname) if haskey(context.varnames_seen, prefixed_varname) if context.error_on_failure error("varname $prefixed_varname used multiple times in model") diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index faf4331fb..d3bfd697a 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -45,7 +45,7 @@ is_extracting_values(::IsParent, ::AbstractContext) = false is_extracting_values(::IsLeaf, ::AbstractContext) = false function Base.push!(context::ValuesAsInModelContext, vn::VarName, value) - return setindex!(context.values, copy(value), prefix_with_context(context, vn)) + return setindex!(context.values, copy(value), prefix(context, vn)) end function broadcast_push!(context::ValuesAsInModelContext, vns, values) diff --git a/test/compiler.jl b/test/compiler.jl index 3b7ebc617..a0286d405 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -505,7 +505,7 @@ module Issue537 end num_steps = length(y[1]) num_obs = length(y) @inbounds for i in 1:num_obs - x ~ to_submodel(prefix(AR1(num_steps, α, μ, σ), "ar1_$i"), false) + x ~ to_submodel(DynamicPPL.prefix(AR1(num_steps, α, μ, σ), "ar1_$i"), false) y[i] ~ MvNormal(x, 0.01 * I) end end diff --git a/test/contexts.jl b/test/contexts.jl index 5d6e2e49f..11e591f8f 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -173,24 +173,24 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() ), ) vn = VarName{:x}() - vn_prefixed = @inferred DynamicPPL.prefix_with_context(ctx, vn) + vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) @test vn_prefixed == @varname(a.b.c.d.e.f.x) vn = VarName{:x}(((1,),)) - vn_prefixed = @inferred DynamicPPL.prefix_with_context(ctx, vn) + vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) @test vn_prefixed == @varname(a.b.c.d.e.f.x[1]) end @testset "nested within arbitrary context stacks" begin vn = @varname(x[1]) ctx1 = PrefixContext{:a}(DefaultContext()) - @test DynamicPPL.prefix_with_context(ctx1, vn) == @varname(a.x[1]) + @test DynamicPPL.prefix(ctx1, vn) == @varname(a.x[1]) ctx2 = SamplingContext(ctx1) - @test DynamicPPL.prefix_with_context(ctx2, vn) == @varname(a.x[1]) + @test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1]) ctx3 = PrefixContext{:b}(ctx2) - @test DynamicPPL.prefix_with_context(ctx3, vn) == @varname(b.a.x[1]) + @test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1]) ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3) - @test DynamicPPL.prefix_with_context(ctx4, vn) == @varname(b.a.x[1]) + @test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1]) end @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS diff --git a/test/debug_utils.jl b/test/debug_utils.jl index d4f6601f5..cac52693e 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -63,8 +63,8 @@ # With manual prefixing, https://github.com/TuringLang/DynamicPPL.jl/issues/785 @model function ModelOuterWorking2() - x1 ~ to_submodel(prefix(ModelInner(), :a), false) - x2 ~ to_submodel(prefix(ModelInner(), :b), false) + x1 ~ to_submodel(DynamicPPL.prefix(ModelInner(), :a), false) + x2 ~ to_submodel(DynamicPPL.prefix(ModelInner(), :b), false) return (x1, x2) end model = ModelOuterWorking2() diff --git a/test/model.jl b/test/model.jl index c884e9393..447a9ecaa 100644 --- a/test/model.jl +++ b/test/model.jl @@ -448,8 +448,8 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() return nothing end @model function outer_manual_prefix() - a ~ to_submodel(prefix(inner(), :a), false) - b ~ to_submodel(prefix(inner(), :b), false) + a ~ to_submodel(DynamicPPL.prefix(inner(), :a), false) + b ~ to_submodel(DynamicPPL.prefix(inner(), :b), false) return nothing end