Skip to content

Commit 04b03cd

Browse files
sunxd3yebai
andauthored
Remove tonamedtuple (#547)
* Remove dependencies to `tonamedtuple` * Remove `tonamedtuple`s * Minor version bump --------- Co-authored-by: Hong Ge <[email protected]>
1 parent 2e8adf4 commit 04b03cd

File tree

8 files changed

+16
-89
lines changed

8 files changed

+16
-89
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.23.21"
3+
version = "0.24.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

docs/src/api.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,6 @@ DynamicPPL.reconstruct
258258
Base.merge(::AbstractVarInfo)
259259
DynamicPPL.subset
260260
DynamicPPL.unflatten
261-
DynamicPPL.tonamedtuple
262261
DynamicPPL.varname_leaves
263262
DynamicPPL.varname_and_value_leaves
264263
```

src/DynamicPPL.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ export AbstractVarInfo,
7171
invlink,
7272
invlink!,
7373
invlink!!,
74-
tonamedtuple,
7574
values_as,
7675
# VarName (reexport from AbstractPPL)
7776
VarName,

src/abstract_varinfo.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -738,21 +738,6 @@ function unflatten(sampler::AbstractSampler, varinfo::AbstractVarInfo, ::Abstrac
738738
return unflatten(varinfo, sampler, θ)
739739
end
740740

741-
"""
742-
tonamedtuple(vi::AbstractVarInfo)
743-
744-
Convert a `vi` into a `NamedTuple` where each variable symbol maps to the values and
745-
indexing string of the variable.
746-
747-
For example, a model that had a vector of vector-valued
748-
variables `x` would return
749-
750-
```julia
751-
(x = ([1.5, 2.0], [3.0, 1.0], ["x[1]", "x[2]"]), )
752-
```
753-
"""
754-
function tonamedtuple end
755-
756741
# TODO: Clean up all this linking stuff once and for all!
757742
"""
758743
with_logabsdet_jacobian_and_reconstruct([f, ]dist, x)

src/simple_varinfo.jl

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -532,44 +532,6 @@ function dot_assume(
532532
return value, lp, vi
533533
end
534534

535-
# We need these to be compatible with how chains are constructed from `AbstractVarInfo` in Turing.jl.
536-
# TODO: Move away from using these `tonamedtuple` methods.
537-
function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:NamedTuple{names}}) where {names}
538-
nt_vals = map(keys(vi)) do vn
539-
val = vi[vn]
540-
vns = collect(TestUtils.varname_leaves(vn, val))
541-
vals = map(copy Base.Fix1(getindex, vi), vns)
542-
(vals, map(string, vns))
543-
end
544-
545-
return NamedTuple{names}(nt_vals)
546-
end
547-
548-
function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:Dict})
549-
syms_to_result = Dict{Symbol,Tuple{Vector{Real},Vector{String}}}()
550-
for vn in keys(vi)
551-
# Extract the leaf varnames and values.
552-
val = vi[vn]
553-
vns = collect(TestUtils.varname_leaves(vn, val))
554-
vals = map(copy Base.Fix1(getindex, vi), vns)
555-
556-
# Determine the corresponding symbol.
557-
sym = only(unique(map(getsym, vns)))
558-
559-
# Initialize entry if not yet initialized.
560-
if !haskey(syms_to_result, sym)
561-
syms_to_result[sym] = (Real[], String[])
562-
end
563-
564-
# Combine with old result.
565-
old_vals, old_string_vns = syms_to_result[sym]
566-
syms_to_result[sym] = (vcat(old_vals, vals), vcat(old_string_vns, map(string, vns)))
567-
end
568-
569-
# Construct `NamedTuple`.
570-
return NamedTuple(pairs(syms_to_result))
571-
end
572-
573535
# NOTE: We don't implement `settrans!!(vi, trans, vn)`.
574536
function settrans!!(vi::SimpleVarInfo, trans)
575537
return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation())

src/threadsafe.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,6 @@ function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String)
209209
return is_flagged(vi.varinfo, vn, flag)
210210
end
211211

212-
tonamedtuple(vi::ThreadSafeVarInfo) = tonamedtuple(vi.varinfo)
213-
214212
# Transformations.
215213
function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName)
216214
return Setfield.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn)

src/varinfo.jl

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,22 +1506,6 @@ end
15061506
return expr
15071507
end
15081508

1509-
# TODO: Remove this completely.
1510-
tonamedtuple(varinfo::VarInfo) = tonamedtuple(varinfo.metadata, varinfo)
1511-
function tonamedtuple(metadata::NamedTuple{names}, varinfo::VarInfo) where {names}
1512-
length(names) === 0 && return NamedTuple()
1513-
1514-
vals_tuple = map(values(metadata)) do x
1515-
# NOTE: `tonamedtuple` is really only used in Turing.jl to convert to
1516-
# a "transition". This means that we really don't mutations of the values
1517-
# in `varinfo` to propoagate the previous samples. Hence we `copy.`
1518-
vals = map(copy Base.Fix1(getindex, varinfo), x.vns)
1519-
return vals, map(string, x.vns)
1520-
end
1521-
1522-
return NamedTuple{names}(vals_tuple)
1523-
end
1524-
15251509
@inline function findvns(vi, f_vns)
15261510
if length(f_vns) == 0
15271511
throw("Unidentified error, please report this error in an issue.")

test/test_util.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,22 +58,22 @@ function test_setval!(model, chain; sample_idx=1, chain_idx=1)
5858
DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx)
5959
θ_new = var_info[spl]
6060
@test θ_old != θ_new
61-
nt = DynamicPPL.tonamedtuple(var_info)
62-
for (k, (vals, names)) in pairs(nt)
63-
for (n, v) in zip(names, vals)
64-
if Symbol(n) keys(chain)
65-
# Assume it's a group
66-
chain_val = vec(
67-
MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]
68-
)
69-
v_true = vec(v)
70-
else
71-
chain_val = chain[sample_idx, n, chain_idx]
72-
v_true = v
73-
end
74-
75-
@test v_true == chain_val
61+
vals = DynamicPPL.values_as(var_info, OrderedDict)
62+
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
63+
for (n, v) in mapreduce(collect, vcat, iters)
64+
n = string(n)
65+
if Symbol(n) keys(chain)
66+
# Assume it's a group
67+
chain_val = vec(
68+
MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]
69+
)
70+
v_true = vec(v)
71+
else
72+
chain_val = chain[sample_idx, n, chain_idx]
73+
v_true = v
7674
end
75+
76+
@test v_true == chain_val
7777
end
7878
end
7979

0 commit comments

Comments
 (0)