Skip to content

Commit a140fda

Browse files
committed
Move varname{_and_value,}_leaves to AbstractPPL
1 parent 8dddcea commit a140fda

File tree

3 files changed

+240
-1
lines changed

3 files changed

+240
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
33
keywords = ["probablistic programming"]
44
license = "MIT"
55
desc = "Common interfaces for probabilistic programming"
6-
version = "0.13.0"
6+
version = "0.14.0"
77

88
[deps]
99
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/AbstractPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ include("varname.jl")
3131
include("abstractmodeltrace.jl")
3232
include("abstractprobprog.jl")
3333
include("evaluate.jl")
34+
include("varname_leaves.jl")
3435
include("hasvalue.jl")
3536

3637
end # module

src/varname_leaves.jl

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
"""
2+
varname_leaves(vn::VarName, val)
3+
4+
Return an iterator over all varnames that are represented by `vn` on `val`.
5+
6+
# Examples
7+
```jldoctest
8+
julia> using AbstractPPL: varname_leaves
9+
10+
julia> foreach(println, varname_leaves(@varname(x), rand(2)))
11+
x[1]
12+
x[2]
13+
14+
julia> foreach(println, varname_leaves(@varname(x[1:2]), rand(2)))
15+
x[1:2][1]
16+
x[1:2][2]
17+
18+
julia> x = (y = 1, z = [[2.0], [3.0]]);
19+
20+
julia> foreach(println, varname_leaves(@varname(x), x))
21+
x.y
22+
x.z[1][1]
23+
x.z[2][1]
24+
```
25+
"""
26+
varname_leaves(vn::VarName, ::Real) = [vn]
27+
function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}})
28+
return (
29+
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn)) for
30+
I in CartesianIndices(val)
31+
)
32+
end
33+
function varname_leaves(vn::VarName, val::AbstractArray)
34+
return Iterators.flatten(
35+
varname_leaves(
36+
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn)), val[I]
37+
) for I in CartesianIndices(val)
38+
)
39+
end
40+
function varname_leaves(vn::VarName, val::NamedTuple)
41+
iter = Iterators.map(keys(val)) do k
42+
optic = Accessors.PropertyLens{k}()
43+
varname_leaves(VarName{getsym(vn)}(optic getoptic(vn)), optic(val))
44+
end
45+
return Iterators.flatten(iter)
46+
end
47+
48+
"""
49+
varname_and_value_leaves(vn::VarName, val)
50+
51+
Return an iterator over all varname-value pairs that are represented by `vn` on `val`.
52+
53+
# Examples
54+
```jldoctest varname-and-value-leaves
55+
julia> using AbstractPPL: varname_and_value_leaves
56+
57+
julia> foreach(println, varname_and_value_leaves(@varname(x), 1:2))
58+
(x[1], 1)
59+
(x[2], 2)
60+
61+
julia> foreach(println, varname_and_value_leaves(@varname(x[1:2]), 1:2))
62+
(x[1:2][1], 1)
63+
(x[1:2][2], 2)
64+
65+
julia> x = (y = 1, z = [[2.0], [3.0]]);
66+
67+
julia> foreach(println, varname_and_value_leaves(@varname(x), x))
68+
(x.y, 1)
69+
(x.z[1][1], 2.0)
70+
(x.z[2][1], 3.0)
71+
```
72+
73+
There is also some special handling for certain types:
74+
75+
```jldoctest varname-and-value-leaves
76+
julia> using LinearAlgebra
77+
78+
julia> x = reshape(1:4, 2, 2);
79+
80+
julia> # `LowerTriangular`
81+
foreach(println, varname_and_value_leaves(@varname(x), LowerTriangular(x)))
82+
(x[1, 1], 1)
83+
(x[2, 1], 2)
84+
(x[2, 2], 4)
85+
86+
julia> # `UpperTriangular`
87+
foreach(println, varname_and_value_leaves(@varname(x), UpperTriangular(x)))
88+
(x[1, 1], 1)
89+
(x[1, 2], 3)
90+
(x[2, 2], 4)
91+
92+
julia> # `Cholesky` with lower-triangular
93+
foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'L', 0)))
94+
(x.L[1, 1], 1.0)
95+
(x.L[2, 1], 0.0)
96+
(x.L[2, 2], 1.0)
97+
98+
julia> # `Cholesky` with upper-triangular
99+
foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'U', 0)))
100+
(x.U[1, 1], 1.0)
101+
(x.U[1, 2], 0.0)
102+
(x.U[2, 2], 1.0)
103+
```
104+
"""
105+
function varname_and_value_leaves(vn::VarName, x)
106+
return Iterators.map(value, Iterators.flatten(varname_and_value_leaves_inner(vn, x)))
107+
end
108+
109+
"""
110+
varname_and_value_leaves(container)
111+
112+
Return an iterator over all varname-value pairs that are represented by `container`.
113+
114+
This is the same as [`varname_and_value_leaves(vn::VarName, x)`](@ref) but over a container
115+
containing multiple varnames.
116+
117+
See also: [`varname_and_value_leaves(vn::VarName, x)`](@ref).
118+
119+
# Examples
120+
```jldoctest varname-and-value-leaves-container
121+
julia> using AbstractPPL: varname_and_value_leaves
122+
123+
julia> # With an `AbstractDict`
124+
dict = Dict(@varname(y) => 1, @varname(z) => [[2.0], [3.0]]);
125+
126+
julia> foreach(println, varname_and_value_leaves(dict))
127+
(y, 1)
128+
(z[1][1], 2.0)
129+
(z[2][1], 3.0)
130+
131+
julia> # With a `NamedTuple`
132+
nt = (y = 1, z = [[2.0], [3.0]]);
133+
134+
julia> foreach(println, varname_and_value_leaves(nt))
135+
(y, 1)
136+
(z[1][1], 2.0)
137+
(z[2][1], 3.0)
138+
```
139+
"""
140+
function varname_and_value_leaves(container::AbstractDict)
141+
return Iterators.flatten(varname_and_value_leaves(k, v) for (k, v) in container)
142+
end
143+
function varname_and_value_leaves(container::NamedTuple)
144+
return Iterators.flatten(
145+
varname_and_value_leaves(VarName{k}(), v) for (k, v) in pairs(container)
146+
)
147+
end
148+
149+
"""
150+
Leaf{T}
151+
152+
A container that represents the leaf of a nested structure, implementing
153+
`iterate` to return itself.
154+
155+
This is particularly useful in conjunction with `Iterators.flatten` to
156+
prevent flattening of nested structures.
157+
"""
158+
struct Leaf{T}
159+
value::T
160+
end
161+
162+
Leaf(xs...) = Leaf(xs)
163+
164+
# Allow us to treat `Leaf` as an iterator containing a single element.
165+
# Something like an `[x]` would also be an iterator with a single element,
166+
# but when we call `flatten` on this, it would also iterate over `x`,
167+
# unflattening that too. By making `Leaf` a single-element iterator, which
168+
# returns itself, we can call `iterate` on this as many times as we like
169+
# without causing any change. The result is that `Iterators.flatten`
170+
# will _not_ unflatten `Leaf`s.
171+
# Note that this is similar to how `Base.iterate` is implemented for `Real`::
172+
#
173+
# julia> iterate(1)
174+
# (1, nothing)
175+
#
176+
# One immediate example where this becomes in our scenario is that we might
177+
# have `missing` values in our data, which does _not_ have an `iterate`
178+
# implemented. Calling `Iterators.flatten` on this would cause an error.
179+
Base.iterate(leaf::Leaf) = leaf, nothing
180+
Base.iterate(::Leaf, _) = nothing
181+
182+
# Convenience.
183+
value(leaf::Leaf) = leaf.value
184+
185+
# Leaf-types.
186+
varname_and_value_leaves_inner(vn::VarName, x::Real) = [Leaf(vn, x)]
187+
function varname_and_value_leaves_inner(
188+
vn::VarName, val::AbstractArray{<:Union{Real,Missing}}
189+
)
190+
return (
191+
Leaf(
192+
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) AbstractPPL.getoptic(vn)),
193+
val[I],
194+
) for I in CartesianIndices(val)
195+
)
196+
end
197+
# Containers.
198+
function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray)
199+
return Iterators.flatten(
200+
varname_and_value_leaves_inner(
201+
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) AbstractPPL.getoptic(vn)),
202+
val[I],
203+
) for I in CartesianIndices(val)
204+
)
205+
end
206+
function varname_and_value_leaves_inner(vn::VarName, val::NamedTuple)
207+
iter = Iterators.map(keys(val)) do k
208+
optic = Accessors.PropertyLens{k}()
209+
varname_and_value_leaves_inner(
210+
VarName{getsym(vn)}(optic getoptic(vn)), optic(val)
211+
)
212+
end
213+
214+
return Iterators.flatten(iter)
215+
end
216+
# Special types.
217+
function varname_and_value_leaves_inner(vn::VarName, x::Cholesky)
218+
# TODO: Or do we use `PDMat` here?
219+
return if x.uplo == 'L'
220+
varname_and_value_leaves_inner(Accessors.PropertyLens{:L}() vn, x.L)
221+
else
222+
varname_and_value_leaves_inner(Accessors.PropertyLens{:U}() vn, x.U)
223+
end
224+
end
225+
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular)
226+
return (
227+
Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn)), x[I])
228+
# Iteration over the lower-triangular indices.
229+
for I in CartesianIndices(x) if I[1] >= I[2]
230+
)
231+
end
232+
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular)
233+
return (
234+
Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn)), x[I])
235+
# Iteration over the upper-triangular indices.
236+
for I in CartesianIndices(x) if I[1] <= I[2]
237+
)
238+
end

0 commit comments

Comments
 (0)