Skip to content

Commit fead2a2

Browse files
authored
tidy occurrences of varname_leaves as well (#1031)
1 parent c8e5841 commit fead2a2

File tree

6 files changed

+9
-9
lines changed

6 files changed

+9
-9
lines changed

src/model_utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ function varname_in_chain!(
8181
# We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the optic.
8282
# This way we can use `getoptic(vn)` to extract the value from `x` and use `getoptic(vn) ∘ vn_parent`
8383
# to extract the value from the `chain`.
84-
for vn in varname_leaves(VarName{sym}(), x)
84+
for vn in AbstractPPL.varname_leaves(VarName{sym}(), x)
8585
# Update `out`, possibly in place, and return.
8686
l = AbstractPPL.getoptic(vn)
8787
varname_in_chain!(x, l vn_parent, chain, chain_idx, iteration_idx, out)
@@ -107,7 +107,7 @@ function values_from_chain(
107107
# This way we can use `getoptic(vn)` to extract the value from `x` and use `getoptic(vn) ∘ vn_parent`
108108
# to extract the value from the `chain`.
109109
out = similar(x)
110-
for vn in varname_leaves(VarName{sym}(), x)
110+
for vn in AbstractPPL.varname_leaves(VarName{sym}(), x)
111111
# Update `out`, possibly in place, and return.
112112
l = AbstractPPL.getoptic(vn)
113113
out = Accessors.set(

src/test_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using Bijectors: Bijectors
1111
using Accessors: Accessors
1212

1313
# For backwards compat.
14-
using DynamicPPL: varname_leaves, update_values!!
14+
using DynamicPPL: update_values!!
1515

1616
include("test_utils/model_interface.jl")
1717
include("test_utils/models.jl")

src/test_utils/sampler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ function test_sampler(
5151
for vn in filter(varnames_filter, varnames(model))
5252
# We want to compare elementwise which can be achieved by
5353
# extracting the leaves of the `VarName` and the corresponding value.
54-
for vn_leaf in varname_leaves(vn, get(target_values, vn))
54+
for vn_leaf in AbstractPPL.varname_leaves(vn, get(target_values, vn))
5555
target_value = get(target_values, vn_leaf)
5656
chain_mean_value = marginal_mean_of_samples(chain, vn_leaf)
5757
@test chain_mean_value target_value atol = atol rtol = rtol

test/contexts.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
9393
# here to split up arrays which could potentially have some,
9494
# but not all, elements being `missing`.
9595
conditioned_vns = mapreduce(
96-
p -> DynamicPPL.TestUtils.varname_leaves(p.first, p.second),
96+
p -> AbstractPPL.varname_leaves(p.first, p.second),
9797
vcat,
9898
pairs(conditioned_values),
9999
)

test/model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
7171
chain_sym_map = Dict{Symbol,Symbol}()
7272
for vn_parent in keys(var_info)
7373
sym = DynamicPPL.getsym(vn_parent)
74-
vn_children = DynamicPPL.varname_leaves(vn_parent, var_info[vn_parent])
74+
vn_children = AbstractPPL.varname_leaves(vn_parent, var_info[vn_parent])
7575
for vn_child in vn_children
7676
chain_sym_map[Symbol(vn_child)] = sym
7777
end

test/model_utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
chain = make_chain_from_prior(model, 10)
77
for (i, d) in enumerate(value_iterator_from_chain(model, chain))
88
for vn in keys(d)
9-
val = DynamicPPL.getvalue(d, vn)
9+
val = AbstractPPL.getvalue(d, vn)
1010
# Because value_iterator_from_chain groups varnames with
1111
# the same parent symbol, we have to ungroup them here
12-
for vn_leaf in DynamicPPL.varname_leaves(vn, val)
13-
val_leaf = DynamicPPL.getvalue(d, vn_leaf)
12+
for vn_leaf in AbstractPPL.varname_leaves(vn, val)
13+
val_leaf = AbstractPPL.getvalue(d, vn_leaf)
1414
@test val_leaf == chain[i, Symbol(vn_leaf), 1]
1515
end
1616
end

0 commit comments

Comments
 (0)