Skip to content

Commit 6441754

Browse files
authored
move varname leaves code to AbstractPPL (#134)
* Move `varname{_and_value,}_leaves` to AbstractPPL * Write changelog, add exports * Add LinAlg dep * Fix various deps and stuff * OrderedCollections also needs to be a test dep
1 parent 8dddcea commit 6441754

File tree

7 files changed

+265
-6
lines changed

7 files changed

+265
-6
lines changed

HISTORY.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
## 0.14.0
2+
3+
Moved the functions `varname_leaves` and `varname_and_value_leaves` to AbstractPPL.
4+
They are now part of the public API of AbstractPPL.
5+
16
## 0.13.0
27

38
Minimum compatibility has been bumped to Julia 1.10.

Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,30 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
33
keywords = ["probablistic programming"]
44
license = "MIT"
55
desc = "Common interfaces for probabilistic programming"
6-
version = "0.13.0"
6+
version = "0.14.0"
77

88
[deps]
99
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
1010
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
1111
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
1212
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
13+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1314
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1415
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1516

1617
[weakdeps]
1718
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
18-
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1919

2020
[extensions]
21-
AbstractPPLDistributionsExt = ["Distributions", "LinearAlgebra"]
21+
AbstractPPLDistributionsExt = ["Distributions"]
2222

2323
[compat]
2424
AbstractMCMC = "2, 3, 4, 5"
2525
Accessors = "0.1"
2626
DensityInterface = "0.4"
2727
Distributions = "0.25"
28-
LinearAlgebra = "<0.0.1, 1.10"
2928
JSON = "0.19 - 0.21"
29+
LinearAlgebra = "<0.0.1, 1.10"
3030
Random = "1.6"
3131
StatsBase = "0.32, 0.33, 0.34"
3232
julia = "1.10"

docs/src/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ hasvalue
2828
getvalue
2929
```
3030

31+
## Splitting VarNames up into components
32+
33+
```@docs
34+
varname_leaves
35+
varname_and_value_leaves
36+
```
37+
3138
## VarName serialisation
3239

3340
```@docs

ext/AbstractPPLDistributionsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ This decision may be revisited in the future.
4949

5050
module AbstractPPLDistributionsExt
5151

52-
using AbstractPPL: AbstractPPL, VarName, Accessors
52+
using AbstractPPL: AbstractPPL, VarName, Accessors, LinearAlgebra
5353
using Distributions: Distributions
5454
using LinearAlgebra: Cholesky, LowerTriangular, UpperTriangular
5555

src/AbstractPPL.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ export VarName,
1818
prefix,
1919
unprefix,
2020
getvalue,
21-
hasvalue
21+
hasvalue,
22+
varname_leaves,
23+
varname_and_value_leaves
2224

2325
# Abstract model functions
2426
export AbstractProbabilisticProgram,
@@ -31,6 +33,7 @@ include("varname.jl")
3133
include("abstractmodeltrace.jl")
3234
include("abstractprobprog.jl")
3335
include("evaluate.jl")
36+
include("varname_leaves.jl")
3437
include("hasvalue.jl")
3538

3639
end # module

src/varname_leaves.jl

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
using LinearAlgebra: LinearAlgebra
2+
3+
"""
4+
varname_leaves(vn::VarName, val)
5+
6+
Return an iterator over all varnames that are represented by `vn` on `val`.
7+
8+
# Examples
9+
```jldoctest
10+
julia> using AbstractPPL: varname_leaves
11+
12+
julia> foreach(println, varname_leaves(@varname(x), rand(2)))
13+
x[1]
14+
x[2]
15+
16+
julia> foreach(println, varname_leaves(@varname(x[1:2]), rand(2)))
17+
x[1:2][1]
18+
x[1:2][2]
19+
20+
julia> x = (y = 1, z = [[2.0], [3.0]]);
21+
22+
julia> foreach(println, varname_leaves(@varname(x), x))
23+
x.y
24+
x.z[1][1]
25+
x.z[2][1]
26+
```
27+
"""
28+
varname_leaves(vn::VarName, ::Real) = [vn]
29+
function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}})
30+
return (
31+
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn)) for
32+
I in CartesianIndices(val)
33+
)
34+
end
35+
function varname_leaves(vn::VarName, val::AbstractArray)
36+
return Iterators.flatten(
37+
varname_leaves(
38+
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn)), val[I]
39+
) for I in CartesianIndices(val)
40+
)
41+
end
42+
function varname_leaves(vn::VarName, val::NamedTuple)
43+
iter = Iterators.map(keys(val)) do k
44+
optic = Accessors.PropertyLens{k}()
45+
varname_leaves(VarName{getsym(vn)}(optic getoptic(vn)), optic(val))
46+
end
47+
return Iterators.flatten(iter)
48+
end
49+
50+
"""
51+
varname_and_value_leaves(vn::VarName, val)
52+
53+
Return an iterator over all varname-value pairs that are represented by `vn` on `val`.
54+
55+
# Examples
56+
```jldoctest varname-and-value-leaves
57+
julia> using AbstractPPL: varname_and_value_leaves
58+
59+
julia> foreach(println, varname_and_value_leaves(@varname(x), 1:2))
60+
(x[1], 1)
61+
(x[2], 2)
62+
63+
julia> foreach(println, varname_and_value_leaves(@varname(x[1:2]), 1:2))
64+
(x[1:2][1], 1)
65+
(x[1:2][2], 2)
66+
67+
julia> x = (y = 1, z = [[2.0], [3.0]]);
68+
69+
julia> foreach(println, varname_and_value_leaves(@varname(x), x))
70+
(x.y, 1)
71+
(x.z[1][1], 2.0)
72+
(x.z[2][1], 3.0)
73+
```
74+
75+
There is also some special handling for certain types:
76+
77+
```jldoctest varname-and-value-leaves
78+
julia> using LinearAlgebra
79+
80+
julia> x = reshape(1:4, 2, 2);
81+
82+
julia> # `LowerTriangular`
83+
foreach(println, varname_and_value_leaves(@varname(x), LowerTriangular(x)))
84+
(x[1, 1], 1)
85+
(x[2, 1], 2)
86+
(x[2, 2], 4)
87+
88+
julia> # `UpperTriangular`
89+
foreach(println, varname_and_value_leaves(@varname(x), UpperTriangular(x)))
90+
(x[1, 1], 1)
91+
(x[1, 2], 3)
92+
(x[2, 2], 4)
93+
94+
julia> # `Cholesky` with lower-triangular
95+
foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'L', 0)))
96+
(x.L[1, 1], 1.0)
97+
(x.L[2, 1], 0.0)
98+
(x.L[2, 2], 1.0)
99+
100+
julia> # `Cholesky` with upper-triangular
101+
foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'U', 0)))
102+
(x.U[1, 1], 1.0)
103+
(x.U[1, 2], 0.0)
104+
(x.U[2, 2], 1.0)
105+
```
106+
"""
107+
function varname_and_value_leaves(vn::VarName, x)
108+
return Iterators.map(value, Iterators.flatten(varname_and_value_leaves_inner(vn, x)))
109+
end
110+
111+
"""
112+
varname_and_value_leaves(container)
113+
114+
Return an iterator over all varname-value pairs that are represented by `container`.
115+
116+
This is the same as [`varname_and_value_leaves(vn::VarName, x)`](@ref) but over a container
117+
containing multiple varnames.
118+
119+
See also: [`varname_and_value_leaves(vn::VarName, x)`](@ref).
120+
121+
# Examples
122+
```jldoctest varname-and-value-leaves-container
123+
julia> using AbstractPPL: varname_and_value_leaves
124+
125+
julia> using OrderedCollections: OrderedDict
126+
127+
julia> # With an `AbstractDict` (we use `OrderedDict` here
128+
# to ensure consistent ordering in doctests)
129+
dict = OrderedDict(@varname(y) => 1, @varname(z) => [[2.0], [3.0]]);
130+
131+
julia> foreach(println, varname_and_value_leaves(dict))
132+
(y, 1)
133+
(z[1][1], 2.0)
134+
(z[2][1], 3.0)
135+
136+
julia> # With a `NamedTuple`
137+
nt = (y = 1, z = [[2.0], [3.0]]);
138+
139+
julia> foreach(println, varname_and_value_leaves(nt))
140+
(y, 1)
141+
(z[1][1], 2.0)
142+
(z[2][1], 3.0)
143+
```
144+
"""
145+
function varname_and_value_leaves(container::AbstractDict)
146+
return Iterators.flatten(varname_and_value_leaves(k, v) for (k, v) in container)
147+
end
148+
function varname_and_value_leaves(container::NamedTuple)
149+
return Iterators.flatten(
150+
varname_and_value_leaves(VarName{k}(), v) for (k, v) in pairs(container)
151+
)
152+
end
153+
154+
"""
155+
Leaf{T}
156+
157+
A container that represents the leaf of a nested structure, implementing
158+
`iterate` to return itself.
159+
160+
This is particularly useful in conjunction with `Iterators.flatten` to
161+
prevent flattening of nested structures.
162+
"""
163+
struct Leaf{T}
164+
value::T
165+
end
166+
167+
Leaf(xs...) = Leaf(xs)
168+
169+
# Allow us to treat `Leaf` as an iterator containing a single element.
170+
# Something like an `[x]` would also be an iterator with a single element,
171+
# but when we call `flatten` on this, it would also iterate over `x`,
172+
# unflattening that too. By making `Leaf` a single-element iterator, which
173+
# returns itself, we can call `iterate` on this as many times as we like
174+
# without causing any change. The result is that `Iterators.flatten`
175+
# will _not_ unflatten `Leaf`s.
176+
# Note that this is similar to how `Base.iterate` is implemented for `Real`::
177+
#
178+
# julia> iterate(1)
179+
# (1, nothing)
180+
#
181+
# One immediate example where this becomes in our scenario is that we might
182+
# have `missing` values in our data, which does _not_ have an `iterate`
183+
# implemented. Calling `Iterators.flatten` on this would cause an error.
184+
Base.iterate(leaf::Leaf) = leaf, nothing
185+
Base.iterate(::Leaf, _) = nothing
186+
187+
# Convenience.
188+
value(leaf::Leaf) = leaf.value
189+
190+
# Leaf-types.
191+
varname_and_value_leaves_inner(vn::VarName, x::Real) = [Leaf(vn, x)]
192+
function varname_and_value_leaves_inner(
193+
vn::VarName, val::AbstractArray{<:Union{Real,Missing}}
194+
)
195+
return (
196+
Leaf(
197+
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) AbstractPPL.getoptic(vn)),
198+
val[I],
199+
) for I in CartesianIndices(val)
200+
)
201+
end
202+
# Containers.
203+
function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray)
204+
return Iterators.flatten(
205+
varname_and_value_leaves_inner(
206+
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) AbstractPPL.getoptic(vn)),
207+
val[I],
208+
) for I in CartesianIndices(val)
209+
)
210+
end
211+
function varname_and_value_leaves_inner(vn::VarName, val::NamedTuple)
212+
iter = Iterators.map(keys(val)) do k
213+
optic = Accessors.PropertyLens{k}()
214+
varname_and_value_leaves_inner(
215+
VarName{getsym(vn)}(optic getoptic(vn)), optic(val)
216+
)
217+
end
218+
219+
return Iterators.flatten(iter)
220+
end
221+
# Special types.
222+
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.Cholesky)
223+
# TODO: Or do we use `PDMat` here?
224+
return if x.uplo == 'L'
225+
varname_and_value_leaves_inner(Accessors.PropertyLens{:L}() vn, x.L)
226+
else
227+
varname_and_value_leaves_inner(Accessors.PropertyLens{:U}() vn, x.U)
228+
end
229+
end
230+
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular)
231+
return (
232+
Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn)), x[I])
233+
# Iteration over the lower-triangular indices.
234+
for I in CartesianIndices(x) if I[1] >= I[2]
235+
)
236+
end
237+
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular)
238+
return (
239+
Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn)), x[I])
240+
# Iteration over the upper-triangular indices.
241+
for I in CartesianIndices(x) if I[1] <= I[2]
242+
)
243+
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
66
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
9+
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
910
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1011
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1112

0 commit comments

Comments
 (0)