Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# DynamicPPL Changelog

## 0.36.14

Added compatibility with [email protected].

## 0.36.13

Added documentation for the `returned(::Model, ::MCMCChains.Chains)` method.
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.36.13"
version = "0.36.14"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -46,7 +46,7 @@ DynamicPPLMooncakeExt = ["Mooncake"]
[compat]
ADTypes = "1"
AbstractMCMC = "5"
AbstractPPL = "0.11"
AbstractPPL = "0.11, 0.12"
Accessors = "0.1"
BangBang = "0.4.1"
Bijectors = "0.13.18, 0.14, 0.15"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[compat]
AbstractPPL = "0.11"
AbstractPPL = "0.11, 0.12"
Accessors = "0.1"
DataStructures = "0.18"
Distributions = "0.25"
Expand Down
10 changes: 9 additions & 1 deletion src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,15 @@ If `vns` is provided, then only check if this/these varname(s) are transformed.
"""
istrans(vi::AbstractVarInfo) = istrans(vi, collect(keys(vi)))
function istrans(vi::AbstractVarInfo, vns::AbstractVector)
return !isempty(vns) && all(Base.Fix1(istrans, vi), vns)
# This used to be: `!isempty(vns) && all(Base.Fix1(istrans, vi), vns)`.
# In theory that should work perfectly fine. For unbeknownst reasons,
# Julia 1.10 fails to infer its return type correctly. Thus we use this
# slightly longer definition.
isempty(vns) && return false
for vn in vns
istrans(vi, vn) || return false
end
return true
end
Comment on lines -484 to 493
Copy link
Member Author

@penelopeysm penelopeysm Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See main PR comment for rationale (if interested). This is the only meaningful change in this PR, the rest is quite mundane


"""
Expand Down
8 changes: 5 additions & 3 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -356,13 +356,15 @@ function BangBang.setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarN
return vi
end

function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName)
function BangBang.setindex!!(
vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName{sym}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that VarName{sym} where {sym} function signatures might change performance, because they force specialisation on that argument. If they would, and if they would make it worse, this could always be circumvented with a getsymbol(vn::VarName{sym}) = sym function (probably that exists already).

I tried to compare these two benchmark results to see if there is an effect: #970 (comment) and #966 (comment). It does look like everything has gotten a bit slower, but this could easily be a GHA benchmarking fluctuation. If it's not too much trouble, would you be happy to check the benchmarks locally in a more controlled environment?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that's a really good catch. I think this is the only function which might be on a performance-sensitive path, but I'll change them all back anyway. The old definition of VarName(vn, optic) was VarName{getsym(vn)}(optic) so sticking to that should preserve exactly the same behaviour as before.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had looked at the benchmarks as I was worried about the type stability thing -- isn't this PR faster on most of them (though I imagine very much within error?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverting this fixed the original istrans type instability, but now there's another one. This might take a while.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the links above this PR is a bit slower on most. I wouldn't put any serious stock into those benchmarks without trying them in a more stable environment first though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd also be happy with a conclusion that the VarName{sym} where {sym}signatures have no (noticable) impact on benchmarks and this can be merged as-is.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, to my relief, it wasn't a different type stability problem, it was the same one. So the same fix (manually expanding all) still works.

Benchmarked smorgasboard:

using DynamicPPL, Distributions, BenchmarkTools
@model function smorgasbord(x, y, ::Type{TV}=Vector{Float64}) where {TV}
    @assert length(x) == length(y)
    m ~ truncated(Normal(); lower=0)
    means ~ product_distribution(fill(Exponential(m), length(x)))
    stds = TV(undef, length(x))
    stds .~ Gamma(1, 1)
    for i in 1:length(x)
        x[i] ~ Normal(means[i], stds[i])
    end
    y ~ product_distribution(map((mean, std) -> Normal(mean, std), means, stds))
    0.0 ~ Normal(sum(y), 1)
    return (; m=m, means=means, stds=stds)
end
m = smorgasbord(randn(100), randn(100))
vi = VarInfo(m)
ctx_def = DefaultContext()
ctx_spl = SamplingContext()

Main:

julia> @benchmark DynamicPPL.evaluate!!($m, $vi, $ctx_def)
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min  max):  11.250 μs  52.083 μs  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     12.292 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   12.525 μs ±  1.371 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

    ▂  ▃▃▃▃▄▄█▂▃▁  ▁
  ▃▇█▆▄███████████▇█▆▅▄▃▂▄▄▃▃▃▂▂▂▂▂▁▁▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  11.2 μs         Histogram: frequency by time        16.6 μs <

 Memory estimate: 16.94 KiB, allocs estimate: 321.

julia> @benchmark DynamicPPL.evaluate!!($m, $vi, $ctx_spl)
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min  max):  20.416 μs  129.583 μs  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     23.208 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   23.595 μs ±   2.289 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

             ▂▃▆▆█▅▃▃▂
  ▁▁▂▂▃▂▃▄▅▇████████████▆▆▅▅▄▃▃▃▃▂▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  20.4 μs         Histogram: frequency by time         30.6 μs <

 Memory estimate: 17.75 KiB, allocs estimate: 333.

This PR:

julia> @benchmark DynamicPPL.evaluate!!($m, $vi, $ctx_def)
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min  max):  11.375 μs   3.182 ms  ┊ GC (min  max): 0.00%  98.77%
 Time  (median):     12.083 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   12.548 μs ± 31.706 μs  ┊ GC (mean ± σ):  2.50% ±  0.99%

    ▂ ▇█ █▇▇ ▆▄ ▃▁
  ▂▅███████████▆███▄▆▅▃▄▄▄▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▂▂▂▂▂▂▂▂▂▂ ▄
  11.4 μs         Histogram: frequency by time        15.6 μs <

 Memory estimate: 16.94 KiB, allocs estimate: 321.

julia> @benchmark DynamicPPL.evaluate!!($m, $vi, $ctx_spl)
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min  max):  21.333 μs   7.169 ms  ┊ GC (min  max): 0.00%  99.12%
 Time  (median):     24.083 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   25.260 μs ± 71.474 μs  ┊ GC (mean ± σ):  2.81% ±  0.99%

             ▁▇▅█▄▆▃▆▁▃▁▂
  ▂▂▁▂▁▂▂▂▃▃▇████████████▇█▅▇▆▇▅▆▄▄▄▄▄▄▄▄▃▃▃▃▃▃▂▃▂▃▂▂▂▂▂▂▂▂▂▂ ▄
  21.3 μs         Histogram: frequency by time        29.9 μs <

 Memory estimate: 17.75 KiB, allocs estimate: 333.

) where {sym}
# For dictlike objects, we treat the entire `vn` as a _key_ to set.
dict = values_as(vi)
# Attempt to split into `parent` and `child` optic.
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
o = optic === nothing ? identity : optic
haskey(dict, VarName(vn, o))
haskey(dict, VarName{sym}(o))
end
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
keyoptic = parent === nothing ? identity : parent
Expand All @@ -372,7 +374,7 @@ function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName
BangBang.setindex!!(dict, val, vn)
else
# Split exists ⟹ trying to set an existing key.
vn_key = VarName(vn, keyoptic)
vn_key = VarName{sym}(keyoptic)
BangBang.setindex!!(dict, set!!(dict[vn_key], child, val), vn_key)
end
return Accessors.@set vi.values = dict_new
Expand Down
76 changes: 36 additions & 40 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -595,9 +595,9 @@
x
```
"""
function parent(vn::VarName)
function parent(vn::VarName{sym}) where {sym}

Check warning on line 598 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L598

Added line #L598 was not covered by tests
p = parent(getoptic(vn))
return p === nothing ? VarName(vn, identity) : VarName(vn, p)
return p === nothing ? VarName{sym}(identity) : VarName{sym}(p)

Check warning on line 600 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L600

Added line #L600 was not covered by tests
end

"""
Expand Down Expand Up @@ -712,7 +712,7 @@
function remove_parent_optic(vn_parent::VarName{sym}, vn_child::VarName{sym}) where {sym}
_, child, issuccess = splitoptic(getoptic(vn_child)) do optic
o = optic === nothing ? identity : optic
VarName(vn_child, o) == vn_parent
o == getoptic(vn_parent)
end

issuccess || error("Could not find $vn_parent in $vn_child")
Expand Down Expand Up @@ -898,7 +898,7 @@

# For `dictlike` we need to check wether `vn` is "immediately" present, or
# if some ancestor of `vn` is present in `dictlike`.
function hasvalue(vals::AbstractDict, vn::VarName)
function hasvalue(vals::AbstractDict, vn::VarName{sym}) where {sym}
# First we check if `vn` is present as is.
haskey(vals, vn) && return true

Expand All @@ -907,7 +907,7 @@
# If `issuccess` is `true`, we found such a split, and hence `vn` is present.
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
o = optic === nothing ? identity : optic
haskey(vals, VarName(vn, o))
haskey(vals, VarName{sym}(o))
end
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
keyoptic = parent === nothing ? identity : parent
Expand All @@ -916,7 +916,7 @@
issuccess || return false

# At this point we just need to check that we `canview` the value.
value = vals[VarName(vn, keyoptic)]
value = vals[VarName{sym}(keyoptic)]

return canview(child, value)
end
Expand All @@ -927,7 +927,7 @@
Return value corresponding to `vn` in `values` by also looking
in the the actual values of the dict.
"""
function nested_getindex(values::AbstractDict, vn::VarName)
function nested_getindex(values::AbstractDict, vn::VarName{sym}) where {sym}
maybeval = get(values, vn, nothing)
if maybeval !== nothing
return maybeval
Expand All @@ -936,7 +936,7 @@
# Split the optic into the key / `parent` and the extraction optic / `child`.
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
o = optic === nothing ? identity : optic
haskey(values, VarName(vn, o))
haskey(values, VarName{sym}(o))
end
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
keyoptic = parent === nothing ? identity : parent
Expand All @@ -949,7 +949,7 @@

# TODO: Should we also check that we `canview` the extracted `value`
# rather than just let it fail upon `get` call?
value = values[VarName(vn, keyoptic)]
value = values[VarName{sym}(keyoptic)]
return child(value)
end

Expand Down Expand Up @@ -1065,22 +1065,24 @@
```
"""
varname_leaves(vn::VarName, ::Real) = [vn]
function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}})
function varname_leaves(
vn::VarName{sym}, val::AbstractArray{<:Union{Real,Missing}}
) where {sym}
return (
VarName(vn, Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) for
VarName{sym}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) for
I in CartesianIndices(val)
)
end
function varname_leaves(vn::VarName, val::AbstractArray)
function varname_leaves(vn::VarName{sym}, val::AbstractArray) where {sym}

Check warning on line 1076 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L1076

Added line #L1076 was not covered by tests
return Iterators.flatten(
varname_leaves(VarName(vn, Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), val[I])
varname_leaves(VarName{sym}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), val[I])
for I in CartesianIndices(val)
)
end
function varname_leaves(vn::VarName, val::NamedTuple)
iter = Iterators.map(keys(val)) do sym
optic = Accessors.PropertyLens{sym}()
varname_leaves(VarName(vn, optic ∘ getoptic(vn)), optic(val))
function varname_leaves(vn::VarName{sym}, val::NamedTuple) where {sym}
iter = Iterators.map(keys(val)) do k
optic = Accessors.PropertyLens{k}()
varname_leaves(VarName{sym}(optic ∘ getoptic(vn)), optic(val))

Check warning on line 1085 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L1082-L1085

Added lines #L1082 - L1085 were not covered by tests
end
return Iterators.flatten(iter)
end
Expand Down Expand Up @@ -1225,30 +1227,26 @@
# Leaf-types.
varname_and_value_leaves_inner(vn::VarName, x::Real) = [Leaf(vn, x)]
function varname_and_value_leaves_inner(
vn::VarName, val::AbstractArray{<:Union{Real,Missing}}
)
vn::VarName{sym}, val::AbstractArray{<:Union{Real,Missing}}
) where {sym}
return (
Leaf(
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) ∘ DynamicPPL.getoptic(vn)),
val[I],
VarName{sym}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)), val[I]
) for I in CartesianIndices(val)
)
end
# Containers.
function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray)
function varname_and_value_leaves_inner(vn::VarName{sym}, val::AbstractArray) where {sym}

Check warning on line 1239 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L1239

Added line #L1239 was not covered by tests
return Iterators.flatten(
varname_and_value_leaves_inner(
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) ∘ DynamicPPL.getoptic(vn)),
val[I],
VarName{sym}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)), val[I]
) for I in CartesianIndices(val)
)
end
function varname_and_value_leaves_inner(vn::DynamicPPL.VarName, val::NamedTuple)
iter = Iterators.map(keys(val)) do sym
optic = DynamicPPL.Accessors.PropertyLens{sym}()
varname_and_value_leaves_inner(
VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val)
)
function varname_and_value_leaves_inner(vn::VarName{sym}, val::NamedTuple) where {sym}
iter = Iterators.map(keys(val)) do k
optic = Accessors.PropertyLens{k}()
varname_and_value_leaves_inner(VarName{sym}(optic ∘ getoptic(vn)), optic(val))

Check warning on line 1249 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L1246-L1249

Added lines #L1246 - L1249 were not covered by tests
end

return Iterators.flatten(iter)
Expand All @@ -1262,22 +1260,20 @@
varname_and_value_leaves_inner(Accessors.PropertyLens{:U}() ∘ vn, x.U)
end
end
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular)
function varname_and_value_leaves_inner(
vn::VarName{sym}, x::LinearAlgebra.LowerTriangular
) where {sym}
return (
Leaf(
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) ∘ DynamicPPL.getoptic(vn)),
x[I],
)
Leaf(VarName{sym}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)), x[I])
# Iteration over the lower-triangular indices.
for I in CartesianIndices(x) if I[1] >= I[2]
)
end
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular)
function varname_and_value_leaves_inner(

Check warning on line 1272 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L1272

Added line #L1272 was not covered by tests
vn::VarName{sym}, x::LinearAlgebra.UpperTriangular
) where {sym}
return (
Leaf(
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) ∘ DynamicPPL.getoptic(vn)),
x[I],
)
Leaf(VarName{sym}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)), x[I])
# Iteration over the upper-triangular indices.
for I in CartesianIndices(x) if I[1] <= I[2]
)
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
ADTypes = "1"
AbstractMCMC = "5"
AbstractPPL = "0.11"
AbstractPPL = "0.11, 0.12"
Accessors = "0.1"
Aqua = "0.8"
Bijectors = "0.15.1"
Expand Down