Skip to content

Commit 36c8dd7

Browse files
committed
Implement values_as_in_model using an accumulator
1 parent 299e17b commit 36c8dd7

File tree

3 files changed

+69
-72
lines changed

3 files changed

+69
-72
lines changed

src/compiler.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -402,14 +402,21 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
402402
end
403403

404404
function generate_assign(left, right)
405-
right_expr = :($(TrackedValue)($right))
406-
tilde_expr = generate_tilde(left, right_expr)
405+
# A statement `x := y` reduces to `x = y`, but if __varinfo__ has an accumulator for
406+
# ValuesAsInModel then in addition we push! the pair of `x` and `y` to the accumulator.
407+
@gensym acc right_val vn
407408
return quote
408-
if $(is_extracting_values)(__context__)
409-
$tilde_expr
410-
else
411-
$left = $right
409+
$right_val = $right
410+
if $(DynamicPPL.is_extracting_values)(__varinfo__)
411+
$vn = $(DynamicPPL.prefix)(
412+
__context__,
413+
$(AbstractPPL.drop_escape(varname(left, need_concretize(left)))),
414+
)
415+
__varinfo__ = $(map_accumulator!!)(
416+
$acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel)
417+
)
412418
end
419+
$left = $right_val
413420
end
414421
end
415422

src/values_as_in_model.jl

Lines changed: 27 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,7 @@
1-
struct TrackedValue{T}
2-
value::T
3-
end
4-
5-
is_tracked_value(::TrackedValue) = true
6-
is_tracked_value(::Any) = false
7-
8-
check_tilde_rhs(x::TrackedValue) = x
9-
101
"""
11-
ValuesAsInModelContext
2+
ValuesAsInModelAccumulator <: AbstractAccumulator
123
13-
A context that is used by [`values_as_in_model`](@ref) to obtain values
4+
An accumulator that is used by [`values_as_in_model`](@ref) to obtain values
145
of the model parameters as they are in the model.
156
167
This is particularly useful when working in unconstrained space, but one
@@ -19,72 +10,43 @@ wants to extract the realization of a model in a constrained space.
1910
# Fields
2011
$(TYPEDFIELDS)
2112
"""
22-
struct ValuesAsInModelContext{C<:AbstractContext} <: AbstractContext
13+
struct ValuesAsInModelAccumulator <: AbstractAccumulator
2314
"values that are extracted from the model"
2415
values::OrderedDict
2516
"whether to extract variables on the LHS of :="
2617
include_colon_eq::Bool
27-
"child context"
28-
context::C
2918
end
30-
function ValuesAsInModelContext(include_colon_eq, context::AbstractContext)
31-
return ValuesAsInModelContext(OrderedDict(), include_colon_eq, context)
19+
function ValuesAsInModelAccumulator(include_colon_eq)
20+
return ValuesAsInModelAccumulator(OrderedDict(), include_colon_eq)
3221
end
3322

34-
NodeTrait(::ValuesAsInModelContext) = IsParent()
35-
childcontext(context::ValuesAsInModelContext) = context.context
36-
function setchildcontext(context::ValuesAsInModelContext, child)
37-
return ValuesAsInModelContext(context.values, context.include_colon_eq, child)
38-
end
23+
accumulator_name(::Type{<:ValuesAsInModelAccumulator}) = :ValuesAsInModel
3924

40-
is_extracting_values(context::ValuesAsInModelContext) = context.include_colon_eq
41-
function is_extracting_values(context::AbstractContext)
42-
return is_extracting_values(NodeTrait(context), context)
25+
function split(acc::ValuesAsInModelAccumulator)
26+
return ValuesAsInModelAccumulator(empty(acc.values), acc.include_colon_eq)
4327
end
44-
is_extracting_values(::IsParent, ::AbstractContext) = false
45-
is_extracting_values(::IsLeaf, ::AbstractContext) = false
46-
47-
function Base.push!(context::ValuesAsInModelContext, vn::VarName, value)
48-
return setindex!(context.values, copy(value), prefix(context, vn))
28+
function combine(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumulator)
29+
return ValuesAsInModelAccumulator(
30+
merge(acc1.values, acc2.values), acc1.include_colon_eq
31+
)
4932
end
5033

51-
function broadcast_push!(context::ValuesAsInModelContext, vns, values)
52-
return push!.((context,), vns, values)
34+
function Base.push!(acc::ValuesAsInModelAccumulator, vn::VarName, val)
35+
setindex!(acc.values, deepcopy(val), vn)
36+
return acc
5337
end
5438

55-
# This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`.
56-
function broadcast_push!(
57-
context::ValuesAsInModelContext, vns::AbstractVector, values::AbstractMatrix
58-
)
59-
for (vn, col) in zip(vns, eachcol(values))
60-
push!(context, vn, col)
61-
end
39+
function is_extracting_values(vi::AbstractVarInfo)
40+
return hasacc(vi, Val(:ValuesAsInModel)) &&
41+
getacc(vi, Val(:ValuesAsInModel)).include_colon_eq
6242
end
6343

64-
# `tilde_asssume`
65-
function tilde_assume(context::ValuesAsInModelContext, right, vn, vi)
66-
if is_tracked_value(right)
67-
value = right.value
68-
else
69-
value, vi = tilde_assume(childcontext(context), right, vn, vi)
70-
end
71-
push!(context, vn, value)
72-
return value, vi
73-
end
74-
function tilde_assume(
75-
rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi
76-
)
77-
if is_tracked_value(right)
78-
value = right.value
79-
else
80-
value, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
81-
end
82-
# Save the value.
83-
push!(context, vn, value)
84-
# Pass on.
85-
return value, vi
44+
function accumulate_assume!!(acc::ValuesAsInModelAccumulator, val, logjac, vn, right)
45+
return push!(acc, vn, val)
8646
end
8747

48+
accumulate_observe!!(acc::ValuesAsInModelAccumulator, right, left, vn) = acc
49+
8850
"""
8951
values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext])
9052
@@ -103,7 +65,7 @@ space at the cost of additional model evaluations.
10365
- `model::Model`: model to extract realizations from.
10466
- `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`.
10567
- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
106-
- `context::AbstractContext`: base context to use for the extraction. Defaults
68+
- `context::AbstractContext`: evaluation context to use in the extraction. Defaults
10769
to `DynamicPPL.DefaultContext()`.
10870
10971
# Examples
@@ -164,7 +126,8 @@ function values_as_in_model(
164126
varinfo::AbstractVarInfo,
165127
context::AbstractContext=DefaultContext(),
166128
)
167-
context = ValuesAsInModelContext(include_colon_eq, context)
168-
evaluate!!(model, varinfo, context)
169-
return context.values
129+
accs = getaccs(varinfo)
130+
varinfo = setaccs!!(deepcopy(varinfo), (ValuesAsInModelAccumulator(include_colon_eq),))
131+
varinfo = last(evaluate!!(model, varinfo, context))
132+
return getacc(varinfo, Val(:ValuesAsInModel)).values
170133
end

test/compiler.jl

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -732,10 +732,10 @@ module Issue537 end
732732
y := 100 + x
733733
return (; x, y)
734734
end
735-
@model function demo_tracked_submodel()
735+
@model function demo_tracked_submodel_no_prefix()
736736
return vals ~ to_submodel(demo_tracked(), false)
737737
end
738-
for model in [demo_tracked(), demo_tracked_submodel()]
738+
for model in [demo_tracked(), demo_tracked_submodel_no_prefix()]
739739
# Make sure it's runnable and `y` is present in the return-value.
740740
@test model() isa NamedTuple{(:x, :y)}
741741

@@ -756,6 +756,33 @@ module Issue537 end
756756
@test haskey(values, @varname(x))
757757
@test !haskey(values, @varname(y))
758758
end
759+
760+
@model function demo_tracked_return_x()
761+
x ~ Normal()
762+
y := 100 + x
763+
return x
764+
end
765+
@model function demo_tracked_submodel_prefix()
766+
return a ~ to_submodel(demo_tracked_return_x())
767+
end
768+
@model function demo_tracked_subsubmodel_prefix()
769+
return b ~ to_submodel(demo_tracked_submodel_prefix())
770+
end
771+
# As above, but the variables should now have their names prefixed with `b.a`.
772+
model = demo_tracked_subsubmodel_prefix()
773+
varinfo = VarInfo(model)
774+
@test haskey(varinfo, @varname(b.a.x))
775+
@test length(keys(varinfo)) == 1
776+
777+
values = values_as_in_model(model, true, deepcopy(varinfo))
778+
@test haskey(values, @varname(b.a.x))
779+
@test haskey(values, @varname(b.a.y))
780+
781+
# And if include_colon_eq is set to `false`, then `values` should
782+
# only contain `x`.
783+
values = values_as_in_model(model, false, deepcopy(varinfo))
784+
@test haskey(values, @varname(b.a.x))
785+
@test length(keys(varinfo)) == 1
759786
end
760787

761788
@testset "signature parsing + TypeWrap" begin

0 commit comments

Comments
 (0)