Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## 0.38.0

The `varname_leaves` and `varname_and_value_leaves` functions have been moved to AbstractPPL.jl.
Their behaviour is otherwise identical.

[...]

## 0.37.1
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ DynamicPPLMooncakeExt = ["Mooncake"]
[compat]
ADTypes = "1"
AbstractMCMC = "5"
AbstractPPL = "0.13"
AbstractPPL = "0.13.1"
Accessors = "0.1"
BangBang = "0.4.1"
Bijectors = "0.13.18, 0.14, 0.15"
Expand Down
2 changes: 0 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,6 @@ DynamicPPL.maybe_invlink_before_eval!!
Base.merge(::AbstractVarInfo)
DynamicPPL.subset
DynamicPPL.unflatten
DynamicPPL.varname_leaves
DynamicPPL.varname_and_value_leaves
```

### Evaluation Contexts
Expand Down
11 changes: 3 additions & 8 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
module DynamicPPLMCMCChainsExt

if isdefined(Base, :get_extension)
using DynamicPPL: DynamicPPL
using MCMCChains: MCMCChains
else
using ..DynamicPPL: DynamicPPL
using ..MCMCChains: MCMCChains
end
using DynamicPPL: DynamicPPL, AbstractPPL
using MCMCChains: MCMCChains

# Load state from a `Chains`: By convention, it is stored in `:samplerstate` metadata
function DynamicPPL.loadstate(chain::MCMCChains.Chains)
Expand Down Expand Up @@ -121,7 +116,7 @@ function DynamicPPL.predict(
varname_vals = mapreduce(
collect,
vcat,
map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)),
map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)),
)

return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo))
Expand Down
239 changes: 0 additions & 239 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -837,245 +837,6 @@ end
# Handle `AbstractDict` differently since `eltype` results in a `Pair`.
infer_nested_eltype(::Type{<:AbstractDict{<:Any,ET}}) where {ET} = infer_nested_eltype(ET)

"""
varname_leaves(vn::VarName, val)

Return an iterator over all varnames that are represented by `vn` on `val`.

# Examples
```jldoctest
julia> using DynamicPPL: varname_leaves

julia> foreach(println, varname_leaves(@varname(x), rand(2)))
x[1]
x[2]

julia> foreach(println, varname_leaves(@varname(x[1:2]), rand(2)))
x[1:2][1]
x[1:2][2]

julia> x = (y = 1, z = [[2.0], [3.0]]);

julia> foreach(println, varname_leaves(@varname(x), x))
x.y
x.z[1][1]
x.z[2][1]
```
"""
varname_leaves(vn::VarName, ::Real) = [vn]
function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}})
return (
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) for
I in CartesianIndices(val)
)
end
function varname_leaves(vn::VarName, val::AbstractArray)
return Iterators.flatten(
varname_leaves(
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), val[I]
) for I in CartesianIndices(val)
)
end
function varname_leaves(vn::VarName, val::NamedTuple)
iter = Iterators.map(keys(val)) do k
optic = Accessors.PropertyLens{k}()
varname_leaves(VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val))
end
return Iterators.flatten(iter)
end

"""
varname_and_value_leaves(vn::VarName, val)

Return an iterator over all varname-value pairs that are represented by `vn` on `val`.

# Examples
```jldoctest varname-and-value-leaves
julia> using DynamicPPL: varname_and_value_leaves

julia> foreach(println, varname_and_value_leaves(@varname(x), 1:2))
(x[1], 1)
(x[2], 2)

julia> foreach(println, varname_and_value_leaves(@varname(x[1:2]), 1:2))
(x[1:2][1], 1)
(x[1:2][2], 2)

julia> x = (y = 1, z = [[2.0], [3.0]]);

julia> foreach(println, varname_and_value_leaves(@varname(x), x))
(x.y, 1)
(x.z[1][1], 2.0)
(x.z[2][1], 3.0)
```

There is also some special handling for certain types:

```jldoctest varname-and-value-leaves
julia> using LinearAlgebra

julia> x = reshape(1:4, 2, 2);

julia> # `LowerTriangular`
foreach(println, varname_and_value_leaves(@varname(x), LowerTriangular(x)))
(x[1, 1], 1)
(x[2, 1], 2)
(x[2, 2], 4)

julia> # `UpperTriangular`
foreach(println, varname_and_value_leaves(@varname(x), UpperTriangular(x)))
(x[1, 1], 1)
(x[1, 2], 3)
(x[2, 2], 4)

julia> # `Cholesky` with lower-triangular
foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'L', 0)))
(x.L[1, 1], 1.0)
(x.L[2, 1], 0.0)
(x.L[2, 2], 1.0)

julia> # `Cholesky` with upper-triangular
foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'U', 0)))
(x.U[1, 1], 1.0)
(x.U[1, 2], 0.0)
(x.U[2, 2], 1.0)
```
"""
function varname_and_value_leaves(vn::VarName, x)
return Iterators.map(value, Iterators.flatten(varname_and_value_leaves_inner(vn, x)))
end

"""
varname_and_value_leaves(container)

Return an iterator over all varname-value pairs that are represented by `container`.

This is the same as [`varname_and_value_leaves(vn::VarName, x)`](@ref) but over a container
containing multiple varnames.

See also: [`varname_and_value_leaves(vn::VarName, x)`](@ref).

# Examples
```jldoctest varname-and-value-leaves-container
julia> using DynamicPPL: varname_and_value_leaves

julia> # With an `OrderedDict`
dict = OrderedDict(@varname(y) => 1, @varname(z) => [[2.0], [3.0]]);

julia> foreach(println, varname_and_value_leaves(dict))
(y, 1)
(z[1][1], 2.0)
(z[2][1], 3.0)

julia> # With a `NamedTuple`
nt = (y = 1, z = [[2.0], [3.0]]);

julia> foreach(println, varname_and_value_leaves(nt))
(y, 1)
(z[1][1], 2.0)
(z[2][1], 3.0)
```
"""
function varname_and_value_leaves(container::OrderedDict)
return Iterators.flatten(varname_and_value_leaves(k, v) for (k, v) in container)
end
function varname_and_value_leaves(container::NamedTuple)
return Iterators.flatten(
varname_and_value_leaves(VarName{k}(), v) for (k, v) in pairs(container)
)
end

"""
Leaf{T}

A container that represents the leaf of a nested structure, implementing
`iterate` to return itself.

This is particularly useful in conjunction with `Iterators.flatten` to
prevent flattening of nested structures.
"""
struct Leaf{T}
value::T
end

Leaf(xs...) = Leaf(xs)

# Allow us to treat `Leaf` as an iterator containing a single element.
# Something like an `[x]` would also be an iterator with a single element,
# but when we call `flatten` on this, it would also iterate over `x`,
# unflattening that too. By making `Leaf` a single-element iterator, which
# returns itself, we can call `iterate` on this as many times as we like
# without causing any change. The result is that `Iterators.flatten`
# will _not_ unflatten `Leaf`s.
# Note that this is similar to how `Base.iterate` is implemented for `Real`::
#
# julia> iterate(1)
# (1, nothing)
#
# One immediate example where this becomes in our scenario is that we might
# have `missing` values in our data, which does _not_ have an `iterate`
# implemented. Calling `Iterators.flatten` on this would cause an error.
Base.iterate(leaf::Leaf) = leaf, nothing
Base.iterate(::Leaf, _) = nothing

# Convenience.
value(leaf::Leaf) = leaf.value

# Leaf-types.
varname_and_value_leaves_inner(vn::VarName, x::Real) = [Leaf(vn, x)]
function varname_and_value_leaves_inner(
vn::VarName, val::AbstractArray{<:Union{Real,Missing}}
)
return (
Leaf(
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)),
val[I],
) for I in CartesianIndices(val)
)
end
# Containers.
function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray)
return Iterators.flatten(
varname_and_value_leaves_inner(
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)),
val[I],
) for I in CartesianIndices(val)
)
end
function varname_and_value_leaves_inner(vn::VarName, val::NamedTuple)
iter = Iterators.map(keys(val)) do k
optic = Accessors.PropertyLens{k}()
varname_and_value_leaves_inner(
VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val)
)
end

return Iterators.flatten(iter)
end
# Special types.
function varname_and_value_leaves_inner(vn::VarName, x::Cholesky)
# TODO: Or do we use `PDMat` here?
return if x.uplo == 'L'
varname_and_value_leaves_inner(Accessors.PropertyLens{:L}() ∘ vn, x.L)
else
varname_and_value_leaves_inner(Accessors.PropertyLens{:U}() ∘ vn, x.U)
end
end
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular)
return (
Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), x[I])
# Iteration over the lower-triangular indices.
for I in CartesianIndices(x) if I[1] >= I[2]
)
end
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular)
return (
Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), x[I])
# Iteration over the upper-triangular indices.
for I in CartesianIndices(x) if I[1] <= I[2]
)
end

broadcast_safe(x) = x
broadcast_safe(x::Distribution) = (x,)
broadcast_safe(x::AbstractContext) = (x,)
Expand Down
2 changes: 1 addition & 1 deletion test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()

# Extract varnames and values.
vns_and_vals_xs = map(
collect ∘ Base.Fix1(DynamicPPL.varname_and_value_leaves, @varname(x)), xs
collect ∘ Base.Fix1(AbstractPPL.varname_and_value_leaves, @varname(x)), xs
)
vns = map(first, first(vns_and_vals_xs))
vals = map(vns_and_vals_xs) do vns_and_vals
Expand Down
2 changes: 1 addition & 1 deletion test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::I
# We have to use varname_and_value_leaves so that each parameter is a scalar
dicts = map(varinfos) do t
vals = DynamicPPL.values_as(t, OrderedDict)
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals))
tuples = mapreduce(collect, vcat, iters)
# The following loop is a replacement for:
# push!(varnames, map(first, tuples)...)
Expand Down
2 changes: 1 addition & 1 deletion test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ end
θ_new = var_info[:]
@test θ_old != θ_new
vals = DynamicPPL.values_as(var_info, OrderedDict)
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals))
for (n, v) in mapreduce(collect, vcat, iters)
n = string(n)
if Symbol(n) ∉ keys(chain)
Expand Down
Loading