Skip to content

Commit d4ef1f2

Browse files
authored
Implement values_as_in_model using an accumulator (#908)
* Implement values_as_in_model using an accumulator * Make make_varname_expression a function * Refuse to combine ValuesAsInModelAccumulators with different include_colon_eqs * Fix nested context test
1 parent 326d7ed commit d4ef1f2

File tree

4 files changed

+85
-83
lines changed

4 files changed

+85
-83
lines changed

src/compiler.jl

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,18 @@ function need_concretize(expr)
2929
end
3030
end
3131

32+
"""
33+
make_varname_expression(expr)
34+
35+
Return a `VarName` based on `expr`, concretizing it if necessary.
36+
"""
37+
function make_varname_expression(expr)
38+
# HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact
39+
# that in DynamicPPL we the entire function body. Instead we should be
40+
# more selective with our escape. Until that's the case, we remove them all.
41+
return AbstractPPL.drop_escape(varname(expr, need_concretize(expr)))
42+
end
43+
3244
"""
3345
isassumption(expr[, vn])
3446
@@ -48,10 +60,7 @@ evaluates to a `VarName`, and this will be used in the subsequent checks.
4860
If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be
4961
used in its place.
5062
"""
51-
function isassumption(
52-
expr::Union{Expr,Symbol},
53-
vn=AbstractPPL.drop_escape(varname(expr, need_concretize(expr))),
54-
)
63+
function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr))
5564
return quote
5665
if $(DynamicPPL.contextual_isassumption)(
5766
__context__, $(DynamicPPL.prefix)(__context__, $vn)
@@ -402,14 +411,18 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
402411
end
403412

404413
function generate_assign(left, right)
405-
right_expr = :($(TrackedValue)($right))
406-
tilde_expr = generate_tilde(left, right_expr)
414+
# A statement `x := y` reduces to `x = y`, but if __varinfo__ has an accumulator for
415+
# ValuesAsInModel then in addition we push! the pair of `x` and `y` to the accumulator.
416+
@gensym acc right_val vn
407417
return quote
408-
if $(is_extracting_values)(__context__)
409-
$tilde_expr
410-
else
411-
$left = $right
418+
$right_val = $right
419+
if $(DynamicPPL.is_extracting_values)(__varinfo__)
420+
$vn = $(DynamicPPL.prefix)(__context__, $(make_varname_expression(left)))
421+
__varinfo__ = $(map_accumulator!!)(
422+
$acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel)
423+
)
412424
end
425+
$left = $right_val
413426
end
414427
end
415428

@@ -437,14 +450,9 @@ function generate_tilde(left, right)
437450
# if the LHS represents an observation
438451
@gensym vn isassumption value dist
439452

440-
# HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact
441-
# that in DynamicPPL we the entire function body. Instead we should be
442-
# more selective with our escape. Until that's the case, we remove them all.
443453
return quote
444454
$dist = $right
445-
$vn = $(DynamicPPL.resolve_varnames)(
446-
$(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $dist
447-
)
455+
$vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist)
448456
$isassumption = $(DynamicPPL.isassumption(left, vn))
449457
if $(DynamicPPL.isfixed(left, vn))
450458
$left = $(DynamicPPL.getfixed_nested)(

src/values_as_in_model.jl

Lines changed: 31 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,7 @@
1-
struct TrackedValue{T}
2-
value::T
3-
end
4-
5-
is_tracked_value(::TrackedValue) = true
6-
is_tracked_value(::Any) = false
7-
8-
check_tilde_rhs(x::TrackedValue) = x
9-
101
"""
11-
ValuesAsInModelContext
2+
ValuesAsInModelAccumulator <: AbstractAccumulator
123
13-
A context that is used by [`values_as_in_model`](@ref) to obtain values
4+
An accumulator that is used by [`values_as_in_model`](@ref) to obtain values
145
of the model parameters as they are in the model.
156
167
This is particularly useful when working in unconstrained space, but one
@@ -19,72 +10,47 @@ wants to extract the realization of a model in a constrained space.
1910
# Fields
2011
$(TYPEDFIELDS)
2112
"""
22-
struct ValuesAsInModelContext{C<:AbstractContext} <: AbstractContext
13+
struct ValuesAsInModelAccumulator <: AbstractAccumulator
2314
"values that are extracted from the model"
2415
values::OrderedDict
2516
"whether to extract variables on the LHS of :="
2617
include_colon_eq::Bool
27-
"child context"
28-
context::C
2918
end
30-
function ValuesAsInModelContext(include_colon_eq, context::AbstractContext)
31-
return ValuesAsInModelContext(OrderedDict(), include_colon_eq, context)
19+
function ValuesAsInModelAccumulator(include_colon_eq)
20+
return ValuesAsInModelAccumulator(OrderedDict(), include_colon_eq)
3221
end
3322

34-
NodeTrait(::ValuesAsInModelContext) = IsParent()
35-
childcontext(context::ValuesAsInModelContext) = context.context
36-
function setchildcontext(context::ValuesAsInModelContext, child)
37-
return ValuesAsInModelContext(context.values, context.include_colon_eq, child)
38-
end
23+
accumulator_name(::Type{<:ValuesAsInModelAccumulator}) = :ValuesAsInModel
3924

40-
is_extracting_values(context::ValuesAsInModelContext) = context.include_colon_eq
41-
function is_extracting_values(context::AbstractContext)
42-
return is_extracting_values(NodeTrait(context), context)
25+
function split(acc::ValuesAsInModelAccumulator)
26+
return ValuesAsInModelAccumulator(empty(acc.values), acc.include_colon_eq)
4327
end
44-
is_extracting_values(::IsParent, ::AbstractContext) = false
45-
is_extracting_values(::IsLeaf, ::AbstractContext) = false
46-
47-
function Base.push!(context::ValuesAsInModelContext, vn::VarName, value)
48-
return setindex!(context.values, copy(value), prefix(context, vn))
28+
function combine(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumulator)
29+
if acc1.include_colon_eq != acc2.include_colon_eq
30+
msg = "Cannot combine accumulators with different include_colon_eq values."
31+
throw(ArgumentError(msg))
32+
end
33+
return ValuesAsInModelAccumulator(
34+
merge(acc1.values, acc2.values), acc1.include_colon_eq
35+
)
4936
end
5037

51-
function broadcast_push!(context::ValuesAsInModelContext, vns, values)
52-
return push!.((context,), vns, values)
38+
function Base.push!(acc::ValuesAsInModelAccumulator, vn::VarName, val)
39+
setindex!(acc.values, deepcopy(val), vn)
40+
return acc
5341
end
5442

55-
# This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`.
56-
function broadcast_push!(
57-
context::ValuesAsInModelContext, vns::AbstractVector, values::AbstractMatrix
58-
)
59-
for (vn, col) in zip(vns, eachcol(values))
60-
push!(context, vn, col)
61-
end
43+
function is_extracting_values(vi::AbstractVarInfo)
44+
return hasacc(vi, Val(:ValuesAsInModel)) &&
45+
getacc(vi, Val(:ValuesAsInModel)).include_colon_eq
6246
end
6347

64-
# `tilde_asssume`
65-
function tilde_assume(context::ValuesAsInModelContext, right, vn, vi)
66-
if is_tracked_value(right)
67-
value = right.value
68-
else
69-
value, vi = tilde_assume(childcontext(context), right, vn, vi)
70-
end
71-
push!(context, vn, value)
72-
return value, vi
73-
end
74-
function tilde_assume(
75-
rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi
76-
)
77-
if is_tracked_value(right)
78-
value = right.value
79-
else
80-
value, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
81-
end
82-
# Save the value.
83-
push!(context, vn, value)
84-
# Pass on.
85-
return value, vi
48+
function accumulate_assume!!(acc::ValuesAsInModelAccumulator, val, logjac, vn, right)
49+
return push!(acc, vn, val)
8650
end
8751

52+
accumulate_observe!!(acc::ValuesAsInModelAccumulator, right, left, vn) = acc
53+
8854
"""
8955
values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext])
9056
@@ -103,7 +69,7 @@ space at the cost of additional model evaluations.
10369
- `model::Model`: model to extract realizations from.
10470
- `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`.
10571
- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
106-
- `context::AbstractContext`: base context to use for the extraction. Defaults
72+
- `context::AbstractContext`: evaluation context to use in the extraction. Defaults
10773
to `DynamicPPL.DefaultContext()`.
10874
10975
# Examples
@@ -164,7 +130,8 @@ function values_as_in_model(
164130
varinfo::AbstractVarInfo,
165131
context::AbstractContext=DefaultContext(),
166132
)
167-
context = ValuesAsInModelContext(include_colon_eq, context)
168-
evaluate!!(model, varinfo, context)
169-
return context.values
133+
accs = getaccs(varinfo)
134+
varinfo = setaccs!!(deepcopy(varinfo), (ValuesAsInModelAccumulator(include_colon_eq),))
135+
varinfo = last(evaluate!!(model, varinfo, context))
136+
return getacc(varinfo, Val(:ValuesAsInModel)).values
170137
end

test/compiler.jl

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -732,10 +732,10 @@ module Issue537 end
732732
y := 100 + x
733733
return (; x, y)
734734
end
735-
@model function demo_tracked_submodel()
735+
@model function demo_tracked_submodel_no_prefix()
736736
return vals ~ to_submodel(demo_tracked(), false)
737737
end
738-
for model in [demo_tracked(), demo_tracked_submodel()]
738+
for model in [demo_tracked(), demo_tracked_submodel_no_prefix()]
739739
# Make sure it's runnable and `y` is present in the return-value.
740740
@test model() isa NamedTuple{(:x, :y)}
741741

@@ -756,6 +756,33 @@ module Issue537 end
756756
@test haskey(values, @varname(x))
757757
@test !haskey(values, @varname(y))
758758
end
759+
760+
@model function demo_tracked_return_x()
761+
x ~ Normal()
762+
y := 100 + x
763+
return x
764+
end
765+
@model function demo_tracked_submodel_prefix()
766+
return a ~ to_submodel(demo_tracked_return_x())
767+
end
768+
@model function demo_tracked_subsubmodel_prefix()
769+
return b ~ to_submodel(demo_tracked_submodel_prefix())
770+
end
771+
# As above, but the variables should now have their names prefixed with `b.a`.
772+
model = demo_tracked_subsubmodel_prefix()
773+
varinfo = VarInfo(model)
774+
@test haskey(varinfo, @varname(b.a.x))
775+
@test length(keys(varinfo)) == 1
776+
777+
values = values_as_in_model(model, true, deepcopy(varinfo))
778+
@test haskey(values, @varname(b.a.x))
779+
@test haskey(values, @varname(b.a.y))
780+
781+
# And if include_colon_eq is set to `false`, then `values` should
782+
# only contain `x`.
783+
values = values_as_in_model(model, false, deepcopy(varinfo))
784+
@test haskey(values, @varname(b.a.x))
785+
@test length(keys(varinfo)) == 1
759786
end
760787

761788
@testset "signature parsing + TypeWrap" begin

test/contexts.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
154154
@test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1])
155155
ctx3 = PrefixContext(@varname(b), ctx2)
156156
@test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1])
157-
ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3)
157+
ctx4 = DynamicPPL.SamplingContext(ctx3)
158158
@test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1])
159159
end
160160

0 commit comments

Comments
 (0)