Skip to content

Commit 866eb6f

Browse files
Replacing tonamedtuple (#526)
* added impl of varname_and_value_leaves * added examples with cholesky to varname_and_value_leaves doctests * added more descriptive docstring of iterate for Leaf * added concrete example in comment of iterate for Leaf * added small docstring to Leaf * Update src/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent e2178c6 commit 866eb6f

File tree

3 files changed

+155
-1
lines changed

3 files changed

+155
-1
lines changed

docs/src/api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,8 @@ DynamicPPL.reconstruct
255255
```@docs
256256
DynamicPPL.unflatten
257257
DynamicPPL.tonamedtuple
258+
DynamicPPL.varname_leaves
259+
DynamicPPL.varname_and_value_leaves
258260
```
259261

260262
#### `SimpleVarInfo`

src/DynamicPPL.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using Setfield: Setfield
1515
using ZygoteRules: ZygoteRules
1616
using LogDensityProblems: LogDensityProblems
1717

18-
using LinearAlgebra: Cholesky
18+
using LinearAlgebra: LinearAlgebra, Cholesky
1919

2020
using DocStringExtensions
2121

src/utils.jl

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,3 +870,155 @@ function varname_leaves(vn::VarName, val::NamedTuple)
870870
end
871871
return Iterators.flatten(iter)
872872
end
873+
874+
"""
875+
varname_and_value_leaves(vn::VarName, val)
876+
877+
Return an iterator over all varname-value pairs that are represented by `vn` on `val`.
878+
879+
# Examples
880+
```jldoctest varname-and-value-leaves
881+
julia> using DynamicPPL: varname_and_value_leaves
882+
883+
julia> foreach(println, varname_and_value_leaves(@varname(x), 1:2))
884+
(x[1], 1)
885+
(x[2], 2)
886+
887+
julia> foreach(println, varname_and_value_leaves(@varname(x[1:2]), 1:2))
888+
(x[1:2][1], 1)
889+
(x[1:2][2], 2)
890+
891+
julia> x = (y = 1, z = [[2.0], [3.0]]);
892+
893+
julia> foreach(println, varname_and_value_leaves(@varname(x), x))
894+
(x.y, 1)
895+
(x.z[1][1], 2.0)
896+
(x.z[2][1], 3.0)
897+
```
898+
899+
There are also some special handling for certain types:
900+
901+
```jldoctest varname-and-value-leaves
902+
julia> using LinearAlgebra
903+
904+
julia> x = reshape(1:4, 2, 2);
905+
906+
julia> # `LowerTriangular`
907+
foreach(println, varname_and_value_leaves(@varname(x), LowerTriangular(x)))
908+
(x[1,1], 1)
909+
(x[2,1], 2)
910+
(x[2,2], 4)
911+
912+
julia> # `UpperTriangular`
913+
foreach(println, varname_and_value_leaves(@varname(x), UpperTriangular(x)))
914+
(x[1,1], 1)
915+
(x[1,2], 3)
916+
(x[2,2], 4)
917+
918+
julia> # `Cholesky` with lower-triangular
919+
foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'L', 0)))
920+
(x[1,1], 1.0)
921+
(x[2,1], 0.0)
922+
(x[2,2], 1.0)
923+
924+
julia> # `Cholesky` with upper-triangular
925+
foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'U', 0)))
926+
(x[1,1], 1.0)
927+
(x[1,2], 0.0)
928+
(x[2,2], 1.0)
929+
```
930+
"""
931+
function varname_and_value_leaves(vn::VarName, x)
932+
return Iterators.map(value, Iterators.flatten(varname_and_value_leaves_inner(vn, x)))
933+
end
934+
935+
"""
936+
Leaf{T}
937+
938+
A container that represents the leaf of a nested structure, implementing
939+
`iterate` to return itself.
940+
941+
This is particularly useful in conjunction with `Iterators.flatten` to
942+
prevent flattening of nested structures.
943+
"""
944+
struct Leaf{T}
945+
value::T
946+
end
947+
948+
Leaf(xs...) = Leaf(xs)
949+
950+
# Allow us to treat `Leaf` as an iterator containing a single element.
951+
# Something like an `[x]` would also be an iterator with a single element,
952+
# but when we call `flatten` on this, it would also iterate over `x`,
953+
# unflattening that too. By making `Leaf` a single-element iterator, which
954+
# returns itself, we can call `iterate` on this as many times as we like
955+
# without causing any change. The result is that `Iterators.flatten`
956+
# will _not_ unflatten `Leaf`s.
957+
# Note that this is similar to how `Base.iterate` is implemented for `Real`::
958+
#
959+
# julia> iterate(1)
960+
# (1, nothing)
961+
#
962+
# One immediate example where this becomes in our scenario is that we might
963+
# have `missing` values in our data, which does _not_ have an `iterate`
964+
# implemented. Calling `Iterators.flatten` on this would cause an error.
965+
Base.iterate(leaf::Leaf) = leaf, nothing
966+
Base.iterate(::Leaf, _) = nothing
967+
968+
# Convenience.
969+
value(leaf::Leaf) = leaf.value
970+
971+
# Leaf-types.
972+
varname_and_value_leaves_inner(vn::VarName, x::Real) = [Leaf(vn, x)]
973+
function varname_and_value_leaves_inner(
974+
vn::VarName, val::AbstractArray{<:Union{Real,Missing}}
975+
)
976+
return (
977+
Leaf(
978+
VarName(vn, DynamicPPL.getlens(vn) DynamicPPL.Setfield.IndexLens(Tuple(I))),
979+
val[I],
980+
) for I in CartesianIndices(val)
981+
)
982+
end
983+
# Containers.
984+
function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray)
985+
return Iterators.flatten(
986+
varname_and_value_leaves_inner(
987+
VarName(vn, DynamicPPL.getlens(vn) DynamicPPL.Setfield.IndexLens(Tuple(I))),
988+
val[I],
989+
) for I in CartesianIndices(val)
990+
)
991+
end
992+
function varname_and_value_leaves_inner(vn::DynamicPPL.VarName, val::NamedTuple)
993+
iter = Iterators.map(keys(val)) do sym
994+
lens = DynamicPPL.Setfield.PropertyLens{sym}()
995+
varname_and_value_leaves_inner(vn lens, get(val, lens))
996+
end
997+
998+
return Iterators.flatten(iter)
999+
end
1000+
# Special types.
1001+
function varname_and_value_leaves_inner(vn::VarName, x::Cholesky)
1002+
# TODO: Or do we use `PDMat` here?
1003+
return varname_and_value_leaves_inner(vn, x.UL)
1004+
end
1005+
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular)
1006+
return (
1007+
Leaf(
1008+
VarName(vn, DynamicPPL.getlens(vn) DynamicPPL.Setfield.IndexLens(Tuple(I))),
1009+
x[I],
1010+
)
1011+
# Iteration over the lower-triangular indices.
1012+
for I in CartesianIndices(x) if I[1] >= I[2]
1013+
)
1014+
end
1015+
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular)
1016+
return (
1017+
Leaf(
1018+
VarName(vn, DynamicPPL.getlens(vn) DynamicPPL.Setfield.IndexLens(Tuple(I))),
1019+
x[I],
1020+
)
1021+
# Iteration over the upper-triangular indices.
1022+
for I in CartesianIndices(x) if I[1] <= I[2]
1023+
)
1024+
end

0 commit comments

Comments
 (0)