diff --git a/HISTORY.md b/HISTORY.md index 87bd2d552..91218d1fc 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,6 +2,9 @@ ## 0.38.0 +The `varname_leaves` and `varname_and_value_leaves` functions have been moved to AbstractPPL.jl. +Their behaviour is otherwise identical. + [...] ## 0.37.1 diff --git a/Project.toml b/Project.toml index 8e0ada64e..6dff71a03 100644 --- a/Project.toml +++ b/Project.toml @@ -47,7 +47,7 @@ DynamicPPLMooncakeExt = ["Mooncake"] [compat] ADTypes = "1" AbstractMCMC = "5" -AbstractPPL = "0.13" +AbstractPPL = "0.13.1" Accessors = "0.1" BangBang = "0.4.1" Bijectors = "0.13.18, 0.14, 0.15" diff --git a/docs/src/api.md b/docs/src/api.md index c6244b75f..d2150f3d7 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -435,8 +435,6 @@ DynamicPPL.maybe_invlink_before_eval!! Base.merge(::AbstractVarInfo) DynamicPPL.subset DynamicPPL.unflatten -DynamicPPL.varname_leaves -DynamicPPL.varname_and_value_leaves ``` ### Evaluation Contexts diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index a29696720..48efc1464 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -1,12 +1,7 @@ module DynamicPPLMCMCChainsExt -if isdefined(Base, :get_extension) - using DynamicPPL: DynamicPPL - using MCMCChains: MCMCChains -else - using ..DynamicPPL: DynamicPPL - using ..MCMCChains: MCMCChains -end +using DynamicPPL: DynamicPPL, AbstractPPL +using MCMCChains: MCMCChains # Load state from a `Chains`: By convention, it is stored in `:samplerstate` metadata function DynamicPPL.loadstate(chain::MCMCChains.Chains) @@ -121,7 +116,7 @@ function DynamicPPL.predict( varname_vals = mapreduce( collect, vcat, - map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)), + map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)), ) return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo)) diff --git a/src/utils.jl b/src/utils.jl index d3371271f..c7d1e089f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -837,245 +837,6 @@ end # Handle `AbstractDict` differently since `eltype` results in a `Pair`. infer_nested_eltype(::Type{<:AbstractDict{<:Any,ET}}) where {ET} = infer_nested_eltype(ET) -""" - varname_leaves(vn::VarName, val) - -Return an iterator over all varnames that are represented by `vn` on `val`. - -# Examples -```jldoctest -julia> using DynamicPPL: varname_leaves - -julia> foreach(println, varname_leaves(@varname(x), rand(2))) -x[1] -x[2] - -julia> foreach(println, varname_leaves(@varname(x[1:2]), rand(2))) -x[1:2][1] -x[1:2][2] - -julia> x = (y = 1, z = [[2.0], [3.0]]); - -julia> foreach(println, varname_leaves(@varname(x), x)) -x.y -x.z[1][1] -x.z[2][1] -``` -""" -varname_leaves(vn::VarName, ::Real) = [vn] -function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) - return ( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) for - I in CartesianIndices(val) - ) -end -function varname_leaves(vn::VarName, val::AbstractArray) - return Iterators.flatten( - varname_leaves( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), val[I] - ) for I in CartesianIndices(val) - ) -end -function varname_leaves(vn::VarName, val::NamedTuple) - iter = Iterators.map(keys(val)) do k - optic = Accessors.PropertyLens{k}() - varname_leaves(VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val)) - end - return Iterators.flatten(iter) -end - -""" - varname_and_value_leaves(vn::VarName, val) - -Return an iterator over all varname-value pairs that are represented by `vn` on `val`. - -# Examples -```jldoctest varname-and-value-leaves -julia> using DynamicPPL: varname_and_value_leaves - -julia> foreach(println, varname_and_value_leaves(@varname(x), 1:2)) -(x[1], 1) -(x[2], 2) - -julia> foreach(println, varname_and_value_leaves(@varname(x[1:2]), 1:2)) -(x[1:2][1], 1) -(x[1:2][2], 2) - -julia> x = (y = 1, z = [[2.0], [3.0]]); - -julia> foreach(println, varname_and_value_leaves(@varname(x), x)) -(x.y, 1) -(x.z[1][1], 2.0) -(x.z[2][1], 3.0) -``` - -There is also some special handling for certain types: - -```jldoctest varname-and-value-leaves -julia> using LinearAlgebra - -julia> x = reshape(1:4, 2, 2); - -julia> # `LowerTriangular` - foreach(println, varname_and_value_leaves(@varname(x), LowerTriangular(x))) -(x[1, 1], 1) -(x[2, 1], 2) -(x[2, 2], 4) - -julia> # `UpperTriangular` - foreach(println, varname_and_value_leaves(@varname(x), UpperTriangular(x))) -(x[1, 1], 1) -(x[1, 2], 3) -(x[2, 2], 4) - -julia> # `Cholesky` with lower-triangular - foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'L', 0))) -(x.L[1, 1], 1.0) -(x.L[2, 1], 0.0) -(x.L[2, 2], 1.0) - -julia> # `Cholesky` with upper-triangular - foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'U', 0))) -(x.U[1, 1], 1.0) -(x.U[1, 2], 0.0) -(x.U[2, 2], 1.0) -``` -""" -function varname_and_value_leaves(vn::VarName, x) - return Iterators.map(value, Iterators.flatten(varname_and_value_leaves_inner(vn, x))) -end - -""" - varname_and_value_leaves(container) - -Return an iterator over all varname-value pairs that are represented by `container`. - -This is the same as [`varname_and_value_leaves(vn::VarName, x)`](@ref) but over a container -containing multiple varnames. - -See also: [`varname_and_value_leaves(vn::VarName, x)`](@ref). - -# Examples -```jldoctest varname-and-value-leaves-container -julia> using DynamicPPL: varname_and_value_leaves - -julia> # With an `OrderedDict` - dict = OrderedDict(@varname(y) => 1, @varname(z) => [[2.0], [3.0]]); - -julia> foreach(println, varname_and_value_leaves(dict)) -(y, 1) -(z[1][1], 2.0) -(z[2][1], 3.0) - -julia> # With a `NamedTuple` - nt = (y = 1, z = [[2.0], [3.0]]); - -julia> foreach(println, varname_and_value_leaves(nt)) -(y, 1) -(z[1][1], 2.0) -(z[2][1], 3.0) -``` -""" -function varname_and_value_leaves(container::OrderedDict) - return Iterators.flatten(varname_and_value_leaves(k, v) for (k, v) in container) -end -function varname_and_value_leaves(container::NamedTuple) - return Iterators.flatten( - varname_and_value_leaves(VarName{k}(), v) for (k, v) in pairs(container) - ) -end - -""" - Leaf{T} - -A container that represents the leaf of a nested structure, implementing -`iterate` to return itself. - -This is particularly useful in conjunction with `Iterators.flatten` to -prevent flattening of nested structures. -""" -struct Leaf{T} - value::T -end - -Leaf(xs...) = Leaf(xs) - -# Allow us to treat `Leaf` as an iterator containing a single element. -# Something like an `[x]` would also be an iterator with a single element, -# but when we call `flatten` on this, it would also iterate over `x`, -# unflattening that too. By making `Leaf` a single-element iterator, which -# returns itself, we can call `iterate` on this as many times as we like -# without causing any change. The result is that `Iterators.flatten` -# will _not_ unflatten `Leaf`s. -# Note that this is similar to how `Base.iterate` is implemented for `Real`:: -# -# julia> iterate(1) -# (1, nothing) -# -# One immediate example where this becomes in our scenario is that we might -# have `missing` values in our data, which does _not_ have an `iterate` -# implemented. Calling `Iterators.flatten` on this would cause an error. -Base.iterate(leaf::Leaf) = leaf, nothing -Base.iterate(::Leaf, _) = nothing - -# Convenience. -value(leaf::Leaf) = leaf.value - -# Leaf-types. -varname_and_value_leaves_inner(vn::VarName, x::Real) = [Leaf(vn, x)] -function varname_and_value_leaves_inner( - vn::VarName, val::AbstractArray{<:Union{Real,Missing}} -) - return ( - Leaf( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)), - val[I], - ) for I in CartesianIndices(val) - ) -end -# Containers. -function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray) - return Iterators.flatten( - varname_and_value_leaves_inner( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)), - val[I], - ) for I in CartesianIndices(val) - ) -end -function varname_and_value_leaves_inner(vn::VarName, val::NamedTuple) - iter = Iterators.map(keys(val)) do k - optic = Accessors.PropertyLens{k}() - varname_and_value_leaves_inner( - VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val) - ) - end - - return Iterators.flatten(iter) -end -# Special types. -function varname_and_value_leaves_inner(vn::VarName, x::Cholesky) - # TODO: Or do we use `PDMat` here? - return if x.uplo == 'L' - varname_and_value_leaves_inner(Accessors.PropertyLens{:L}() ∘ vn, x.L) - else - varname_and_value_leaves_inner(Accessors.PropertyLens{:U}() ∘ vn, x.U) - end -end -function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular) - return ( - Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), x[I]) - # Iteration over the lower-triangular indices. - for I in CartesianIndices(x) if I[1] >= I[2] - ) -end -function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular) - return ( - Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), x[I]) - # Iteration over the upper-triangular indices. - for I in CartesianIndices(x) if I[1] <= I[2] - ) -end - broadcast_safe(x) = x broadcast_safe(x::Distribution) = (x,) broadcast_safe(x::AbstractContext) = (x,) diff --git a/test/model.jl b/test/model.jl index 81f84e548..f062a70b4 100644 --- a/test/model.jl +++ b/test/model.jl @@ -347,7 +347,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() # Extract varnames and values. vns_and_vals_xs = map( - collect ∘ Base.Fix1(DynamicPPL.varname_and_value_leaves, @varname(x)), xs + collect ∘ Base.Fix1(AbstractPPL.varname_and_value_leaves, @varname(x)), xs ) vns = map(first, first(vns_and_vals_xs)) vals = map(vns_and_vals_xs) do vns_and_vals diff --git a/test/test_util.jl b/test/test_util.jl index e04486760..c6762ed45 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -72,7 +72,7 @@ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::I # We have to use varname_and_value_leaves so that each parameter is a scalar dicts = map(varinfos) do t vals = DynamicPPL.values_as(t, OrderedDict) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)) tuples = mapreduce(collect, vcat, iters) # The following loop is a replacement for: # push!(varnames, map(first, tuples)...) diff --git a/test/varinfo.jl b/test/varinfo.jl index ba7c17b34..f36af44a1 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -488,7 +488,7 @@ end θ_new = var_info[:] @test θ_old != θ_new vals = DynamicPPL.values_as(var_info, OrderedDict) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)) for (n, v) in mapreduce(collect, vcat, iters) n = string(n) if Symbol(n) ∉ keys(chain)