diff --git a/HISTORY.md b/HISTORY.md index e67ceba0a..0d2a56606 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,9 @@ # DynamicPPL Changelog +## 0.36.14 + +Added compatibility with AbstractPPL@0.12. + ## 0.36.13 Added documentation for the `returned(::Model, ::MCMCChains.Chains)` method. diff --git a/Project.toml b/Project.toml index cd473354c..351fe6365 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.36.13" +version = "0.36.14" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -46,7 +46,7 @@ DynamicPPLMooncakeExt = ["Mooncake"] [compat] ADTypes = "1" AbstractMCMC = "5" -AbstractPPL = "0.11" +AbstractPPL = "0.11, 0.12" Accessors = "0.1" BangBang = "0.4.1" Bijectors = "0.13.18, 0.14, 0.15" diff --git a/docs/Project.toml b/docs/Project.toml index c00c29c96..3f258909a 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -14,7 +14,7 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" [compat] -AbstractPPL = "0.11" +AbstractPPL = "0.11, 0.12" Accessors = "0.1" DataStructures = "0.18" Distributions = "0.25" diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index f11b8a3ec..28bf488fa 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -481,7 +481,15 @@ If `vns` is provided, then only check if this/these varname(s) are transformed. """ istrans(vi::AbstractVarInfo) = istrans(vi, collect(keys(vi))) function istrans(vi::AbstractVarInfo, vns::AbstractVector) - return !isempty(vns) && all(Base.Fix1(istrans, vi), vns) + # This used to be: `!isempty(vns) && all(Base.Fix1(istrans, vi), vns)`. + # In theory that should work perfectly fine. For unbeknownst reasons, + # Julia 1.10 fails to infer its return type correctly. Thus we use this + # slightly longer definition. + isempty(vns) && return false + for vn in vns + istrans(vi, vn) || return false + end + return true end """ diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 3ae425896..2297bc9e1 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -362,7 +362,7 @@ function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName # Attempt to split into `parent` and `child` optic. parent, child, issuccess = splitoptic(getoptic(vn)) do optic o = optic === nothing ? identity : optic - haskey(dict, VarName(vn, o)) + haskey(dict, VarName{getsym(vn)}(o)) end # When combined with `VarInfo`, `nothing` is equivalent to `identity`. keyoptic = parent === nothing ? identity : parent @@ -372,7 +372,7 @@ function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName BangBang.setindex!!(dict, val, vn) else # Split exists ⟹ trying to set an existing key. - vn_key = VarName(vn, keyoptic) + vn_key = VarName{getsym(vn)}(keyoptic) BangBang.setindex!!(dict, set!!(dict[vn_key], child, val), vn_key) end return Accessors.@set vi.values = dict_new diff --git a/src/utils.jl b/src/utils.jl index d828fd771..73a8b48b9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -597,7 +597,7 @@ x """ function parent(vn::VarName) p = parent(getoptic(vn)) - return p === nothing ? VarName(vn, identity) : VarName(vn, p) + return p === nothing ? VarName{getsym(vn)}(identity) : VarName{getsym(vn)}(p) end """ @@ -712,7 +712,7 @@ ERROR: Could not find x.a[2] in x.a[1] function remove_parent_optic(vn_parent::VarName{sym}, vn_child::VarName{sym}) where {sym} _, child, issuccess = splitoptic(getoptic(vn_child)) do optic o = optic === nothing ? identity : optic - VarName(vn_child, o) == vn_parent + o == getoptic(vn_parent) end issuccess || error("Could not find $vn_parent in $vn_child") @@ -907,7 +907,7 @@ function hasvalue(vals::AbstractDict, vn::VarName) # If `issuccess` is `true`, we found such a split, and hence `vn` is present. parent, child, issuccess = splitoptic(getoptic(vn)) do optic o = optic === nothing ? identity : optic - haskey(vals, VarName(vn, o)) + haskey(vals, VarName{getsym(vn)}(o)) end # When combined with `VarInfo`, `nothing` is equivalent to `identity`. keyoptic = parent === nothing ? identity : parent @@ -916,7 +916,7 @@ function hasvalue(vals::AbstractDict, vn::VarName) issuccess || return false # At this point we just need to check that we `canview` the value. - value = vals[VarName(vn, keyoptic)] + value = vals[VarName{getsym(vn)}(keyoptic)] return canview(child, value) end @@ -936,7 +936,7 @@ function nested_getindex(values::AbstractDict, vn::VarName) # Split the optic into the key / `parent` and the extraction optic / `child`. parent, child, issuccess = splitoptic(getoptic(vn)) do optic o = optic === nothing ? identity : optic - haskey(values, VarName(vn, o)) + haskey(values, VarName{getsym(vn)}(o)) end # When combined with `VarInfo`, `nothing` is equivalent to `identity`. keyoptic = parent === nothing ? identity : parent @@ -949,7 +949,7 @@ function nested_getindex(values::AbstractDict, vn::VarName) # TODO: Should we also check that we `canview` the extracted `value` # rather than just let it fail upon `get` call? - value = values[VarName(vn, keyoptic)] + value = values[VarName{getsym(vn)}(keyoptic)] return child(value) end @@ -1067,20 +1067,21 @@ x.z[2][1] varname_leaves(vn::VarName, ::Real) = [vn] function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) return ( - VarName(vn, Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) for + VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) for I in CartesianIndices(val) ) end function varname_leaves(vn::VarName, val::AbstractArray) return Iterators.flatten( - varname_leaves(VarName(vn, Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), val[I]) - for I in CartesianIndices(val) + varname_leaves( + VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), val[I] + ) for I in CartesianIndices(val) ) end function varname_leaves(vn::VarName, val::NamedTuple) - iter = Iterators.map(keys(val)) do sym - optic = Accessors.PropertyLens{sym}() - varname_leaves(VarName(vn, optic ∘ getoptic(vn)), optic(val)) + iter = Iterators.map(keys(val)) do k + optic = Accessors.PropertyLens{k}() + varname_leaves(VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val)) end return Iterators.flatten(iter) end @@ -1110,7 +1111,7 @@ julia> foreach(println, varname_and_value_leaves(@varname(x), x)) (x.z[2][1], 3.0) ``` -There are also some special handling for certain types: +There is also some special handling for certain types: ```jldoctest varname-and-value-leaves julia> using LinearAlgebra @@ -1229,7 +1230,7 @@ function varname_and_value_leaves_inner( ) return ( Leaf( - VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) ∘ DynamicPPL.getoptic(vn)), + VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)), val[I], ) for I in CartesianIndices(val) ) @@ -1238,14 +1239,14 @@ end function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray) return Iterators.flatten( varname_and_value_leaves_inner( - VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) ∘ DynamicPPL.getoptic(vn)), + VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)), val[I], ) for I in CartesianIndices(val) ) end -function varname_and_value_leaves_inner(vn::DynamicPPL.VarName, val::NamedTuple) - iter = Iterators.map(keys(val)) do sym - optic = DynamicPPL.Accessors.PropertyLens{sym}() +function varname_and_value_leaves_inner(vn::VarName, val::NamedTuple) + iter = Iterators.map(keys(val)) do k + optic = Accessors.PropertyLens{k}() varname_and_value_leaves_inner( VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val) ) @@ -1264,20 +1265,14 @@ function varname_and_value_leaves_inner(vn::VarName, x::Cholesky) end function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular) return ( - Leaf( - VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) ∘ DynamicPPL.getoptic(vn)), - x[I], - ) + Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), x[I]) # Iteration over the lower-triangular indices. for I in CartesianIndices(x) if I[1] >= I[2] ) end function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular) return ( - Leaf( - VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) ∘ DynamicPPL.getoptic(vn)), - x[I], - ) + Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), x[I]) # Iteration over the upper-triangular indices. for I in CartesianIndices(x) if I[1] <= I[2] ) diff --git a/test/Project.toml b/test/Project.toml index 4f9ff4220..afecba1c4 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -30,7 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1" AbstractMCMC = "5" -AbstractPPL = "0.11" +AbstractPPL = "0.11, 0.12" Accessors = "0.1" Aqua = "0.8" Bijectors = "0.15.1"