Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## 0.14.0

Moved the functions `varname_leaves` and `varname_and_value_leaves` to AbstractPPL.
They are now part of the public API of AbstractPPL.

## 0.13.0

Minimum compatibility has been bumped to Julia 1.10.
Expand Down
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,30 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
keywords = ["probablistic programming"]
license = "MIT"
desc = "Common interfaces for probabilistic programming"
version = "0.13.0"
version = "0.14.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[extensions]
AbstractPPLDistributionsExt = ["Distributions", "LinearAlgebra"]
AbstractPPLDistributionsExt = ["Distributions"]

[compat]
AbstractMCMC = "2, 3, 4, 5"
Accessors = "0.1"
DensityInterface = "0.4"
Distributions = "0.25"
LinearAlgebra = "<0.0.1, 1.10"
JSON = "0.19 - 0.21"
LinearAlgebra = "<0.0.1, 1.10"
Random = "1.6"
StatsBase = "0.32, 0.33, 0.34"
julia = "1.10"
7 changes: 7 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ hasvalue
getvalue
```

## Splitting VarNames up into components

```@docs
varname_leaves
varname_and_value_leaves
```

## VarName serialisation

```@docs
Expand Down
2 changes: 1 addition & 1 deletion ext/AbstractPPLDistributionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ This decision may be revisited in the future.

module AbstractPPLDistributionsExt

using AbstractPPL: AbstractPPL, VarName, Accessors
using AbstractPPL: AbstractPPL, VarName, Accessors, LinearAlgebra
using Distributions: Distributions
using LinearAlgebra: Cholesky, LowerTriangular, UpperTriangular

Expand Down
5 changes: 4 additions & 1 deletion src/AbstractPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ export VarName,
prefix,
unprefix,
getvalue,
hasvalue
hasvalue,
varname_leaves,
varname_and_value_leaves

# Abstract model functions
export AbstractProbabilisticProgram,
Expand All @@ -31,6 +33,7 @@ include("varname.jl")
include("abstractmodeltrace.jl")
include("abstractprobprog.jl")
include("evaluate.jl")
include("varname_leaves.jl")
include("hasvalue.jl")

end # module
243 changes: 243 additions & 0 deletions src/varname_leaves.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
using LinearAlgebra: LinearAlgebra

"""
varname_leaves(vn::VarName, val)

Return an iterator over all varnames that are represented by `vn` on `val`.

# Examples
```jldoctest
julia> using AbstractPPL: varname_leaves

julia> foreach(println, varname_leaves(@varname(x), rand(2)))
x[1]
x[2]

julia> foreach(println, varname_leaves(@varname(x[1:2]), rand(2)))
x[1:2][1]
x[1:2][2]

julia> x = (y = 1, z = [[2.0], [3.0]]);

julia> foreach(println, varname_leaves(@varname(x), x))
x.y
x.z[1][1]
x.z[2][1]
```
"""
varname_leaves(vn::VarName, ::Real) = [vn]
function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}})
return (
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) for
I in CartesianIndices(val)
)
end
function varname_leaves(vn::VarName, val::AbstractArray)
return Iterators.flatten(
varname_leaves(
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), val[I]
) for I in CartesianIndices(val)
)
end
function varname_leaves(vn::VarName, val::NamedTuple)
iter = Iterators.map(keys(val)) do k
optic = Accessors.PropertyLens{k}()
varname_leaves(VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val))
end
return Iterators.flatten(iter)
end

"""
varname_and_value_leaves(vn::VarName, val)

Return an iterator over all varname-value pairs that are represented by `vn` on `val`.

# Examples
```jldoctest varname-and-value-leaves
julia> using AbstractPPL: varname_and_value_leaves

julia> foreach(println, varname_and_value_leaves(@varname(x), 1:2))
(x[1], 1)
(x[2], 2)

julia> foreach(println, varname_and_value_leaves(@varname(x[1:2]), 1:2))
(x[1:2][1], 1)
(x[1:2][2], 2)

julia> x = (y = 1, z = [[2.0], [3.0]]);

julia> foreach(println, varname_and_value_leaves(@varname(x), x))
(x.y, 1)
(x.z[1][1], 2.0)
(x.z[2][1], 3.0)
```

There is also some special handling for certain types:

```jldoctest varname-and-value-leaves
julia> using LinearAlgebra

julia> x = reshape(1:4, 2, 2);

julia> # `LowerTriangular`
foreach(println, varname_and_value_leaves(@varname(x), LowerTriangular(x)))
(x[1, 1], 1)
(x[2, 1], 2)
(x[2, 2], 4)

julia> # `UpperTriangular`
foreach(println, varname_and_value_leaves(@varname(x), UpperTriangular(x)))
(x[1, 1], 1)
(x[1, 2], 3)
(x[2, 2], 4)

julia> # `Cholesky` with lower-triangular
foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'L', 0)))
(x.L[1, 1], 1.0)
(x.L[2, 1], 0.0)
(x.L[2, 2], 1.0)

julia> # `Cholesky` with upper-triangular
foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'U', 0)))
(x.U[1, 1], 1.0)
(x.U[1, 2], 0.0)
(x.U[2, 2], 1.0)
```
"""
function varname_and_value_leaves(vn::VarName, x)
return Iterators.map(value, Iterators.flatten(varname_and_value_leaves_inner(vn, x)))
end

"""
varname_and_value_leaves(container)

Return an iterator over all varname-value pairs that are represented by `container`.

This is the same as [`varname_and_value_leaves(vn::VarName, x)`](@ref) but over a container
containing multiple varnames.

See also: [`varname_and_value_leaves(vn::VarName, x)`](@ref).

# Examples
```jldoctest varname-and-value-leaves-container
julia> using AbstractPPL: varname_and_value_leaves

julia> using OrderedCollections: OrderedDict

julia> # With an `AbstractDict` (we use `OrderedDict` here
# to ensure consistent ordering in doctests)
dict = OrderedDict(@varname(y) => 1, @varname(z) => [[2.0], [3.0]]);

julia> foreach(println, varname_and_value_leaves(dict))
(y, 1)
(z[1][1], 2.0)
(z[2][1], 3.0)

julia> # With a `NamedTuple`
nt = (y = 1, z = [[2.0], [3.0]]);

julia> foreach(println, varname_and_value_leaves(nt))
(y, 1)
(z[1][1], 2.0)
(z[2][1], 3.0)
```
"""
function varname_and_value_leaves(container::AbstractDict)
return Iterators.flatten(varname_and_value_leaves(k, v) for (k, v) in container)
end
function varname_and_value_leaves(container::NamedTuple)
return Iterators.flatten(
varname_and_value_leaves(VarName{k}(), v) for (k, v) in pairs(container)
)
end

"""
Leaf{T}

A container that represents the leaf of a nested structure, implementing
`iterate` to return itself.

This is particularly useful in conjunction with `Iterators.flatten` to
prevent flattening of nested structures.
"""
struct Leaf{T}
value::T
end

Leaf(xs...) = Leaf(xs)

# Allow us to treat `Leaf` as an iterator containing a single element.
# Something like an `[x]` would also be an iterator with a single element,
# but when we call `flatten` on this, it would also iterate over `x`,
# unflattening that too. By making `Leaf` a single-element iterator, which
# returns itself, we can call `iterate` on this as many times as we like
# without causing any change. The result is that `Iterators.flatten`
# will _not_ unflatten `Leaf`s.
# Note that this is similar to how `Base.iterate` is implemented for `Real`::
#
# julia> iterate(1)
# (1, nothing)
#
# One immediate example where this becomes in our scenario is that we might
# have `missing` values in our data, which does _not_ have an `iterate`
# implemented. Calling `Iterators.flatten` on this would cause an error.
Base.iterate(leaf::Leaf) = leaf, nothing
Base.iterate(::Leaf, _) = nothing

# Convenience.
value(leaf::Leaf) = leaf.value

# Leaf-types.
varname_and_value_leaves_inner(vn::VarName, x::Real) = [Leaf(vn, x)]
function varname_and_value_leaves_inner(
vn::VarName, val::AbstractArray{<:Union{Real,Missing}}
)
return (
Leaf(
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)),
val[I],
) for I in CartesianIndices(val)
)
end
# Containers.
function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray)
return Iterators.flatten(
varname_and_value_leaves_inner(
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)),
val[I],
) for I in CartesianIndices(val)
)
end
function varname_and_value_leaves_inner(vn::VarName, val::NamedTuple)
iter = Iterators.map(keys(val)) do k
optic = Accessors.PropertyLens{k}()
varname_and_value_leaves_inner(
VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val)
)
end

return Iterators.flatten(iter)
end
# Special types.
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.Cholesky)
# TODO: Or do we use `PDMat` here?
return if x.uplo == 'L'
varname_and_value_leaves_inner(Accessors.PropertyLens{:L}() ∘ vn, x.L)
else
varname_and_value_leaves_inner(Accessors.PropertyLens{:U}() ∘ vn, x.U)
end
end
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular)
return (
Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), x[I])
# Iteration over the lower-triangular indices.
for I in CartesianIndices(x) if I[1] >= I[2]
)
end
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular)
return (
Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), x[I])
# Iteration over the upper-triangular indices.
for I in CartesianIndices(x) if I[1] <= I[2]
)
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

Expand Down
Loading