diff --git a/src/model_utils.jl b/src/model_utils.jl index ac4ec7022..e4c326b39 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -81,7 +81,7 @@ function varname_in_chain!( # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the optic. # This way we can use `getoptic(vn)` to extract the value from `x` and use `getoptic(vn) ∘ vn_parent` # to extract the value from the `chain`. - for vn in varname_leaves(VarName{sym}(), x) + for vn in AbstractPPL.varname_leaves(VarName{sym}(), x) # Update `out`, possibly in place, and return. l = AbstractPPL.getoptic(vn) varname_in_chain!(x, l ∘ vn_parent, chain, chain_idx, iteration_idx, out) @@ -107,7 +107,7 @@ function values_from_chain( # This way we can use `getoptic(vn)` to extract the value from `x` and use `getoptic(vn) ∘ vn_parent` # to extract the value from the `chain`. out = similar(x) - for vn in varname_leaves(VarName{sym}(), x) + for vn in AbstractPPL.varname_leaves(VarName{sym}(), x) # Update `out`, possibly in place, and return. l = AbstractPPL.getoptic(vn) out = Accessors.set( diff --git a/src/test_utils.jl b/src/test_utils.jl index 65079f023..195345d60 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -11,7 +11,7 @@ using Bijectors: Bijectors using Accessors: Accessors # For backwards compat. -using DynamicPPL: varname_leaves, update_values!! +using DynamicPPL: update_values!! include("test_utils/model_interface.jl") include("test_utils/models.jl") diff --git a/src/test_utils/sampler.jl b/src/test_utils/sampler.jl index 71cdb1cac..3ef965bad 100644 --- a/src/test_utils/sampler.jl +++ b/src/test_utils/sampler.jl @@ -51,7 +51,7 @@ function test_sampler( for vn in filter(varnames_filter, varnames(model)) # We want to compare elementwise which can be achieved by # extracting the leaves of the `VarName` and the corresponding value. - for vn_leaf in varname_leaves(vn, get(target_values, vn)) + for vn_leaf in AbstractPPL.varname_leaves(vn, get(target_values, vn)) target_value = get(target_values, vn_leaf) chain_mean_value = marginal_mean_of_samples(chain, vn_leaf) @test chain_mean_value ≈ target_value atol = atol rtol = rtol diff --git a/test/contexts.jl b/test/contexts.jl index 365865e7e..107607d99 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -93,7 +93,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # here to split up arrays which could potentially have some, # but not all, elements being `missing`. conditioned_vns = mapreduce( - p -> DynamicPPL.TestUtils.varname_leaves(p.first, p.second), + p -> AbstractPPL.varname_leaves(p.first, p.second), vcat, pairs(conditioned_values), ) diff --git a/test/model.jl b/test/model.jl index f062a70b4..2234dde8f 100644 --- a/test/model.jl +++ b/test/model.jl @@ -71,7 +71,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() chain_sym_map = Dict{Symbol,Symbol}() for vn_parent in keys(var_info) sym = DynamicPPL.getsym(vn_parent) - vn_children = DynamicPPL.varname_leaves(vn_parent, var_info[vn_parent]) + vn_children = AbstractPPL.varname_leaves(vn_parent, var_info[vn_parent]) for vn_child in vn_children chain_sym_map[Symbol(vn_child)] = sym end diff --git a/test/model_utils.jl b/test/model_utils.jl index 720ae55aa..af695dbf2 100644 --- a/test/model_utils.jl +++ b/test/model_utils.jl @@ -6,11 +6,11 @@ chain = make_chain_from_prior(model, 10) for (i, d) in enumerate(value_iterator_from_chain(model, chain)) for vn in keys(d) - val = DynamicPPL.getvalue(d, vn) + val = AbstractPPL.getvalue(d, vn) # Because value_iterator_from_chain groups varnames with # the same parent symbol, we have to ungroup them here - for vn_leaf in DynamicPPL.varname_leaves(vn, val) - val_leaf = DynamicPPL.getvalue(d, vn_leaf) + for vn_leaf in AbstractPPL.varname_leaves(vn, val) + val_leaf = AbstractPPL.getvalue(d, vn_leaf) @test val_leaf == chain[i, Symbol(vn_leaf), 1] end end