Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# DynamicPPL Changelog

## 0.36.14

Added compatibility with [email protected].

## 0.36.13

Added documentation for the `returned(::Model, ::MCMCChains.Chains)` method.
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.36.13"
version = "0.36.14"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -46,7 +46,7 @@ DynamicPPLMooncakeExt = ["Mooncake"]
[compat]
ADTypes = "1"
AbstractMCMC = "5"
AbstractPPL = "0.11"
AbstractPPL = "0.11, 0.12"
Accessors = "0.1"
BangBang = "0.4.1"
Bijectors = "0.13.18, 0.14, 0.15"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[compat]
AbstractPPL = "0.11"
AbstractPPL = "0.11, 0.12"
Accessors = "0.1"
DataStructures = "0.18"
Distributions = "0.25"
Expand Down
10 changes: 9 additions & 1 deletion src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,15 @@ If `vns` is provided, then only check if this/these varname(s) are transformed.
"""
istrans(vi::AbstractVarInfo) = istrans(vi, collect(keys(vi)))
function istrans(vi::AbstractVarInfo, vns::AbstractVector)
return !isempty(vns) && all(Base.Fix1(istrans, vi), vns)
# This used to be: `!isempty(vns) && all(Base.Fix1(istrans, vi), vns)`.
# In theory that should work perfectly fine. For unbeknownst reasons,
# Julia 1.10 fails to infer its return type correctly. Thus we use this
# slightly longer definition.
isempty(vns) && return false
for vn in vns
istrans(vi, vn) || return false
end
return true
end
Comment on lines -484 to 493
Copy link
Member Author

@penelopeysm penelopeysm Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See main PR comment for rationale (if interested). This is the only meaningful change in this PR, the rest is quite mundane


"""
Expand Down
4 changes: 2 additions & 2 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName
# Attempt to split into `parent` and `child` optic.
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
o = optic === nothing ? identity : optic
haskey(dict, VarName(vn, o))
haskey(dict, VarName{getsym(vn)}(o))
end
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
keyoptic = parent === nothing ? identity : parent
Expand All @@ -372,7 +372,7 @@ function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName
BangBang.setindex!!(dict, val, vn)
else
# Split exists ⟹ trying to set an existing key.
vn_key = VarName(vn, keyoptic)
vn_key = VarName{getsym(vn)}(keyoptic)
BangBang.setindex!!(dict, set!!(dict[vn_key], child, val), vn_key)
end
return Accessors.@set vi.values = dict_new
Expand Down
47 changes: 21 additions & 26 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@
"""
function parent(vn::VarName)
p = parent(getoptic(vn))
return p === nothing ? VarName(vn, identity) : VarName(vn, p)
return p === nothing ? VarName{getsym(vn)}(identity) : VarName{getsym(vn)}(p)

Check warning on line 600 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L600

Added line #L600 was not covered by tests
end

"""
Expand Down Expand Up @@ -712,7 +712,7 @@
function remove_parent_optic(vn_parent::VarName{sym}, vn_child::VarName{sym}) where {sym}
_, child, issuccess = splitoptic(getoptic(vn_child)) do optic
o = optic === nothing ? identity : optic
VarName(vn_child, o) == vn_parent
o == getoptic(vn_parent)
end

issuccess || error("Could not find $vn_parent in $vn_child")
Expand Down Expand Up @@ -907,7 +907,7 @@
# If `issuccess` is `true`, we found such a split, and hence `vn` is present.
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
o = optic === nothing ? identity : optic
haskey(vals, VarName(vn, o))
haskey(vals, VarName{getsym(vn)}(o))
end
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
keyoptic = parent === nothing ? identity : parent
Expand All @@ -916,7 +916,7 @@
issuccess || return false

# At this point we just need to check that we `canview` the value.
value = vals[VarName(vn, keyoptic)]
value = vals[VarName{getsym(vn)}(keyoptic)]

return canview(child, value)
end
Expand All @@ -936,7 +936,7 @@
# Split the optic into the key / `parent` and the extraction optic / `child`.
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
o = optic === nothing ? identity : optic
haskey(values, VarName(vn, o))
haskey(values, VarName{getsym(vn)}(o))
end
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
keyoptic = parent === nothing ? identity : parent
Expand All @@ -949,7 +949,7 @@

# TODO: Should we also check that we `canview` the extracted `value`
# rather than just let it fail upon `get` call?
value = values[VarName(vn, keyoptic)]
value = values[VarName{getsym(vn)}(keyoptic)]
return child(value)
end

Expand Down Expand Up @@ -1067,20 +1067,21 @@
varname_leaves(vn::VarName, ::Real) = [vn]
function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}})
return (
VarName(vn, Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) for
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) for
I in CartesianIndices(val)
)
end
function varname_leaves(vn::VarName, val::AbstractArray)
return Iterators.flatten(
varname_leaves(VarName(vn, Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), val[I])
for I in CartesianIndices(val)
varname_leaves(
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), val[I]
) for I in CartesianIndices(val)
)
end
function varname_leaves(vn::VarName, val::NamedTuple)
iter = Iterators.map(keys(val)) do sym
optic = Accessors.PropertyLens{sym}()
varname_leaves(VarName(vn, optic ∘ getoptic(vn)), optic(val))
iter = Iterators.map(keys(val)) do k
optic = Accessors.PropertyLens{k}()
varname_leaves(VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val))

Check warning on line 1084 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L1082-L1084

Added lines #L1082 - L1084 were not covered by tests
end
return Iterators.flatten(iter)
end
Expand Down Expand Up @@ -1110,7 +1111,7 @@
(x.z[2][1], 3.0)
```

There are also some special handling for certain types:
There is also some special handling for certain types:

```jldoctest varname-and-value-leaves
julia> using LinearAlgebra
Expand Down Expand Up @@ -1229,7 +1230,7 @@
)
return (
Leaf(
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) ∘ DynamicPPL.getoptic(vn)),
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)),
val[I],
) for I in CartesianIndices(val)
)
Expand All @@ -1238,14 +1239,14 @@
function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray)
return Iterators.flatten(
varname_and_value_leaves_inner(
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) ∘ DynamicPPL.getoptic(vn)),
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)),
val[I],
) for I in CartesianIndices(val)
)
end
function varname_and_value_leaves_inner(vn::DynamicPPL.VarName, val::NamedTuple)
iter = Iterators.map(keys(val)) do sym
optic = DynamicPPL.Accessors.PropertyLens{sym}()
function varname_and_value_leaves_inner(vn::VarName, val::NamedTuple)
iter = Iterators.map(keys(val)) do k
optic = Accessors.PropertyLens{k}()

Check warning on line 1249 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L1247-L1249

Added lines #L1247 - L1249 were not covered by tests
varname_and_value_leaves_inner(
VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val)
)
Expand All @@ -1264,20 +1265,14 @@
end
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular)
return (
Leaf(
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) ∘ DynamicPPL.getoptic(vn)),
x[I],
)
Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), x[I])
# Iteration over the lower-triangular indices.
for I in CartesianIndices(x) if I[1] >= I[2]
)
end
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular)
return (
Leaf(
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) ∘ DynamicPPL.getoptic(vn)),
x[I],
)
Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), x[I])
# Iteration over the upper-triangular indices.
for I in CartesianIndices(x) if I[1] <= I[2]
)
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
ADTypes = "1"
AbstractMCMC = "5"
AbstractPPL = "0.11"
AbstractPPL = "0.11, 0.12"
Accessors = "0.1"
Aqua = "0.8"
Bijectors = "0.15.1"
Expand Down