diff --git a/src/compiler.jl b/src/compiler.jl index 9eb4835d3..b783c2a13 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -29,6 +29,18 @@ function need_concretize(expr) end end +""" + make_varname_expression(expr) + +Return a `VarName` based on `expr`, concretizing it if necessary. +""" +function make_varname_expression(expr) + # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact + # that in DynamicPPL we the entire function body. Instead we should be + # more selective with our escape. Until that's the case, we remove them all. + return AbstractPPL.drop_escape(varname(expr, need_concretize(expr))) +end + """ isassumption(expr[, vn]) @@ -48,10 +60,7 @@ evaluates to a `VarName`, and this will be used in the subsequent checks. If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be used in its place. """ -function isassumption( - expr::Union{Expr,Symbol}, - vn=AbstractPPL.drop_escape(varname(expr, need_concretize(expr))), -) +function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr)) return quote if $(DynamicPPL.contextual_isassumption)( __context__, $(DynamicPPL.prefix)(__context__, $vn) @@ -402,14 +411,18 @@ function generate_mainbody!(mod, found, expr::Expr, warn) end function generate_assign(left, right) - right_expr = :($(TrackedValue)($right)) - tilde_expr = generate_tilde(left, right_expr) + # A statement `x := y` reduces to `x = y`, but if __varinfo__ has an accumulator for + # ValuesAsInModel then in addition we push! the pair of `x` and `y` to the accumulator. + @gensym acc right_val vn return quote - if $(is_extracting_values)(__context__) - $tilde_expr - else - $left = $right + $right_val = $right + if $(DynamicPPL.is_extracting_values)(__varinfo__) + $vn = $(DynamicPPL.prefix)(__context__, $(make_varname_expression(left))) + __varinfo__ = $(map_accumulator!!)( + $acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel) + ) end + $left = $right_val end end @@ -437,14 +450,9 @@ function generate_tilde(left, right) # if the LHS represents an observation @gensym vn isassumption value dist - # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact - # that in DynamicPPL we the entire function body. Instead we should be - # more selective with our escape. Until that's the case, we remove them all. return quote $dist = $right - $vn = $(DynamicPPL.resolve_varnames)( - $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $dist - ) + $vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist) $isassumption = $(DynamicPPL.isassumption(left, vn)) if $(DynamicPPL.isfixed(left, vn)) $left = $(DynamicPPL.getfixed_nested)( diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 3ec474940..4d6225c10 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -1,16 +1,7 @@ -struct TrackedValue{T} - value::T -end - -is_tracked_value(::TrackedValue) = true -is_tracked_value(::Any) = false - -check_tilde_rhs(x::TrackedValue) = x - """ - ValuesAsInModelContext + ValuesAsInModelAccumulator <: AbstractAccumulator -A context that is used by [`values_as_in_model`](@ref) to obtain values +An accumulator that is used by [`values_as_in_model`](@ref) to obtain values of the model parameters as they are in the model. This is particularly useful when working in unconstrained space, but one @@ -19,72 +10,47 @@ wants to extract the realization of a model in a constrained space. # Fields $(TYPEDFIELDS) """ -struct ValuesAsInModelContext{C<:AbstractContext} <: AbstractContext +struct ValuesAsInModelAccumulator <: AbstractAccumulator "values that are extracted from the model" values::OrderedDict "whether to extract variables on the LHS of :=" include_colon_eq::Bool - "child context" - context::C end -function ValuesAsInModelContext(include_colon_eq, context::AbstractContext) - return ValuesAsInModelContext(OrderedDict(), include_colon_eq, context) +function ValuesAsInModelAccumulator(include_colon_eq) + return ValuesAsInModelAccumulator(OrderedDict(), include_colon_eq) end -NodeTrait(::ValuesAsInModelContext) = IsParent() -childcontext(context::ValuesAsInModelContext) = context.context -function setchildcontext(context::ValuesAsInModelContext, child) - return ValuesAsInModelContext(context.values, context.include_colon_eq, child) -end +accumulator_name(::Type{<:ValuesAsInModelAccumulator}) = :ValuesAsInModel -is_extracting_values(context::ValuesAsInModelContext) = context.include_colon_eq -function is_extracting_values(context::AbstractContext) - return is_extracting_values(NodeTrait(context), context) +function split(acc::ValuesAsInModelAccumulator) + return ValuesAsInModelAccumulator(empty(acc.values), acc.include_colon_eq) end -is_extracting_values(::IsParent, ::AbstractContext) = false -is_extracting_values(::IsLeaf, ::AbstractContext) = false - -function Base.push!(context::ValuesAsInModelContext, vn::VarName, value) - return setindex!(context.values, copy(value), prefix(context, vn)) +function combine(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumulator) + if acc1.include_colon_eq != acc2.include_colon_eq + msg = "Cannot combine accumulators with different include_colon_eq values." + throw(ArgumentError(msg)) + end + return ValuesAsInModelAccumulator( + merge(acc1.values, acc2.values), acc1.include_colon_eq + ) end -function broadcast_push!(context::ValuesAsInModelContext, vns, values) - return push!.((context,), vns, values) +function Base.push!(acc::ValuesAsInModelAccumulator, vn::VarName, val) + setindex!(acc.values, deepcopy(val), vn) + return acc end -# This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`. -function broadcast_push!( - context::ValuesAsInModelContext, vns::AbstractVector, values::AbstractMatrix -) - for (vn, col) in zip(vns, eachcol(values)) - push!(context, vn, col) - end +function is_extracting_values(vi::AbstractVarInfo) + return hasacc(vi, Val(:ValuesAsInModel)) && + getacc(vi, Val(:ValuesAsInModel)).include_colon_eq end -# `tilde_asssume` -function tilde_assume(context::ValuesAsInModelContext, right, vn, vi) - if is_tracked_value(right) - value = right.value - else - value, vi = tilde_assume(childcontext(context), right, vn, vi) - end - push!(context, vn, value) - return value, vi -end -function tilde_assume( - rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi -) - if is_tracked_value(right) - value = right.value - else - value, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi) - end - # Save the value. - push!(context, vn, value) - # Pass on. - return value, vi +function accumulate_assume!!(acc::ValuesAsInModelAccumulator, val, logjac, vn, right) + return push!(acc, vn, val) end +accumulate_observe!!(acc::ValuesAsInModelAccumulator, right, left, vn) = acc + """ values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext]) @@ -103,7 +69,7 @@ space at the cost of additional model evaluations. - `model::Model`: model to extract realizations from. - `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`. - `varinfo::AbstractVarInfo`: variable information to use for the extraction. -- `context::AbstractContext`: base context to use for the extraction. Defaults +- `context::AbstractContext`: evaluation context to use in the extraction. Defaults to `DynamicPPL.DefaultContext()`. # Examples @@ -164,7 +130,8 @@ function values_as_in_model( varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext(), ) - context = ValuesAsInModelContext(include_colon_eq, context) - evaluate!!(model, varinfo, context) - return context.values + accs = getaccs(varinfo) + varinfo = setaccs!!(deepcopy(varinfo), (ValuesAsInModelAccumulator(include_colon_eq),)) + varinfo = last(evaluate!!(model, varinfo, context)) + return getacc(varinfo, Val(:ValuesAsInModel)).values end diff --git a/test/compiler.jl b/test/compiler.jl index 81c018111..2e76de27f 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -732,10 +732,10 @@ module Issue537 end y := 100 + x return (; x, y) end - @model function demo_tracked_submodel() + @model function demo_tracked_submodel_no_prefix() return vals ~ to_submodel(demo_tracked(), false) end - for model in [demo_tracked(), demo_tracked_submodel()] + for model in [demo_tracked(), demo_tracked_submodel_no_prefix()] # Make sure it's runnable and `y` is present in the return-value. @test model() isa NamedTuple{(:x, :y)} @@ -756,6 +756,33 @@ module Issue537 end @test haskey(values, @varname(x)) @test !haskey(values, @varname(y)) end + + @model function demo_tracked_return_x() + x ~ Normal() + y := 100 + x + return x + end + @model function demo_tracked_submodel_prefix() + return a ~ to_submodel(demo_tracked_return_x()) + end + @model function demo_tracked_subsubmodel_prefix() + return b ~ to_submodel(demo_tracked_submodel_prefix()) + end + # As above, but the variables should now have their names prefixed with `b.a`. + model = demo_tracked_subsubmodel_prefix() + varinfo = VarInfo(model) + @test haskey(varinfo, @varname(b.a.x)) + @test length(keys(varinfo)) == 1 + + values = values_as_in_model(model, true, deepcopy(varinfo)) + @test haskey(values, @varname(b.a.x)) + @test haskey(values, @varname(b.a.y)) + + # And if include_colon_eq is set to `false`, then `values` should + # only contain `x`. + values = values_as_in_model(model, false, deepcopy(varinfo)) + @test haskey(values, @varname(b.a.x)) + @test length(keys(varinfo)) == 1 end @testset "signature parsing + TypeWrap" begin diff --git a/test/contexts.jl b/test/contexts.jl index 5f22b75eb..1dd6a2280 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -154,7 +154,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1]) ctx3 = PrefixContext(@varname(b), ctx2) @test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1]) - ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3) + ctx4 = DynamicPPL.SamplingContext(ctx3) @test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1]) end