Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/model_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ function varname_in_chain!(
# We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the optic.
# This way we can use `getoptic(vn)` to extract the value from `x` and use `getoptic(vn) ∘ vn_parent`
# to extract the value from the `chain`.
for vn in varname_leaves(VarName{sym}(), x)
for vn in AbstractPPL.varname_leaves(VarName{sym}(), x)
# Update `out`, possibly in place, and return.
l = AbstractPPL.getoptic(vn)
varname_in_chain!(x, l ∘ vn_parent, chain, chain_idx, iteration_idx, out)
Expand All @@ -107,7 +107,7 @@ function values_from_chain(
# This way we can use `getoptic(vn)` to extract the value from `x` and use `getoptic(vn) ∘ vn_parent`
# to extract the value from the `chain`.
out = similar(x)
for vn in varname_leaves(VarName{sym}(), x)
for vn in AbstractPPL.varname_leaves(VarName{sym}(), x)
# Update `out`, possibly in place, and return.
l = AbstractPPL.getoptic(vn)
out = Accessors.set(
Expand Down
2 changes: 1 addition & 1 deletion src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using Bijectors: Bijectors
using Accessors: Accessors

# For backwards compat.
using DynamicPPL: varname_leaves, update_values!!
using DynamicPPL: update_values!!

include("test_utils/model_interface.jl")
include("test_utils/models.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/test_utils/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ function test_sampler(
for vn in filter(varnames_filter, varnames(model))
# We want to compare elementwise which can be achieved by
# extracting the leaves of the `VarName` and the corresponding value.
for vn_leaf in varname_leaves(vn, get(target_values, vn))
for vn_leaf in AbstractPPL.varname_leaves(vn, get(target_values, vn))
target_value = get(target_values, vn_leaf)
chain_mean_value = marginal_mean_of_samples(chain, vn_leaf)
@test chain_mean_value ≈ target_value atol = atol rtol = rtol
Expand Down
2 changes: 1 addition & 1 deletion test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
# here to split up arrays which could potentially have some,
# but not all, elements being `missing`.
conditioned_vns = mapreduce(
p -> DynamicPPL.TestUtils.varname_leaves(p.first, p.second),
p -> AbstractPPL.varname_leaves(p.first, p.second),
vcat,
pairs(conditioned_values),
)
Expand Down
2 changes: 1 addition & 1 deletion test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
chain_sym_map = Dict{Symbol,Symbol}()
for vn_parent in keys(var_info)
sym = DynamicPPL.getsym(vn_parent)
vn_children = DynamicPPL.varname_leaves(vn_parent, var_info[vn_parent])
vn_children = AbstractPPL.varname_leaves(vn_parent, var_info[vn_parent])
for vn_child in vn_children
chain_sym_map[Symbol(vn_child)] = sym
end
Expand Down
6 changes: 3 additions & 3 deletions test/model_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
chain = make_chain_from_prior(model, 10)
for (i, d) in enumerate(value_iterator_from_chain(model, chain))
for vn in keys(d)
val = DynamicPPL.getvalue(d, vn)
val = AbstractPPL.getvalue(d, vn)
# Because value_iterator_from_chain groups varnames with
# the same parent symbol, we have to ungroup them here
for vn_leaf in DynamicPPL.varname_leaves(vn, val)
val_leaf = DynamicPPL.getvalue(d, vn_leaf)
for vn_leaf in AbstractPPL.varname_leaves(vn, val)
val_leaf = AbstractPPL.getvalue(d, vn_leaf)
@test val_leaf == chain[i, Symbol(vn_leaf), 1]
end
end
Expand Down
Loading