@@ -870,3 +870,155 @@ function varname_leaves(vn::VarName, val::NamedTuple)
870
870
end
871
871
return Iterators. flatten (iter)
872
872
end
873
+
874
+ """
875
+ varname_and_value_leaves(vn::VarName, val)
876
+
877
+ Return an iterator over all varname-value pairs that are represented by `vn` on `val`.
878
+
879
+ # Examples
880
+ ```jldoctest varname-and-value-leaves
881
+ julia> using DynamicPPL: varname_and_value_leaves
882
+
883
+ julia> foreach(println, varname_and_value_leaves(@varname(x), 1:2))
884
+ (x[1], 1)
885
+ (x[2], 2)
886
+
887
+ julia> foreach(println, varname_and_value_leaves(@varname(x[1:2]), 1:2))
888
+ (x[1:2][1], 1)
889
+ (x[1:2][2], 2)
890
+
891
+ julia> x = (y = 1, z = [[2.0], [3.0]]);
892
+
893
+ julia> foreach(println, varname_and_value_leaves(@varname(x), x))
894
+ (x.y, 1)
895
+ (x.z[1][1], 2.0)
896
+ (x.z[2][1], 3.0)
897
+ ```
898
+
899
+ There are also some special handling for certain types:
900
+
901
+ ```jldoctest varname-and-value-leaves
902
+ julia> using LinearAlgebra
903
+
904
+ julia> x = reshape(1:4, 2, 2);
905
+
906
+ julia> # `LowerTriangular`
907
+ foreach(println, varname_and_value_leaves(@varname(x), LowerTriangular(x)))
908
+ (x[1,1], 1)
909
+ (x[2,1], 2)
910
+ (x[2,2], 4)
911
+
912
+ julia> # `UpperTriangular`
913
+ foreach(println, varname_and_value_leaves(@varname(x), UpperTriangular(x)))
914
+ (x[1,1], 1)
915
+ (x[1,2], 3)
916
+ (x[2,2], 4)
917
+
918
+ julia> # `Cholesky` with lower-triangular
919
+ foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'L', 0)))
920
+ (x[1,1], 1.0)
921
+ (x[2,1], 0.0)
922
+ (x[2,2], 1.0)
923
+
924
+ julia> # `Cholesky` with upper-triangular
925
+ foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'U', 0)))
926
+ (x[1,1], 1.0)
927
+ (x[1,2], 0.0)
928
+ (x[2,2], 1.0)
929
+ ```
930
+ """
931
+ function varname_and_value_leaves (vn:: VarName , x)
932
+ return Iterators. map (value, Iterators. flatten (varname_and_value_leaves_inner (vn, x)))
933
+ end
934
+
935
+ """
936
+ Leaf{T}
937
+
938
+ A container that represents the leaf of a nested structure, implementing
939
+ `iterate` to return itself.
940
+
941
+ This is particularly useful in conjunction with `Iterators.flatten` to
942
+ prevent flattening of nested structures.
943
+ """
944
+ struct Leaf{T}
945
+ value:: T
946
+ end
947
+
948
+ Leaf (xs... ) = Leaf (xs)
949
+
950
+ # Allow us to treat `Leaf` as an iterator containing a single element.
951
+ # Something like an `[x]` would also be an iterator with a single element,
952
+ # but when we call `flatten` on this, it would also iterate over `x`,
953
+ # unflattening that too. By making `Leaf` a single-element iterator, which
954
+ # returns itself, we can call `iterate` on this as many times as we like
955
+ # without causing any change. The result is that `Iterators.flatten`
956
+ # will _not_ unflatten `Leaf`s.
957
+ # Note that this is similar to how `Base.iterate` is implemented for `Real`::
958
+ #
959
+ # julia> iterate(1)
960
+ # (1, nothing)
961
+ #
962
+ # One immediate example where this becomes in our scenario is that we might
963
+ # have `missing` values in our data, which does _not_ have an `iterate`
964
+ # implemented. Calling `Iterators.flatten` on this would cause an error.
965
+ Base. iterate (leaf:: Leaf ) = leaf, nothing
966
+ Base. iterate (:: Leaf , _) = nothing
967
+
968
+ # Convenience.
969
+ value (leaf:: Leaf ) = leaf. value
970
+
971
+ # Leaf-types.
972
+ varname_and_value_leaves_inner (vn:: VarName , x:: Real ) = [Leaf (vn, x)]
973
+ function varname_and_value_leaves_inner (
974
+ vn:: VarName , val:: AbstractArray{<:Union{Real,Missing}}
975
+ )
976
+ return (
977
+ Leaf (
978
+ VarName (vn, DynamicPPL. getlens (vn) ∘ DynamicPPL. Setfield. IndexLens (Tuple (I))),
979
+ val[I],
980
+ ) for I in CartesianIndices (val)
981
+ )
982
+ end
983
+ # Containers.
984
+ function varname_and_value_leaves_inner (vn:: VarName , val:: AbstractArray )
985
+ return Iterators. flatten (
986
+ varname_and_value_leaves_inner (
987
+ VarName (vn, DynamicPPL. getlens (vn) ∘ DynamicPPL. Setfield. IndexLens (Tuple (I))),
988
+ val[I],
989
+ ) for I in CartesianIndices (val)
990
+ )
991
+ end
992
+ function varname_and_value_leaves_inner (vn:: DynamicPPL.VarName , val:: NamedTuple )
993
+ iter = Iterators. map (keys (val)) do sym
994
+ lens = DynamicPPL. Setfield. PropertyLens {sym} ()
995
+ varname_and_value_leaves_inner (vn ∘ lens, get (val, lens))
996
+ end
997
+
998
+ return Iterators. flatten (iter)
999
+ end
1000
+ # Special types.
1001
+ function varname_and_value_leaves_inner (vn:: VarName , x:: Cholesky )
1002
+ # TODO : Or do we use `PDMat` here?
1003
+ return varname_and_value_leaves_inner (vn, x. UL)
1004
+ end
1005
+ function varname_and_value_leaves_inner (vn:: VarName , x:: LinearAlgebra.LowerTriangular )
1006
+ return (
1007
+ Leaf (
1008
+ VarName (vn, DynamicPPL. getlens (vn) ∘ DynamicPPL. Setfield. IndexLens (Tuple (I))),
1009
+ x[I],
1010
+ )
1011
+ # Iteration over the lower-triangular indices.
1012
+ for I in CartesianIndices (x) if I[1 ] >= I[2 ]
1013
+ )
1014
+ end
1015
+ function varname_and_value_leaves_inner (vn:: VarName , x:: LinearAlgebra.UpperTriangular )
1016
+ return (
1017
+ Leaf (
1018
+ VarName (vn, DynamicPPL. getlens (vn) ∘ DynamicPPL. Setfield. IndexLens (Tuple (I))),
1019
+ x[I],
1020
+ )
1021
+ # Iteration over the upper-triangular indices.
1022
+ for I in CartesianIndices (x) if I[1 ] <= I[2 ]
1023
+ )
1024
+ end
0 commit comments