@@ -837,245 +837,6 @@ end
837
837
# Handle `AbstractDict` differently since `eltype` results in a `Pair`.
838
838
infer_nested_eltype (:: Type{<:AbstractDict{<:Any,ET}} ) where {ET} = infer_nested_eltype (ET)
839
839
840
- """
841
- varname_leaves(vn::VarName, val)
842
-
843
- Return an iterator over all varnames that are represented by `vn` on `val`.
844
-
845
- # Examples
846
- ```jldoctest
847
- julia> using DynamicPPL: varname_leaves
848
-
849
- julia> foreach(println, varname_leaves(@varname(x), rand(2)))
850
- x[1]
851
- x[2]
852
-
853
- julia> foreach(println, varname_leaves(@varname(x[1:2]), rand(2)))
854
- x[1:2][1]
855
- x[1:2][2]
856
-
857
- julia> x = (y = 1, z = [[2.0], [3.0]]);
858
-
859
- julia> foreach(println, varname_leaves(@varname(x), x))
860
- x.y
861
- x.z[1][1]
862
- x.z[2][1]
863
- ```
864
- """
865
- varname_leaves (vn:: VarName , :: Real ) = [vn]
866
- function varname_leaves (vn:: VarName , val:: AbstractArray{<:Union{Real,Missing}} )
867
- return (
868
- VarName {getsym(vn)} (Accessors. IndexLens (Tuple (I)) ∘ getoptic (vn)) for
869
- I in CartesianIndices (val)
870
- )
871
- end
872
- function varname_leaves (vn:: VarName , val:: AbstractArray )
873
- return Iterators. flatten (
874
- varname_leaves (
875
- VarName {getsym(vn)} (Accessors. IndexLens (Tuple (I)) ∘ getoptic (vn)), val[I]
876
- ) for I in CartesianIndices (val)
877
- )
878
- end
879
- function varname_leaves (vn:: VarName , val:: NamedTuple )
880
- iter = Iterators. map (keys (val)) do k
881
- optic = Accessors. PropertyLens {k} ()
882
- varname_leaves (VarName {getsym(vn)} (optic ∘ getoptic (vn)), optic (val))
883
- end
884
- return Iterators. flatten (iter)
885
- end
886
-
887
- """
888
- varname_and_value_leaves(vn::VarName, val)
889
-
890
- Return an iterator over all varname-value pairs that are represented by `vn` on `val`.
891
-
892
- # Examples
893
- ```jldoctest varname-and-value-leaves
894
- julia> using DynamicPPL: varname_and_value_leaves
895
-
896
- julia> foreach(println, varname_and_value_leaves(@varname(x), 1:2))
897
- (x[1], 1)
898
- (x[2], 2)
899
-
900
- julia> foreach(println, varname_and_value_leaves(@varname(x[1:2]), 1:2))
901
- (x[1:2][1], 1)
902
- (x[1:2][2], 2)
903
-
904
- julia> x = (y = 1, z = [[2.0], [3.0]]);
905
-
906
- julia> foreach(println, varname_and_value_leaves(@varname(x), x))
907
- (x.y, 1)
908
- (x.z[1][1], 2.0)
909
- (x.z[2][1], 3.0)
910
- ```
911
-
912
- There is also some special handling for certain types:
913
-
914
- ```jldoctest varname-and-value-leaves
915
- julia> using LinearAlgebra
916
-
917
- julia> x = reshape(1:4, 2, 2);
918
-
919
- julia> # `LowerTriangular`
920
- foreach(println, varname_and_value_leaves(@varname(x), LowerTriangular(x)))
921
- (x[1, 1], 1)
922
- (x[2, 1], 2)
923
- (x[2, 2], 4)
924
-
925
- julia> # `UpperTriangular`
926
- foreach(println, varname_and_value_leaves(@varname(x), UpperTriangular(x)))
927
- (x[1, 1], 1)
928
- (x[1, 2], 3)
929
- (x[2, 2], 4)
930
-
931
- julia> # `Cholesky` with lower-triangular
932
- foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'L', 0)))
933
- (x.L[1, 1], 1.0)
934
- (x.L[2, 1], 0.0)
935
- (x.L[2, 2], 1.0)
936
-
937
- julia> # `Cholesky` with upper-triangular
938
- foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'U', 0)))
939
- (x.U[1, 1], 1.0)
940
- (x.U[1, 2], 0.0)
941
- (x.U[2, 2], 1.0)
942
- ```
943
- """
944
- function varname_and_value_leaves (vn:: VarName , x)
945
- return Iterators. map (value, Iterators. flatten (varname_and_value_leaves_inner (vn, x)))
946
- end
947
-
948
- """
949
- varname_and_value_leaves(container)
950
-
951
- Return an iterator over all varname-value pairs that are represented by `container`.
952
-
953
- This is the same as [`varname_and_value_leaves(vn::VarName, x)`](@ref) but over a container
954
- containing multiple varnames.
955
-
956
- See also: [`varname_and_value_leaves(vn::VarName, x)`](@ref).
957
-
958
- # Examples
959
- ```jldoctest varname-and-value-leaves-container
960
- julia> using DynamicPPL: varname_and_value_leaves
961
-
962
- julia> # With an `OrderedDict`
963
- dict = OrderedDict(@varname(y) => 1, @varname(z) => [[2.0], [3.0]]);
964
-
965
- julia> foreach(println, varname_and_value_leaves(dict))
966
- (y, 1)
967
- (z[1][1], 2.0)
968
- (z[2][1], 3.0)
969
-
970
- julia> # With a `NamedTuple`
971
- nt = (y = 1, z = [[2.0], [3.0]]);
972
-
973
- julia> foreach(println, varname_and_value_leaves(nt))
974
- (y, 1)
975
- (z[1][1], 2.0)
976
- (z[2][1], 3.0)
977
- ```
978
- """
979
- function varname_and_value_leaves (container:: OrderedDict )
980
- return Iterators. flatten (varname_and_value_leaves (k, v) for (k, v) in container)
981
- end
982
- function varname_and_value_leaves (container:: NamedTuple )
983
- return Iterators. flatten (
984
- varname_and_value_leaves (VarName {k} (), v) for (k, v) in pairs (container)
985
- )
986
- end
987
-
988
- """
989
- Leaf{T}
990
-
991
- A container that represents the leaf of a nested structure, implementing
992
- `iterate` to return itself.
993
-
994
- This is particularly useful in conjunction with `Iterators.flatten` to
995
- prevent flattening of nested structures.
996
- """
997
- struct Leaf{T}
998
- value:: T
999
- end
1000
-
1001
- Leaf (xs... ) = Leaf (xs)
1002
-
1003
- # Allow us to treat `Leaf` as an iterator containing a single element.
1004
- # Something like an `[x]` would also be an iterator with a single element,
1005
- # but when we call `flatten` on this, it would also iterate over `x`,
1006
- # unflattening that too. By making `Leaf` a single-element iterator, which
1007
- # returns itself, we can call `iterate` on this as many times as we like
1008
- # without causing any change. The result is that `Iterators.flatten`
1009
- # will _not_ unflatten `Leaf`s.
1010
- # Note that this is similar to how `Base.iterate` is implemented for `Real`::
1011
- #
1012
- # julia> iterate(1)
1013
- # (1, nothing)
1014
- #
1015
- # One immediate example where this becomes in our scenario is that we might
1016
- # have `missing` values in our data, which does _not_ have an `iterate`
1017
- # implemented. Calling `Iterators.flatten` on this would cause an error.
1018
- Base. iterate (leaf:: Leaf ) = leaf, nothing
1019
- Base. iterate (:: Leaf , _) = nothing
1020
-
1021
- # Convenience.
1022
- value (leaf:: Leaf ) = leaf. value
1023
-
1024
- # Leaf-types.
1025
- varname_and_value_leaves_inner (vn:: VarName , x:: Real ) = [Leaf (vn, x)]
1026
- function varname_and_value_leaves_inner (
1027
- vn:: VarName , val:: AbstractArray{<:Union{Real,Missing}}
1028
- )
1029
- return (
1030
- Leaf (
1031
- VarName {getsym(vn)} (Accessors. IndexLens (Tuple (I)) ∘ AbstractPPL. getoptic (vn)),
1032
- val[I],
1033
- ) for I in CartesianIndices (val)
1034
- )
1035
- end
1036
- # Containers.
1037
- function varname_and_value_leaves_inner (vn:: VarName , val:: AbstractArray )
1038
- return Iterators. flatten (
1039
- varname_and_value_leaves_inner (
1040
- VarName {getsym(vn)} (Accessors. IndexLens (Tuple (I)) ∘ AbstractPPL. getoptic (vn)),
1041
- val[I],
1042
- ) for I in CartesianIndices (val)
1043
- )
1044
- end
1045
- function varname_and_value_leaves_inner (vn:: VarName , val:: NamedTuple )
1046
- iter = Iterators. map (keys (val)) do k
1047
- optic = Accessors. PropertyLens {k} ()
1048
- varname_and_value_leaves_inner (
1049
- VarName {getsym(vn)} (optic ∘ getoptic (vn)), optic (val)
1050
- )
1051
- end
1052
-
1053
- return Iterators. flatten (iter)
1054
- end
1055
- # Special types.
1056
- function varname_and_value_leaves_inner (vn:: VarName , x:: Cholesky )
1057
- # TODO : Or do we use `PDMat` here?
1058
- return if x. uplo == ' L'
1059
- varname_and_value_leaves_inner (Accessors. PropertyLens {:L} () ∘ vn, x. L)
1060
- else
1061
- varname_and_value_leaves_inner (Accessors. PropertyLens {:U} () ∘ vn, x. U)
1062
- end
1063
- end
1064
- function varname_and_value_leaves_inner (vn:: VarName , x:: LinearAlgebra.LowerTriangular )
1065
- return (
1066
- Leaf (VarName {getsym(vn)} (Accessors. IndexLens (Tuple (I)) ∘ getoptic (vn)), x[I])
1067
- # Iteration over the lower-triangular indices.
1068
- for I in CartesianIndices (x) if I[1 ] >= I[2 ]
1069
- )
1070
- end
1071
- function varname_and_value_leaves_inner (vn:: VarName , x:: LinearAlgebra.UpperTriangular )
1072
- return (
1073
- Leaf (VarName {getsym(vn)} (Accessors. IndexLens (Tuple (I)) ∘ getoptic (vn)), x[I])
1074
- # Iteration over the upper-triangular indices.
1075
- for I in CartesianIndices (x) if I[1 ] <= I[2 ]
1076
- )
1077
- end
1078
-
1079
840
broadcast_safe (x) = x
1080
841
broadcast_safe (x:: Distribution ) = (x,)
1081
842
broadcast_safe (x:: AbstractContext ) = (x,)
0 commit comments