Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -402,14 +402,21 @@
end

function generate_assign(left, right)
right_expr = :($(TrackedValue)($right))
tilde_expr = generate_tilde(left, right_expr)
# A statement `x := y` reduces to `x = y`, but if __varinfo__ has an accumulator for
# ValuesAsInModel then in addition we push! the pair of `x` and `y` to the accumulator.
@gensym acc right_val vn

Check warning on line 407 in src/compiler.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler.jl#L407

Added line #L407 was not covered by tests
return quote
if $(is_extracting_values)(__context__)
$tilde_expr
else
$left = $right
$right_val = $right
if $(DynamicPPL.is_extracting_values)(__varinfo__)
$vn = $(DynamicPPL.prefix)(

Check warning on line 411 in src/compiler.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler.jl#L409-L411

Added lines #L409 - L411 were not covered by tests
__context__,
$(AbstractPPL.drop_escape(varname(left, need_concretize(left)))),
)
__varinfo__ = $(map_accumulator!!)(
$acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel)

Check warning on line 416 in src/compiler.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler.jl#L415-L416

Added lines #L415 - L416 were not covered by tests
)
end
$left = $right_val

Check warning on line 419 in src/compiler.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler.jl#L419

Added line #L419 was not covered by tests
end
end

Expand Down
91 changes: 27 additions & 64 deletions src/values_as_in_model.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
struct TrackedValue{T}
value::T
end

is_tracked_value(::TrackedValue) = true
is_tracked_value(::Any) = false

check_tilde_rhs(x::TrackedValue) = x

"""
ValuesAsInModelContext
ValuesAsInModelAccumulator <: AbstractAccumulator
A context that is used by [`values_as_in_model`](@ref) to obtain values
An accumulator that is used by [`values_as_in_model`](@ref) to obtain values
of the model parameters as they are in the model.
This is particularly useful when working in unconstrained space, but one
Expand All @@ -19,72 +10,43 @@
# Fields
$(TYPEDFIELDS)
"""
struct ValuesAsInModelContext{C<:AbstractContext} <: AbstractContext
struct ValuesAsInModelAccumulator <: AbstractAccumulator
"values that are extracted from the model"
values::OrderedDict
"whether to extract variables on the LHS of :="
include_colon_eq::Bool
"child context"
context::C
end
function ValuesAsInModelContext(include_colon_eq, context::AbstractContext)
return ValuesAsInModelContext(OrderedDict(), include_colon_eq, context)
function ValuesAsInModelAccumulator(include_colon_eq)
return ValuesAsInModelAccumulator(OrderedDict(), include_colon_eq)
end

NodeTrait(::ValuesAsInModelContext) = IsParent()
childcontext(context::ValuesAsInModelContext) = context.context
function setchildcontext(context::ValuesAsInModelContext, child)
return ValuesAsInModelContext(context.values, context.include_colon_eq, child)
end
accumulator_name(::Type{<:ValuesAsInModelAccumulator}) = :ValuesAsInModel

is_extracting_values(context::ValuesAsInModelContext) = context.include_colon_eq
function is_extracting_values(context::AbstractContext)
return is_extracting_values(NodeTrait(context), context)
function split(acc::ValuesAsInModelAccumulator)
return ValuesAsInModelAccumulator(empty(acc.values), acc.include_colon_eq)
end
is_extracting_values(::IsParent, ::AbstractContext) = false
is_extracting_values(::IsLeaf, ::AbstractContext) = false

function Base.push!(context::ValuesAsInModelContext, vn::VarName, value)
return setindex!(context.values, copy(value), prefix(context, vn))
function combine(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumulator)
return ValuesAsInModelAccumulator(
merge(acc1.values, acc2.values), acc1.include_colon_eq
)
end

function broadcast_push!(context::ValuesAsInModelContext, vns, values)
return push!.((context,), vns, values)
function Base.push!(acc::ValuesAsInModelAccumulator, vn::VarName, val)
setindex!(acc.values, deepcopy(val), vn)
return acc
end

# This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`.
function broadcast_push!(
context::ValuesAsInModelContext, vns::AbstractVector, values::AbstractMatrix
)
for (vn, col) in zip(vns, eachcol(values))
push!(context, vn, col)
end
function is_extracting_values(vi::AbstractVarInfo)
return hasacc(vi, Val(:ValuesAsInModel)) &&

Check warning on line 40 in src/values_as_in_model.jl

View check run for this annotation

Codecov / codecov/patch

src/values_as_in_model.jl#L39-L40

Added lines #L39 - L40 were not covered by tests
getacc(vi, Val(:ValuesAsInModel)).include_colon_eq
end

# `tilde_asssume`
function tilde_assume(context::ValuesAsInModelContext, right, vn, vi)
if is_tracked_value(right)
value = right.value
else
value, vi = tilde_assume(childcontext(context), right, vn, vi)
end
push!(context, vn, value)
return value, vi
end
function tilde_assume(
rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi
)
if is_tracked_value(right)
value = right.value
else
value, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
end
# Save the value.
push!(context, vn, value)
# Pass on.
return value, vi
function accumulate_assume!!(acc::ValuesAsInModelAccumulator, val, logjac, vn, right)
return push!(acc, vn, val)
end

accumulate_observe!!(acc::ValuesAsInModelAccumulator, right, left, vn) = acc

Check warning on line 48 in src/values_as_in_model.jl

View check run for this annotation

Codecov / codecov/patch

src/values_as_in_model.jl#L48

Added line #L48 was not covered by tests

"""
values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext])
Expand All @@ -103,7 +65,7 @@
- `model::Model`: model to extract realizations from.
- `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`.
- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
- `context::AbstractContext`: base context to use for the extraction. Defaults
- `context::AbstractContext`: evaluation context to use in the extraction. Defaults
to `DynamicPPL.DefaultContext()`.
# Examples
Expand Down Expand Up @@ -164,7 +126,8 @@
varinfo::AbstractVarInfo,
context::AbstractContext=DefaultContext(),
)
context = ValuesAsInModelContext(include_colon_eq, context)
evaluate!!(model, varinfo, context)
return context.values
accs = getaccs(varinfo)
varinfo = setaccs!!(deepcopy(varinfo), (ValuesAsInModelAccumulator(include_colon_eq),))
varinfo = last(evaluate!!(model, varinfo, context))
return getacc(varinfo, Val(:ValuesAsInModel)).values
end
31 changes: 29 additions & 2 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -732,10 +732,10 @@ module Issue537 end
y := 100 + x
return (; x, y)
end
@model function demo_tracked_submodel()
@model function demo_tracked_submodel_no_prefix()
return vals ~ to_submodel(demo_tracked(), false)
end
for model in [demo_tracked(), demo_tracked_submodel()]
for model in [demo_tracked(), demo_tracked_submodel_no_prefix()]
# Make sure it's runnable and `y` is present in the return-value.
@test model() isa NamedTuple{(:x, :y)}

Expand All @@ -756,6 +756,33 @@ module Issue537 end
@test haskey(values, @varname(x))
@test !haskey(values, @varname(y))
end

@model function demo_tracked_return_x()
x ~ Normal()
y := 100 + x
return x
end
@model function demo_tracked_submodel_prefix()
return a ~ to_submodel(demo_tracked_return_x())
end
@model function demo_tracked_subsubmodel_prefix()
return b ~ to_submodel(demo_tracked_submodel_prefix())
end
# As above, but the variables should now have their names prefixed with `b.a`.
model = demo_tracked_subsubmodel_prefix()
varinfo = VarInfo(model)
@test haskey(varinfo, @varname(b.a.x))
@test length(keys(varinfo)) == 1

values = values_as_in_model(model, true, deepcopy(varinfo))
@test haskey(values, @varname(b.a.x))
@test haskey(values, @varname(b.a.y))

# And if include_colon_eq is set to `false`, then `values` should
# only contain `x`.
values = values_as_in_model(model, false, deepcopy(varinfo))
@test haskey(values, @varname(b.a.x))
@test length(keys(varinfo)) == 1
end

@testset "signature parsing + TypeWrap" begin
Expand Down
Loading