diff --git a/HISTORY.md b/HISTORY.md index 66fba61..d8542a3 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,8 @@ +## 0.14.0 + +Moved the functions `varname_leaves` and `varname_and_value_leaves` to AbstractPPL. +They are now part of the public API of AbstractPPL. + ## 0.13.0 Minimum compatibility has been bumped to Julia 1.10. diff --git a/Project.toml b/Project.toml index 148cde2..5edd9e2 100644 --- a/Project.toml +++ b/Project.toml @@ -3,30 +3,30 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" keywords = ["probablistic programming"] license = "MIT" desc = "Common interfaces for probabilistic programming" -version = "0.13.0" +version = "0.14.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [weakdeps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [extensions] -AbstractPPLDistributionsExt = ["Distributions", "LinearAlgebra"] +AbstractPPLDistributionsExt = ["Distributions"] [compat] AbstractMCMC = "2, 3, 4, 5" Accessors = "0.1" DensityInterface = "0.4" Distributions = "0.25" -LinearAlgebra = "<0.0.1, 1.10" JSON = "0.19 - 0.21" +LinearAlgebra = "<0.0.1, 1.10" Random = "1.6" StatsBase = "0.32, 0.33, 0.34" julia = "1.10" diff --git a/docs/src/api.md b/docs/src/api.md index 9c05c47..9f289ab 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -28,6 +28,13 @@ hasvalue getvalue ``` +## Splitting VarNames up into components + +```@docs +varname_leaves +varname_and_value_leaves +``` + ## VarName serialisation ```@docs diff --git a/ext/AbstractPPLDistributionsExt.jl b/ext/AbstractPPLDistributionsExt.jl index 857e04e..d58c05c 100644 --- a/ext/AbstractPPLDistributionsExt.jl +++ b/ext/AbstractPPLDistributionsExt.jl @@ -49,7 +49,7 @@ This decision may be revisited in the future. module AbstractPPLDistributionsExt -using AbstractPPL: AbstractPPL, VarName, Accessors +using AbstractPPL: AbstractPPL, VarName, Accessors, LinearAlgebra using Distributions: Distributions using LinearAlgebra: Cholesky, LowerTriangular, UpperTriangular diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index dd495db..30a4f7e 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -18,7 +18,9 @@ export VarName, prefix, unprefix, getvalue, - hasvalue + hasvalue, + varname_leaves, + varname_and_value_leaves # Abstract model functions export AbstractProbabilisticProgram, @@ -31,6 +33,7 @@ include("varname.jl") include("abstractmodeltrace.jl") include("abstractprobprog.jl") include("evaluate.jl") +include("varname_leaves.jl") include("hasvalue.jl") end # module diff --git a/src/varname_leaves.jl b/src/varname_leaves.jl new file mode 100644 index 0000000..e4025e8 --- /dev/null +++ b/src/varname_leaves.jl @@ -0,0 +1,243 @@ +using LinearAlgebra: LinearAlgebra + +""" + varname_leaves(vn::VarName, val) + +Return an iterator over all varnames that are represented by `vn` on `val`. + +# Examples +```jldoctest +julia> using AbstractPPL: varname_leaves + +julia> foreach(println, varname_leaves(@varname(x), rand(2))) +x[1] +x[2] + +julia> foreach(println, varname_leaves(@varname(x[1:2]), rand(2))) +x[1:2][1] +x[1:2][2] + +julia> x = (y = 1, z = [[2.0], [3.0]]); + +julia> foreach(println, varname_leaves(@varname(x), x)) +x.y +x.z[1][1] +x.z[2][1] +``` +""" +varname_leaves(vn::VarName, ::Real) = [vn] +function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) + return ( + VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) for + I in CartesianIndices(val) + ) +end +function varname_leaves(vn::VarName, val::AbstractArray) + return Iterators.flatten( + varname_leaves( + VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), val[I] + ) for I in CartesianIndices(val) + ) +end +function varname_leaves(vn::VarName, val::NamedTuple) + iter = Iterators.map(keys(val)) do k + optic = Accessors.PropertyLens{k}() + varname_leaves(VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val)) + end + return Iterators.flatten(iter) +end + +""" + varname_and_value_leaves(vn::VarName, val) + +Return an iterator over all varname-value pairs that are represented by `vn` on `val`. + +# Examples +```jldoctest varname-and-value-leaves +julia> using AbstractPPL: varname_and_value_leaves + +julia> foreach(println, varname_and_value_leaves(@varname(x), 1:2)) +(x[1], 1) +(x[2], 2) + +julia> foreach(println, varname_and_value_leaves(@varname(x[1:2]), 1:2)) +(x[1:2][1], 1) +(x[1:2][2], 2) + +julia> x = (y = 1, z = [[2.0], [3.0]]); + +julia> foreach(println, varname_and_value_leaves(@varname(x), x)) +(x.y, 1) +(x.z[1][1], 2.0) +(x.z[2][1], 3.0) +``` + +There is also some special handling for certain types: + +```jldoctest varname-and-value-leaves +julia> using LinearAlgebra + +julia> x = reshape(1:4, 2, 2); + +julia> # `LowerTriangular` + foreach(println, varname_and_value_leaves(@varname(x), LowerTriangular(x))) +(x[1, 1], 1) +(x[2, 1], 2) +(x[2, 2], 4) + +julia> # `UpperTriangular` + foreach(println, varname_and_value_leaves(@varname(x), UpperTriangular(x))) +(x[1, 1], 1) +(x[1, 2], 3) +(x[2, 2], 4) + +julia> # `Cholesky` with lower-triangular + foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'L', 0))) +(x.L[1, 1], 1.0) +(x.L[2, 1], 0.0) +(x.L[2, 2], 1.0) + +julia> # `Cholesky` with upper-triangular + foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'U', 0))) +(x.U[1, 1], 1.0) +(x.U[1, 2], 0.0) +(x.U[2, 2], 1.0) +``` +""" +function varname_and_value_leaves(vn::VarName, x) + return Iterators.map(value, Iterators.flatten(varname_and_value_leaves_inner(vn, x))) +end + +""" + varname_and_value_leaves(container) + +Return an iterator over all varname-value pairs that are represented by `container`. + +This is the same as [`varname_and_value_leaves(vn::VarName, x)`](@ref) but over a container +containing multiple varnames. + +See also: [`varname_and_value_leaves(vn::VarName, x)`](@ref). + +# Examples +```jldoctest varname-and-value-leaves-container +julia> using AbstractPPL: varname_and_value_leaves + +julia> using OrderedCollections: OrderedDict + +julia> # With an `AbstractDict` (we use `OrderedDict` here + # to ensure consistent ordering in doctests) + dict = OrderedDict(@varname(y) => 1, @varname(z) => [[2.0], [3.0]]); + +julia> foreach(println, varname_and_value_leaves(dict)) +(y, 1) +(z[1][1], 2.0) +(z[2][1], 3.0) + +julia> # With a `NamedTuple` + nt = (y = 1, z = [[2.0], [3.0]]); + +julia> foreach(println, varname_and_value_leaves(nt)) +(y, 1) +(z[1][1], 2.0) +(z[2][1], 3.0) +``` +""" +function varname_and_value_leaves(container::AbstractDict) + return Iterators.flatten(varname_and_value_leaves(k, v) for (k, v) in container) +end +function varname_and_value_leaves(container::NamedTuple) + return Iterators.flatten( + varname_and_value_leaves(VarName{k}(), v) for (k, v) in pairs(container) + ) +end + +""" + Leaf{T} + +A container that represents the leaf of a nested structure, implementing +`iterate` to return itself. + +This is particularly useful in conjunction with `Iterators.flatten` to +prevent flattening of nested structures. +""" +struct Leaf{T} + value::T +end + +Leaf(xs...) = Leaf(xs) + +# Allow us to treat `Leaf` as an iterator containing a single element. +# Something like an `[x]` would also be an iterator with a single element, +# but when we call `flatten` on this, it would also iterate over `x`, +# unflattening that too. By making `Leaf` a single-element iterator, which +# returns itself, we can call `iterate` on this as many times as we like +# without causing any change. The result is that `Iterators.flatten` +# will _not_ unflatten `Leaf`s. +# Note that this is similar to how `Base.iterate` is implemented for `Real`:: +# +# julia> iterate(1) +# (1, nothing) +# +# One immediate example where this becomes in our scenario is that we might +# have `missing` values in our data, which does _not_ have an `iterate` +# implemented. Calling `Iterators.flatten` on this would cause an error. +Base.iterate(leaf::Leaf) = leaf, nothing +Base.iterate(::Leaf, _) = nothing + +# Convenience. +value(leaf::Leaf) = leaf.value + +# Leaf-types. +varname_and_value_leaves_inner(vn::VarName, x::Real) = [Leaf(vn, x)] +function varname_and_value_leaves_inner( + vn::VarName, val::AbstractArray{<:Union{Real,Missing}} +) + return ( + Leaf( + VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)), + val[I], + ) for I in CartesianIndices(val) + ) +end +# Containers. +function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray) + return Iterators.flatten( + varname_and_value_leaves_inner( + VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)), + val[I], + ) for I in CartesianIndices(val) + ) +end +function varname_and_value_leaves_inner(vn::VarName, val::NamedTuple) + iter = Iterators.map(keys(val)) do k + optic = Accessors.PropertyLens{k}() + varname_and_value_leaves_inner( + VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val) + ) + end + + return Iterators.flatten(iter) +end +# Special types. +function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.Cholesky) + # TODO: Or do we use `PDMat` here? + return if x.uplo == 'L' + varname_and_value_leaves_inner(Accessors.PropertyLens{:L}() ∘ vn, x.L) + else + varname_and_value_leaves_inner(Accessors.PropertyLens{:U}() ∘ vn, x.U) + end +end +function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular) + return ( + Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), x[I]) + # Iteration over the lower-triangular indices. + for I in CartesianIndices(x) if I[1] >= I[2] + ) +end +function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular) + return ( + Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), x[I]) + # Iteration over the upper-triangular indices. + for I in CartesianIndices(x) if I[1] <= I[2] + ) +end diff --git a/test/Project.toml b/test/Project.toml index e19c50e..b2a8fca 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"