Skip to content

Commit c8e5841

Browse files
authored
use varname_leaves from AbstractPPL instead (#1030)
* use `varname_leaves` from AbstractPPL instead * add changelog entry * fix import
1 parent 2d18ce3 commit c8e5841

File tree

8 files changed

+10
-253
lines changed

8 files changed

+10
-253
lines changed

HISTORY.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
## 0.38.0
44

5+
The `varname_leaves` and `varname_and_value_leaves` functions have been moved to AbstractPPL.jl.
6+
Their behaviour is otherwise identical.
7+
58
[...]
69

710
## 0.37.1

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ DynamicPPLMooncakeExt = ["Mooncake"]
4747
[compat]
4848
ADTypes = "1"
4949
AbstractMCMC = "5"
50-
AbstractPPL = "0.13"
50+
AbstractPPL = "0.13.1"
5151
Accessors = "0.1"
5252
BangBang = "0.4.1"
5353
Bijectors = "0.13.18, 0.14, 0.15"

docs/src/api.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,8 +435,6 @@ DynamicPPL.maybe_invlink_before_eval!!
435435
Base.merge(::AbstractVarInfo)
436436
DynamicPPL.subset
437437
DynamicPPL.unflatten
438-
DynamicPPL.varname_leaves
439-
DynamicPPL.varname_and_value_leaves
440438
```
441439

442440
### Evaluation Contexts

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
module DynamicPPLMCMCChainsExt
22

3-
if isdefined(Base, :get_extension)
4-
using DynamicPPL: DynamicPPL
5-
using MCMCChains: MCMCChains
6-
else
7-
using ..DynamicPPL: DynamicPPL
8-
using ..MCMCChains: MCMCChains
9-
end
3+
using DynamicPPL: DynamicPPL, AbstractPPL
4+
using MCMCChains: MCMCChains
105

116
# Load state from a `Chains`: By convention, it is stored in `:samplerstate` metadata
127
function DynamicPPL.loadstate(chain::MCMCChains.Chains)
@@ -121,7 +116,7 @@ function DynamicPPL.predict(
121116
varname_vals = mapreduce(
122117
collect,
123118
vcat,
124-
map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)),
119+
map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)),
125120
)
126121

127122
return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo))

src/utils.jl

Lines changed: 0 additions & 239 deletions
Original file line numberDiff line numberDiff line change
@@ -837,245 +837,6 @@ end
837837
# Handle `AbstractDict` differently since `eltype` results in a `Pair`.
838838
infer_nested_eltype(::Type{<:AbstractDict{<:Any,ET}}) where {ET} = infer_nested_eltype(ET)
839839

840-
"""
841-
varname_leaves(vn::VarName, val)
842-
843-
Return an iterator over all varnames that are represented by `vn` on `val`.
844-
845-
# Examples
846-
```jldoctest
847-
julia> using DynamicPPL: varname_leaves
848-
849-
julia> foreach(println, varname_leaves(@varname(x), rand(2)))
850-
x[1]
851-
x[2]
852-
853-
julia> foreach(println, varname_leaves(@varname(x[1:2]), rand(2)))
854-
x[1:2][1]
855-
x[1:2][2]
856-
857-
julia> x = (y = 1, z = [[2.0], [3.0]]);
858-
859-
julia> foreach(println, varname_leaves(@varname(x), x))
860-
x.y
861-
x.z[1][1]
862-
x.z[2][1]
863-
```
864-
"""
865-
varname_leaves(vn::VarName, ::Real) = [vn]
866-
function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}})
867-
return (
868-
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn)) for
869-
I in CartesianIndices(val)
870-
)
871-
end
872-
function varname_leaves(vn::VarName, val::AbstractArray)
873-
return Iterators.flatten(
874-
varname_leaves(
875-
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn)), val[I]
876-
) for I in CartesianIndices(val)
877-
)
878-
end
879-
function varname_leaves(vn::VarName, val::NamedTuple)
880-
iter = Iterators.map(keys(val)) do k
881-
optic = Accessors.PropertyLens{k}()
882-
varname_leaves(VarName{getsym(vn)}(optic getoptic(vn)), optic(val))
883-
end
884-
return Iterators.flatten(iter)
885-
end
886-
887-
"""
888-
varname_and_value_leaves(vn::VarName, val)
889-
890-
Return an iterator over all varname-value pairs that are represented by `vn` on `val`.
891-
892-
# Examples
893-
```jldoctest varname-and-value-leaves
894-
julia> using DynamicPPL: varname_and_value_leaves
895-
896-
julia> foreach(println, varname_and_value_leaves(@varname(x), 1:2))
897-
(x[1], 1)
898-
(x[2], 2)
899-
900-
julia> foreach(println, varname_and_value_leaves(@varname(x[1:2]), 1:2))
901-
(x[1:2][1], 1)
902-
(x[1:2][2], 2)
903-
904-
julia> x = (y = 1, z = [[2.0], [3.0]]);
905-
906-
julia> foreach(println, varname_and_value_leaves(@varname(x), x))
907-
(x.y, 1)
908-
(x.z[1][1], 2.0)
909-
(x.z[2][1], 3.0)
910-
```
911-
912-
There is also some special handling for certain types:
913-
914-
```jldoctest varname-and-value-leaves
915-
julia> using LinearAlgebra
916-
917-
julia> x = reshape(1:4, 2, 2);
918-
919-
julia> # `LowerTriangular`
920-
foreach(println, varname_and_value_leaves(@varname(x), LowerTriangular(x)))
921-
(x[1, 1], 1)
922-
(x[2, 1], 2)
923-
(x[2, 2], 4)
924-
925-
julia> # `UpperTriangular`
926-
foreach(println, varname_and_value_leaves(@varname(x), UpperTriangular(x)))
927-
(x[1, 1], 1)
928-
(x[1, 2], 3)
929-
(x[2, 2], 4)
930-
931-
julia> # `Cholesky` with lower-triangular
932-
foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'L', 0)))
933-
(x.L[1, 1], 1.0)
934-
(x.L[2, 1], 0.0)
935-
(x.L[2, 2], 1.0)
936-
937-
julia> # `Cholesky` with upper-triangular
938-
foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'U', 0)))
939-
(x.U[1, 1], 1.0)
940-
(x.U[1, 2], 0.0)
941-
(x.U[2, 2], 1.0)
942-
```
943-
"""
944-
function varname_and_value_leaves(vn::VarName, x)
945-
return Iterators.map(value, Iterators.flatten(varname_and_value_leaves_inner(vn, x)))
946-
end
947-
948-
"""
949-
varname_and_value_leaves(container)
950-
951-
Return an iterator over all varname-value pairs that are represented by `container`.
952-
953-
This is the same as [`varname_and_value_leaves(vn::VarName, x)`](@ref) but over a container
954-
containing multiple varnames.
955-
956-
See also: [`varname_and_value_leaves(vn::VarName, x)`](@ref).
957-
958-
# Examples
959-
```jldoctest varname-and-value-leaves-container
960-
julia> using DynamicPPL: varname_and_value_leaves
961-
962-
julia> # With an `OrderedDict`
963-
dict = OrderedDict(@varname(y) => 1, @varname(z) => [[2.0], [3.0]]);
964-
965-
julia> foreach(println, varname_and_value_leaves(dict))
966-
(y, 1)
967-
(z[1][1], 2.0)
968-
(z[2][1], 3.0)
969-
970-
julia> # With a `NamedTuple`
971-
nt = (y = 1, z = [[2.0], [3.0]]);
972-
973-
julia> foreach(println, varname_and_value_leaves(nt))
974-
(y, 1)
975-
(z[1][1], 2.0)
976-
(z[2][1], 3.0)
977-
```
978-
"""
979-
function varname_and_value_leaves(container::OrderedDict)
980-
return Iterators.flatten(varname_and_value_leaves(k, v) for (k, v) in container)
981-
end
982-
function varname_and_value_leaves(container::NamedTuple)
983-
return Iterators.flatten(
984-
varname_and_value_leaves(VarName{k}(), v) for (k, v) in pairs(container)
985-
)
986-
end
987-
988-
"""
989-
Leaf{T}
990-
991-
A container that represents the leaf of a nested structure, implementing
992-
`iterate` to return itself.
993-
994-
This is particularly useful in conjunction with `Iterators.flatten` to
995-
prevent flattening of nested structures.
996-
"""
997-
struct Leaf{T}
998-
value::T
999-
end
1000-
1001-
Leaf(xs...) = Leaf(xs)
1002-
1003-
# Allow us to treat `Leaf` as an iterator containing a single element.
1004-
# Something like an `[x]` would also be an iterator with a single element,
1005-
# but when we call `flatten` on this, it would also iterate over `x`,
1006-
# unflattening that too. By making `Leaf` a single-element iterator, which
1007-
# returns itself, we can call `iterate` on this as many times as we like
1008-
# without causing any change. The result is that `Iterators.flatten`
1009-
# will _not_ unflatten `Leaf`s.
1010-
# Note that this is similar to how `Base.iterate` is implemented for `Real`::
1011-
#
1012-
# julia> iterate(1)
1013-
# (1, nothing)
1014-
#
1015-
# One immediate example where this becomes in our scenario is that we might
1016-
# have `missing` values in our data, which does _not_ have an `iterate`
1017-
# implemented. Calling `Iterators.flatten` on this would cause an error.
1018-
Base.iterate(leaf::Leaf) = leaf, nothing
1019-
Base.iterate(::Leaf, _) = nothing
1020-
1021-
# Convenience.
1022-
value(leaf::Leaf) = leaf.value
1023-
1024-
# Leaf-types.
1025-
varname_and_value_leaves_inner(vn::VarName, x::Real) = [Leaf(vn, x)]
1026-
function varname_and_value_leaves_inner(
1027-
vn::VarName, val::AbstractArray{<:Union{Real,Missing}}
1028-
)
1029-
return (
1030-
Leaf(
1031-
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) AbstractPPL.getoptic(vn)),
1032-
val[I],
1033-
) for I in CartesianIndices(val)
1034-
)
1035-
end
1036-
# Containers.
1037-
function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray)
1038-
return Iterators.flatten(
1039-
varname_and_value_leaves_inner(
1040-
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) AbstractPPL.getoptic(vn)),
1041-
val[I],
1042-
) for I in CartesianIndices(val)
1043-
)
1044-
end
1045-
function varname_and_value_leaves_inner(vn::VarName, val::NamedTuple)
1046-
iter = Iterators.map(keys(val)) do k
1047-
optic = Accessors.PropertyLens{k}()
1048-
varname_and_value_leaves_inner(
1049-
VarName{getsym(vn)}(optic getoptic(vn)), optic(val)
1050-
)
1051-
end
1052-
1053-
return Iterators.flatten(iter)
1054-
end
1055-
# Special types.
1056-
function varname_and_value_leaves_inner(vn::VarName, x::Cholesky)
1057-
# TODO: Or do we use `PDMat` here?
1058-
return if x.uplo == 'L'
1059-
varname_and_value_leaves_inner(Accessors.PropertyLens{:L}() vn, x.L)
1060-
else
1061-
varname_and_value_leaves_inner(Accessors.PropertyLens{:U}() vn, x.U)
1062-
end
1063-
end
1064-
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular)
1065-
return (
1066-
Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn)), x[I])
1067-
# Iteration over the lower-triangular indices.
1068-
for I in CartesianIndices(x) if I[1] >= I[2]
1069-
)
1070-
end
1071-
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular)
1072-
return (
1073-
Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn)), x[I])
1074-
# Iteration over the upper-triangular indices.
1075-
for I in CartesianIndices(x) if I[1] <= I[2]
1076-
)
1077-
end
1078-
1079840
broadcast_safe(x) = x
1080841
broadcast_safe(x::Distribution) = (x,)
1081842
broadcast_safe(x::AbstractContext) = (x,)

test/model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
347347

348348
# Extract varnames and values.
349349
vns_and_vals_xs = map(
350-
collect Base.Fix1(DynamicPPL.varname_and_value_leaves, @varname(x)), xs
350+
collect Base.Fix1(AbstractPPL.varname_and_value_leaves, @varname(x)), xs
351351
)
352352
vns = map(first, first(vns_and_vals_xs))
353353
vals = map(vns_and_vals_xs) do vns_and_vals

test/test_util.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::I
7272
# We have to use varname_and_value_leaves so that each parameter is a scalar
7373
dicts = map(varinfos) do t
7474
vals = DynamicPPL.values_as(t, OrderedDict)
75-
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
75+
iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals))
7676
tuples = mapreduce(collect, vcat, iters)
7777
# The following loop is a replacement for:
7878
# push!(varnames, map(first, tuples)...)

test/varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ end
488488
θ_new = var_info[:]
489489
@test θ_old != θ_new
490490
vals = DynamicPPL.values_as(var_info, OrderedDict)
491-
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
491+
iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals))
492492
for (n, v) in mapreduce(collect, vcat, iters)
493493
n = string(n)
494494
if Symbol(n) keys(chain)

0 commit comments

Comments
 (0)