diff --git a/HISTORY.md b/HISTORY.md index 17b0b2611..dd62ed691 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,7 +4,7 @@ **Breaking changes** -### Submodels +### Submodels: conditioning Variables in a submodel can now be conditioned and fixed in a correct way. See https://github.com/TuringLang/DynamicPPL.jl/issues/857 for a full illustration, but essentially it means you can now do this: @@ -22,38 +22,7 @@ end and the `inner.x` variable will be correctly conditioned. (Previously, you would have to condition `inner()` with the variable `a.x`, meaning that you would need to know what prefix to use before you had actually prefixed it.) -### AD testing utilities - -`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default. -To disable this, pass the `linked=false` keyword argument. -If the calculated value or gradient is incorrect, it also throws a `DynamicPPL.TestUtils.AD.ADIncorrectException` rather than a test failure. -This exception contains the actual and expected gradient so you can inspect it if needed; see the documentation for more information. -From a practical perspective, this means that if you need to add this to a test suite, you need to use `@test run_ad(...) isa Any` rather than just `run_ad(...)`. - -### SimpleVarInfo linking / invlinking - -Linking a linked SimpleVarInfo, or invlinking an unlinked SimpleVarInfo, now displays a warning instead of an error. - -### VarInfo constructors - -`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead. - -The `metadata` argument to `VarInfo([rng, ]model[, sampler, context, metadata])` has been removed. -If you were not using this argument (most likely), then there is no change needed. -If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `DynamicPPL.typed_vector_varinfo` instead (see 'Other changes' below). - -The `UntypedVarInfo` constructor and type is no longer exported. -If you needed to construct one, you should now use `DynamicPPL.untyped_varinfo` instead. - -The `TypedVarInfo` constructor and type is no longer exported. -The _type_ has been replaced with `DynamicPPL.NTVarInfo`. -The _constructor_ has been replaced with `DynamicPPL.typed_varinfo`. - -Note that the exact kind of VarInfo returned by `VarInfo(rng, model, ...)` is an implementation detail. -Previously, it was guaranteed that this would always be a VarInfo whose metadata was a `NamedTuple` containing `Metadata` structs. -Going forward, this is no longer the case, and you should only assume that the returned object obeys the `AbstractVarInfo` interface. - -### VarName prefixing behaviour +### Submodel prefixing The way in which VarNames in submodels are prefixed has been changed. This is best explained through an example. @@ -95,9 +64,62 @@ outer() | (@varname(var"a.x") => 1.0,) outer() | (a.x=1.0,) ``` -If you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain. +In a similar way, if the variable on the left-hand side of your tilde statement is not just a single identifier, any fields or indices it accesses are now properly respected. +Consider the following setup: + +```julia +using DynamicPPL, Distributions +@model inner() = x ~ Normal() +@model function outer() + a = Vector{Float64}(undef, 1) + a[1] ~ to_submodel(inner()) + return a +end +``` + +In this case, the variable sampled is actually the `x` field of the first element of `a`: + +```julia +julia> only(keys(VarInfo(outer()))) == @varname(a[1].x) +true +``` + +Before this version, it used to be a single variable called `var"a[1].x"`. + +Note that if you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain. (This behaviour will likely be changed in the future, in that Chains should be indexable by VarNames and not just Symbols, but that has not been implemented yet.) +### AD testing utilities + +`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default. +To disable this, pass the `linked=false` keyword argument. +If the calculated value or gradient is incorrect, it also throws a `DynamicPPL.TestUtils.AD.ADIncorrectException` rather than a test failure. +This exception contains the actual and expected gradient so you can inspect it if needed; see the documentation for more information. +From a practical perspective, this means that if you need to add this to a test suite, you need to use `@test run_ad(...) isa Any` rather than just `run_ad(...)`. + +### SimpleVarInfo linking / invlinking + +Linking a linked SimpleVarInfo, or invlinking an unlinked SimpleVarInfo, now displays a warning instead of an error. + +### VarInfo constructors + +`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead. + +The `metadata` argument to `VarInfo([rng, ]model[, sampler, context, metadata])` has been removed. +If you were not using this argument (most likely), then there is no change needed. +If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `DynamicPPL.typed_vector_varinfo` instead (see 'Other changes' below). + +The `UntypedVarInfo` constructor and type is no longer exported. +If you needed to construct one, you should now use `DynamicPPL.untyped_varinfo` instead. + +The `TypedVarInfo` constructor and type is no longer exported. +The _type_ has been replaced with `DynamicPPL.NTVarInfo`. +The _constructor_ has been replaced with `DynamicPPL.typed_varinfo`. + +Note that the exact kind of VarInfo returned by `VarInfo(rng, model, ...)` is an implementation detail. +Previously, it was guaranteed that this would always be a VarInfo whose metadata was a `NamedTuple` containing `Metadata` structs. +Going forward, this is no longer the case, and you should only assume that the returned object obeys the `AbstractVarInfo` interface. + **Other changes** While these are technically breaking, they are only internal changes and do not affect the public API. diff --git a/docs/src/internals/submodel_condition.md b/docs/src/internals/submodel_condition.md index 042a0f77a..f01d24d68 100644 --- a/docs/src/internals/submodel_condition.md +++ b/docs/src/internals/submodel_condition.md @@ -181,10 +181,10 @@ Putting all of the information so far together, what it means is that if we have using DynamicPPL: PrefixContext, ConditionContext, DefaultContext inner_ctx_with_outer_cond = ConditionContext( - Dict(@varname(a.x) => 1.0), PrefixContext{:a}(DefaultContext()) + Dict(@varname(a.x) => 1.0), PrefixContext(@varname(a)) ) -inner_ctx_with_inner_cond = PrefixContext{:a}( - ConditionContext(Dict(@varname(x) => 1.0), DefaultContext()) +inner_ctx_with_inner_cond = PrefixContext( + @varname(a), ConditionContext(Dict(@varname(x) => 1.0)) ) ``` @@ -252,10 +252,11 @@ The general strategy that we adopt is similar to above. Following the principle that `PrefixContext` should be nested inside the outer context, but outside the inner submodel's context, we can infer that the correct context inside `charlie` should be: ```@example -big_ctx = PrefixContext{:a}( +big_ctx = PrefixContext( + @varname(a), ConditionContext( Dict(@varname(b.y) => 1.0), - PrefixContext{:b}(ConditionContext(Dict(@varname(x) => 1.0))), + PrefixContext(@varname(b), ConditionContext(Dict(@varname(x) => 1.0))), ), ) ``` @@ -280,9 +281,9 @@ end function myprefix(::IsParent, ctx::AbstractContext, vn::VarName) return myprefix(childcontext(ctx), vn) end -function myprefix(ctx::DynamicPPL.PrefixContext{Prefix}, vn::VarName) where {Prefix} +function myprefix(ctx::DynamicPPL.PrefixContext, vn::VarName) # The functionality to actually manipulate the VarNames is in AbstractPPL - new_vn = AbstractPPL.prefix(vn, VarName{Prefix}()) + new_vn = AbstractPPL.prefix(vn, ctx.vn_prefix) # Then pass to the child context return myprefix(childcontext(ctx), new_vn) end @@ -295,11 +296,11 @@ This implementation clearly is not correct, because it applies the _inner_ `Pref The right way to implement `myprefix` is to, essentially, reverse the order of two lines above: ```@example -function myprefix(ctx::DynamicPPL.PrefixContext{Prefix}, vn::VarName) where {Prefix} +function myprefix(ctx::DynamicPPL.PrefixContext, vn::VarName) # Pass to the child context first new_vn = myprefix(childcontext(ctx), vn) # Then apply this context's prefix - return AbstractPPL.prefix(new_vn, VarName{Prefix}()) + return AbstractPPL.prefix(new_vn, ctx.vn_prefix) end myprefix(big_ctx, @varname(x)) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e9fad54b6..eb025dec8 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -131,7 +131,7 @@ function tilde_assume!!(context, right, vn, vi) # change in the future. if should_auto_prefix(right) dppl_model = right.model.model # This isa DynamicPPL.Model - prefixed_submodel_context = PrefixContext{Symbol(vn)}(dppl_model.context) + prefixed_submodel_context = PrefixContext(vn, dppl_model.context) new_dppl_model = contextualize(dppl_model, prefixed_submodel_context) right = to_submodel(new_dppl_model, true) end diff --git a/src/contexts.jl b/src/contexts.jl index dbd6f9b23..8ac085663 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -237,27 +237,34 @@ function setchildcontext(parent::MiniBatchContext, child) end """ - PrefixContext{Prefix}(context) + PrefixContext(vn::VarName[, context::AbstractContext]) + PrefixContext(vn::Val{sym}[, context::AbstractContext]) where {sym} Create a context that allows you to use the wrapped `context` when running the model and -adds the `Prefix` to all parameters. +prefixes all parameters with the VarName `vn`. + +`PrefixContext(Val(:a), context)` is equivalent to `PrefixContext(@varname(a), context)`. +If `context` is not provided, it defaults to `DefaultContext()`. This context is useful in nested models to ensure that the names of the parameters are unique. See also: [`to_submodel`](@ref) """ -struct PrefixContext{Prefix,C} <: AbstractContext +struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext + vn_prefix::Tvn context::C end -function PrefixContext{Prefix}(context::AbstractContext) where {Prefix} - return PrefixContext{Prefix,typeof(context)}(context) +PrefixContext(vn::VarName) = PrefixContext(vn, DefaultContext()) +function PrefixContext(::Val{sym}, context::AbstractContext) where {sym} + return PrefixContext(VarName{sym}(), context) end +PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}()) NodeTrait(::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context -function setchildcontext(::PrefixContext{Prefix}, child) where {Prefix} - return PrefixContext{Prefix}(child) +function setchildcontext(ctx::PrefixContext, child::AbstractContext) + return PrefixContext(ctx.vn_prefix, child) end """ @@ -265,8 +272,8 @@ end Apply the prefixes in the context `ctx` to the variable name `vn`. """ -function prefix(ctx::PrefixContext{Prefix}, vn::VarName) where {Prefix} - return AbstractPPL.prefix(prefix(childcontext(ctx), vn), VarName{Prefix}()) +function prefix(ctx::PrefixContext, vn::VarName) + return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix) end function prefix(ctx::AbstractContext, vn::VarName) return prefix(NodeTrait(ctx), ctx, vn) @@ -295,14 +302,13 @@ not_ need to modify any inner `ConditionContext`s and `FixedContext`s. If you _do_ need to modify them, then you may need to use `prefix_cond_and_fixed_variables` instead. """ -function prefix_and_strip_contexts(ctx::PrefixContext{Prefix}, vn::VarName) where {Prefix} +function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) child_context = childcontext(ctx) # vn_prefixed contains the prefixes from all lower levels vn_prefixed, child_context_without_prefixes = prefix_and_strip_contexts( child_context, vn ) - return AbstractPPL.prefix(vn_prefixed, VarName{Prefix}()), - child_context_without_prefixes + return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes end function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn) @@ -314,11 +320,16 @@ function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName end """ - prefix(model::Model, x) - -Return `model` but with all random variables prefixed by `x`. + prefix(model::Model, x::VarName) + prefix(model::Model, x::Val{sym}) + prefix(model::Model, x::Any) -If `x` is known at compile-time, use `Val{x}()` to avoid runtime overheads for prefixing. +Return `model` but with all random variables prefixed by `x`, where `x` is either: +- a `VarName` (e.g. `@varname(a)`), +- a `Val{sym}` (e.g. `Val(:a)`), or +- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that + this will introduce runtime overheads so is not recommended unless absolutely + necessary. # Examples @@ -328,17 +339,19 @@ julia> using DynamicPPL: prefix julia> @model demo() = x ~ Dirac(1) demo (generic function with 2 methods) -julia> rand(prefix(demo(), :my_prefix)) +julia> rand(prefix(demo(), @varname(my_prefix))) (var"my_prefix.x" = 1,) -julia> # One can also use `Val` to avoid runtime overheads. - rand(prefix(demo(), Val(:my_prefix))) +julia> rand(prefix(demo(), Val(:my_prefix))) (var"my_prefix.x" = 1,) ``` """ -prefix(model::Model, x) = contextualize(model, PrefixContext{Symbol(x)}(model.context)) -function prefix(model::Model, ::Val{x}) where {x} - return contextualize(model, PrefixContext{Symbol(x)}(model.context)) +prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) +function prefix(model::Model, x::Val{sym}) where {sym} + return contextualize(model, PrefixContext(VarName{sym}(), model.context)) +end +function prefix(model::Model, x) + return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) end """ @@ -426,7 +439,7 @@ hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn) function hasconditioned_nested(::IsParent, context, vn) return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) end -function hasconditioned_nested(context::PrefixContext{Prefix}, vn) where {Prefix} +function hasconditioned_nested(context::PrefixContext, vn) return hasconditioned_nested(collapse_prefix_stack(context), vn) end @@ -444,7 +457,7 @@ end function getconditioned_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end -function getconditioned_nested(context::PrefixContext{Prefix}, vn) where {Prefix} +function getconditioned_nested(context::PrefixContext, vn) return getconditioned_nested(collapse_prefix_stack(context), vn) end function getconditioned_nested(::IsParent, context, vn) @@ -715,13 +728,13 @@ which explains this in much more detail. ```jldoctest julia> using DynamicPPL: collapse_prefix_stack -julia> c1 = PrefixContext{:a}(ConditionContext((x=1, ))); +julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, ))); julia> collapse_prefix_stack(c1) ConditionContext(Dict(a.x => 1), DefaultContext()) julia> # Here, `x` gets prefixed only with `a`, whereas `y` is prefixed with both. - c2 = PrefixContext{:a}(ConditionContext((x=1, ), PrefixContext{:b}(ConditionContext((y=2,))))); + c2 = PrefixContext(@varname(a), ConditionContext((x=1, ), PrefixContext(@varname(b), ConditionContext((y=2,))))); julia> collapsed = collapse_prefix_stack(c2); @@ -733,14 +746,14 @@ julia> # `collapsed` really looks something like this: (1, 2) ``` """ -function collapse_prefix_stack(context::PrefixContext{Prefix}) where {Prefix} +function collapse_prefix_stack(context::PrefixContext) # Collapse the child context (thus applying any inner prefixes first) collapsed = collapse_prefix_stack(childcontext(context)) # Prefix any conditioned variables with the current prefix # Note: prefix_conditioned_variables is O(N) in the depth of the context stack. # So is this function. In the worst case scenario, this is O(N^2) in the # depth of the context stack. - return prefix_cond_and_fixed_variables(collapsed, VarName{Prefix}()) + return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix) end function collapse_prefix_stack(context::AbstractContext) return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context) diff --git a/src/model.jl b/src/model.jl index cfe87ad44..c7c4bdf57 100644 --- a/src/model.jl +++ b/src/model.jl @@ -429,7 +429,7 @@ julia> # Nested ones also work. # (Note that `PrefixContext` also prefixes the variables of any # ConditionContext that is _inside_ it; because of this, the type of the # container has to be broadened to a `Dict`.) - cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((m=1.0,)))), x=100.0); + cm = condition(contextualize(m, PrefixContext(@varname(a), ConditionContext((m=1.0,)))), x=100.0); julia> Set(keys(conditioned(cm))) == Set([@varname(a.m), @varname(x)]) true @@ -441,7 +441,7 @@ julia> # Since we conditioned on `a.m`, it is not treated as a random variable. a.x julia> # We can also condition on `a.m` _outside_ of the PrefixContext: - cm = condition(contextualize(m, PrefixContext{:a}(DefaultContext())), (@varname(a.m) => 1.0)); + cm = condition(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0)); julia> conditioned(cm) Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: @@ -769,7 +769,7 @@ julia> # Returns all the variables we have fixed on + their values. (x = 100.0, m = 1.0) julia> # The rest of this is the same as the `condition` example above. - cm = fix(contextualize(m, PrefixContext{:a}(fix(m=1.0))), x=100.0); + cm = fix(contextualize(m, PrefixContext(@varname(a), fix(m=1.0))), x=100.0); julia> Set(keys(fixed(cm))) == Set([@varname(a.m), @varname(x)]) true @@ -779,7 +779,7 @@ julia> keys(VarInfo(cm)) a.x julia> # We can also condition on `a.m` _outside_ of the PrefixContext: - cm = fix(contextualize(m, PrefixContext{:a}(DefaultContext())), (@varname(a.m) => 1.0)); + cm = fix(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0)); julia> fixed(cm) Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index f6b9c4479..5f1ec95ec 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -223,12 +223,12 @@ end prefix_submodel_context(prefix, left, ctx) = prefix_submodel_context(prefix, ctx) function prefix_submodel_context(prefix, ctx) # E.g. `prefix="asd[$i]"` or `prefix=asd` with `asd` to be evaluated. - return :($(PrefixContext){$(Symbol)($(esc(prefix)))}($ctx)) + return :($(PrefixContext)($(Val)($(Symbol)($(esc(prefix)))), $ctx)) end function prefix_submodel_context(prefix::Union{AbstractString,Symbol}, ctx) # E.g. `prefix="asd"`. - return :($(PrefixContext){$(esc(Meta.quot(Symbol(prefix))))}($ctx)) + return :($(PrefixContext)($(esc(Meta.quot(Val(Symbol(prefix))))), $ctx)) end function prefix_submodel_context(prefix::Bool, ctx) diff --git a/test/contexts.jl b/test/contexts.jl index 081e59775..91ca62a12 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -57,14 +57,15 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), :sampling => SamplingContext(), :minibatch => MiniBatchContext(DefaultContext(), 0.0), - :prefix => PrefixContext{:x}(DefaultContext()), + :prefix => PrefixContext(@varname(x)), :pointwiselogdensity => PointwiseLogdensityContext(), :condition1 => ConditionContext((x=1.0,)), :condition2 => ConditionContext( (x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,))) ), :condition3 => ConditionContext( - (x=1.0,), PrefixContext{:a}(ConditionContext(Dict(@varname(y) => 2.0))) + (x=1.0,), + PrefixContext(@varname(a), ConditionContext(Dict(@varname(y) => 2.0))), ), :condition4 => ConditionContext((x=[1.0, missing],)), ) @@ -132,31 +133,37 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "PrefixContext" begin @testset "prefixing" begin - ctx = @inferred PrefixContext{:a}( - PrefixContext{:b}( - PrefixContext{:c}( - PrefixContext{:d}( - PrefixContext{:e}(PrefixContext{:f}(DefaultContext())) + ctx = @inferred PrefixContext( + @varname(a), + PrefixContext( + @varname(b), + PrefixContext( + @varname(c), + PrefixContext( + @varname(d), + PrefixContext( + @varname(e), PrefixContext(@varname(f), DefaultContext()) + ), ), ), ), ) - vn = VarName{:x}() + vn = @varname(x) vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) @test vn_prefixed == @varname(a.b.c.d.e.f.x) - vn = VarName{:x}(((1,),)) + vn = @varname(x[1]) vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) @test vn_prefixed == @varname(a.b.c.d.e.f.x[1]) end @testset "nested within arbitrary context stacks" begin vn = @varname(x[1]) - ctx1 = PrefixContext{:a}(DefaultContext()) + ctx1 = PrefixContext(@varname(a)) @test DynamicPPL.prefix(ctx1, vn) == @varname(a.x[1]) ctx2 = SamplingContext(ctx1) @test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1]) - ctx3 = PrefixContext{:b}(ctx2) + ctx3 = PrefixContext(@varname(b), ctx2) @test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1]) ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3) @test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1]) @@ -164,30 +171,30 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "prefix_and_strip_contexts" begin vn = @varname(x[1]) - ctx1 = PrefixContext{:a}(DefaultContext()) + ctx1 = PrefixContext(@varname(a)) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx1, vn) @test new_vn == @varname(a.x[1]) @test new_ctx == DefaultContext() - ctx2 = SamplingContext(PrefixContext{:a}(DefaultContext())) + ctx2 = SamplingContext(PrefixContext(@varname(a))) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx2, vn) @test new_vn == @varname(a.x[1]) @test new_ctx == SamplingContext() - ctx3 = PrefixContext{:a}(ConditionContext((a=1,))) + ctx3 = PrefixContext(@varname(a), ConditionContext((a=1,))) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx3, vn) @test new_vn == @varname(a.x[1]) @test new_ctx == ConditionContext((a=1,)) - ctx4 = SamplingContext(PrefixContext{:a}(ConditionContext((a=1,)))) + ctx4 = SamplingContext(PrefixContext(@varname(a), ConditionContext((a=1,)))) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx4, vn) @test new_vn == @varname(a.x[1]) @test new_ctx == SamplingContext(ConditionContext((a=1,))) end @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - prefix = :my_prefix - context = DynamicPPL.PrefixContext{prefix}(SamplingContext()) + prefix_vn = @varname(my_prefix) + context = DynamicPPL.PrefixContext(prefix_vn, SamplingContext()) # Sample with the context. varinfo = DynamicPPL.VarInfo() DynamicPPL.evaluate!!(model, varinfo, context) @@ -196,7 +203,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # Extract the ground truth varnames vns_expected = Set([ - AbstractPPL.prefix(vn, VarName{prefix}()) for + AbstractPPL.prefix(vn, prefix_vn) for vn in DynamicPPL.TestUtils.varnames(model) ]) @@ -374,7 +381,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end # Prefix -> Condition - c1 = PrefixContext{:a}(ConditionContext((c=1, d=2))) + c1 = PrefixContext(@varname(a), ConditionContext((c=1, d=2))) c1 = collapse_prefix_stack(c1) @test has_no_prefixcontexts(c1) c1_vals = conditioned(c1) @@ -383,7 +390,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test getvalue(c1_vals, @varname(a.d)) == 2 # Condition -> Prefix - c2 = (ConditionContext((c=1, d=2), PrefixContext{:a}(DefaultContext()))) + c2 = ConditionContext((c=1, d=2), PrefixContext(@varname(a))) c2 = collapse_prefix_stack(c2) @test has_no_prefixcontexts(c2) c2_vals = conditioned(c2) @@ -392,7 +399,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test getvalue(c2_vals, @varname(d)) == 2 # Prefix -> Fixed - c3 = PrefixContext{:a}(FixedContext((f=1, g=2))) + c3 = PrefixContext(@varname(a), FixedContext((f=1, g=2))) c3 = collapse_prefix_stack(c3) c3_vals = fixed(c3) @test length(c3_vals) == 2 @@ -401,7 +408,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test getvalue(c3_vals, @varname(a.g)) == 2 # Fixed -> Prefix - c4 = (FixedContext((f=1, g=2), PrefixContext{:a}(DefaultContext()))) + c4 = FixedContext((f=1, g=2), PrefixContext(@varname(a))) c4 = collapse_prefix_stack(c4) @test has_no_prefixcontexts(c4) c4_vals = fixed(c4) @@ -410,8 +417,11 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test getvalue(c4_vals, @varname(g)) == 2 # Prefix -> Condition -> Prefix -> Condition - c5 = PrefixContext{:a}( - ConditionContext((c=1,), PrefixContext{:b}(ConditionContext((d=2,)))) + c5 = PrefixContext( + @varname(a), + ConditionContext( + (c=1,), PrefixContext(@varname(b), ConditionContext((d=2,))) + ), ) c5 = collapse_prefix_stack(c5) @test has_no_prefixcontexts(c5) @@ -421,8 +431,9 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test getvalue(c5_vals, @varname(a.b.d)) == 2 # Prefix -> Condition -> Prefix -> Fixed - c6 = PrefixContext{:a}( - ConditionContext((c=1,), PrefixContext{:b}(FixedContext((d=2,)))) + c6 = PrefixContext( + @varname(a), + ConditionContext((c=1,), PrefixContext(@varname(b), FixedContext((d=2,)))), ) c6 = collapse_prefix_stack(c6) @test has_no_prefixcontexts(c6) diff --git a/test/submodels.jl b/test/submodels.jl index 6a8a2c889..e79eed2c3 100644 --- a/test/submodels.jl +++ b/test/submodels.jl @@ -122,9 +122,7 @@ using Test p.b ~ Normal() return (p.a, p.b) end - expected_vns = Set([ - @varname(var"p.a".x[1]), @varname(var"p.a".y), @varname(p.b) - ]) + expected_vns = Set([@varname(p.a.x[1]), @varname(p.a.y), @varname(p.b)]) @test Set(keys(VarInfo(g()))) == expected_vns # Check that we can condition/fix on any of them from the outside