Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
34ad663
Fix and improve map!! and apply!!
mhauru Dec 18, 2025
dc6291d
mapreduce and nested PartialArrays
mhauru Dec 18, 2025
20ed575
Test invariants more
mhauru Dec 18, 2025
477b715
Work-in-progress VNTVarInfo
mhauru Dec 18, 2025
7aa6013
Fix a bug in link
mhauru Dec 18, 2025
9ae56ab
Merge branch 'mhauru/vnt-for-vaimacc' into mhauru/vnt-for-varinfo
mhauru Dec 19, 2025
bdeeb4a
Update map!! to operate on pairs
mhauru Jan 8, 2026
5498d82
Split map!! into map_pairs!! and map_values!!, fix some bugs
mhauru Jan 8, 2026
81be716
Make mapreduce operate on pairs
mhauru Jan 8, 2026
37f4adf
Implement keys and values using mapreduce
mhauru Jan 8, 2026
fc29cc6
Add more VNT constructors
mhauru Jan 9, 2026
c6d0677
Add VNT subset
mhauru Jan 9, 2026
c18258c
Make _compose_no_identity handle typed_identity too
mhauru Jan 9, 2026
b91e6ff
Myriad improvements to VNTVarInfo, overhaul varinfo.jl tests to use V…
mhauru Jan 9, 2026
4b1a8f5
Merge remote-tracking branch 'origin/mhauru/vnt-for-fastldf' into mha…
mhauru Jan 9, 2026
8018f45
Fix a couple of ArrayLikeBlock bugs
mhauru Jan 9, 2026
1cbcda7
Fix PartialArray map bug
mhauru Jan 9, 2026
573cd5a
In VNTVarInfo, handle variables with varying dimensions correctly
mhauru Jan 9, 2026
c353cbc
Fix two small bugs
mhauru Jan 9, 2026
a36bb15
Allow nested PartialArrays with ArrayLikeBlocks
mhauru Jan 9, 2026
bf05554
Stop testing for NamedDist with unconcrete VarName
mhauru Jan 9, 2026
7857eae
Misc bugfixes
mhauru Jan 9, 2026
16fe150
Stop running SVI and VNT tests
mhauru Jan 9, 2026
51a518f
Fix LDF bug
mhauru Jan 12, 2026
1950a93
Fix some bugs, simplify (inv)linking
mhauru Jan 12, 2026
051521a
Fix some tests
mhauru Jan 12, 2026
9812ad0
Comment back in include of old VI files
mhauru Jan 12, 2026
6d44954
Remote JET extension and experimental.jl
mhauru Jan 12, 2026
d5bfa2c
Reimplement bijector.jl
mhauru Jan 12, 2026
eb903e1
Move linking code to VarInfo, fix ProductNamedDistribution bijector, etc
mhauru Jan 12, 2026
469a715
Mark a test as broken
mhauru Jan 12, 2026
89a8396
Various bugfixes
mhauru Jan 12, 2026
8cf8ab0
Remove SimpleVarInfo, VarNamedVector, and the old VarInfo type
mhauru Jan 12, 2026
8ba36f6
Fix a lot of doctests
mhauru Jan 13, 2026
1f6335d
Rename vntvarinfo.jl to varinfo.jl
mhauru Jan 13, 2026
dbcf5f6
Rename VNTVarInfo to VarInfo
mhauru Jan 13, 2026
0edaa53
Remove (un)typed_varinfo
mhauru Jan 13, 2026
c2748a7
Add docstrings to varinfo.jl
mhauru Jan 13, 2026
6dbae23
Simplify transformations
mhauru Jan 13, 2026
2fa7333
Fix docs
mhauru Jan 13, 2026
7cbc4a7
Mark some inference tests as broken on 1.10
mhauru Jan 13, 2026
b4361c0
Polish VNT and tests
mhauru Jan 13, 2026
73e50df
Fix broken test marking
mhauru Jan 13, 2026
922fbb6
Polish varinfo.jl
mhauru Jan 13, 2026
66c7970
Polish internal docs
mhauru Jan 13, 2026
51fdcbe
More broken inference tests on v1.10
mhauru Jan 13, 2026
07a13c4
Export VarNamedTuple and its functions
mhauru Jan 14, 2026
92dd490
Add HISTROY.md entry on the new VarInfo
mhauru Jan 14, 2026
06f6c1e
Apply suggestions from code review
mhauru Jan 14, 2026
4f893bc
Use SkipSizeCheck rather than Val(:pass)
mhauru Jan 14, 2026
fdb1373
Remove getindex with dist argument
mhauru Jan 14, 2026
a023a7f
Simplify map and mapreduce for VNT
mhauru Jan 14, 2026
c369b09
Remove unused utility functions
mhauru Jan 15, 2026
6128a56
Use OnlyAccsVarInfo in extract_priors
mhauru Jan 15, 2026
8fddfef
Make linking status a type parameter of VarInfo
mhauru Jan 15, 2026
aa3adb3
Fix a typo
mhauru Jan 15, 2026
0c03233
Simplify code
mhauru Jan 15, 2026
39df57b
Fix comments, remove dead line
mhauru Jan 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,68 @@

## 0.40

### `VarNamedTuple`

DynamicPPL now exports a new type, called `VarNamedTuple`, which stores values keyed by `VarName`s.
With it are exported a few new functions for using it: `map_values!!`, `map_pairs!!`, `apply!!`.
Our documentation's Internals section now has a page about `VarNamedTuple`, how it works, and what it's good for.

`VarNamedTuple` is now used internally in many different parts: In `VarInfo`, in `values_as_in_model`, in `LogDensityFunction`, etc.
Almost all of the below changes are the consequence from switching over to using `VarNamedTuple` for various features internally.

### Overhaul of `VarInfo`

DynamicPPL tracks variable values during model execution using one of the `AbstractVarInfo` types.
Previously, there were many versions of them: `VarInfo`, both "typed" and "untyped", and `SimpleVarInfo` with both `NamedTuple` and `OrderedDict` as storage backends.
These have all been replaced by a rewritten implementation of `VarInfo`.
While the basics of the `VarInfo` interface remain the same, this brings with it many changes:

#### No more many `AbstractVarInfo` types

`SimpleVarInfo`, `untyped_varinfo`, `typed_varinfo`, and many other constructors, some exported some not, have been removed.
The remaining one is `VarInfo(...)`, which can take a model or a collection of values.
See the docstring for details.

Some related types and functions, that weren't exported but may have been used by some, have also been removed, most notably `VarNamedVector` and its associated functions like `loosen_types!!` and `tighten_types!!`.

#### Setting and getting values

Previously the various `AbstractVarInfo` types had a multitude of functions for setting values:
`push!!`, `push!`, `setindex!`, `update!`, `update_internal!`, `insert_internal!`, `reset!`, etc.
These have all been replaced by three functions

- `setindex!!` is the one to use for simply setting a variable in `VarInfo` to a known value. It works regardless of whether the variable already exists.
- `setindex_internal!!` is the one to use for setting the internal, vectorised representation of a variable. See the docstring for details.
- `setindex_with_dist!!` is to be used when you want to set a value, but choose the internal representation based on which distribution this value is a sample for.

The order of the arguments for some of these functions has also changed, and now more closely matches the usual convention for `setindex!!`.

Note that `setindex!` (with a single `!`) is not defined, and thus you can't do `varinfo[varname] = new_value`.

`unflatten` works as before, but has been renamed to `unflatten!!`, since it may mutate the first argument and aliases memory with the second argument (it uses views rather than copies of the input vector).

#### Linking is now safer

`link!!` and `invlink!!` on `VarInfo` used to assume that the prior distribution of a variable didn't change from one execution to another (as it does in e.g. `truncated(dist; lower=x)` where `x` is a random variable).
This is no longer the case.
Linking should thus be safer to do.
The cost to pay is that calls to `link!!` and `invlink!!` (and the non-mutating versions) now trigger a model evaluation, to determine the correct priors to use.

#### Other miscellanea

- The `Experimental` module had functions like `Experimental.determine_suitable_varinfo` for determining which `AbstractVarInfo` type was suitable for a given model. This is now redundant and has been removed.
- `Bijectors.bijector(::Model)`, which creates a bijector from the vectorised variable space of the model to the linked variable space of the model, now has slightly different optional arguments. See the docstring for details.
- `NamedDist` no longer allows variable names with `Colon`s in them, such as `x[:]`.

There are probably also changes to the `VarInfo` interface that we've neglected to document here, since the overhaul of `VarInfo` has been quite complete.
If anything related to `VarInfo` is behaving unexpectedly, e.g. the arguments or return type of a function seem to have changed, please check the docstring, which should be comprehensive.

#### Performance benefits

The purpose of this overhaul of `VarInfo` is code simplification and performance benefits.

TODO(mhauru) Add some basic summary of what has gotten faster by how much.

### Changes to indexing random variables with square brackets

0.40 internally reimplements how DynamicPPL handles random variables like `x[1]`, `x.y[2,2]`, and `x[:,1:4,5]`, i.e. ones that use indexing with square brackets.
Expand Down
3 changes: 0 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392"
Expand All @@ -40,7 +39,6 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLJETExt = ["JET"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"]
DynamicPPLMooncakeExt = ["Mooncake"]
Expand All @@ -62,7 +60,6 @@ DocStringExtensions = "0.9"
EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10.12, 1"
InteractiveUtils = "1"
JET = "0.9, 0.10, 0.11"
KernelAbstractions = "0.9.33"
LinearAlgebra = "1.6"
LogDensityProblems = "2"
Expand Down
61 changes: 25 additions & 36 deletions benchmarks/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@ using StableRNGs: StableRNG

rng = StableRNG(23)

colnames = [
"Model", "Dim", "AD Backend", "VarInfo", "Linked", "t(eval)/t(ref)", "t(grad)/t(eval)"
]
colnames = ["Model", "Dim", "AD Backend", "Linked", "t(eval)/t(ref)", "t(grad)/t(eval)"]
function print_results(results_table; to_json=false)
if to_json
# Print to the given file as JSON
Expand Down Expand Up @@ -58,31 +56,26 @@ function run(; to_json=false)
end

# Specify the combinations to test:
# (Model Name, model instance, VarInfo choice, AD backend, linked)
# (Model Name, model instance, AD backend, linked)
chosen_combinations = [
(
"Simple assume observe",
Models.simple_assume_observe(randn(rng)),
:typed,
:forwarddiff,
false,
),
("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false),
("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true),
("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true),
("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true),
("Loop univariate 1k", loop_univariate1k, :typed, :mooncake, true),
("Multivariate 1k", multivariate1k, :typed, :mooncake, true),
("Loop univariate 10k", loop_univariate10k, :typed, :mooncake, true),
("Multivariate 10k", multivariate10k, :typed, :mooncake, true),
("Dynamic", Models.dynamic(), :typed, :mooncake, true),
("Submodel", Models.parent(randn(rng)), :typed, :mooncake, true),
("LDA", lda_instance, :typed, :reversediff, true),
("Smorgasbord", smorgasbord_instance, :forwarddiff, false),
("Smorgasbord", smorgasbord_instance, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :reversediff, true),
("Smorgasbord", smorgasbord_instance, :mooncake, true),
("Smorgasbord", smorgasbord_instance, :enzyme, true),
("Loop univariate 1k", loop_univariate1k, :mooncake, true),
("Multivariate 1k", multivariate1k, :mooncake, true),
("Loop univariate 10k", loop_univariate10k, :mooncake, true),
("Multivariate 10k", multivariate10k, :mooncake, true),
("Dynamic", Models.dynamic(), :mooncake, true),
("Submodel", Models.parent(randn(rng)), :mooncake, true),
("LDA", lda_instance, :reversediff, true),
]

# Time running a model-like function that does not use DynamicPPL, as a reference point.
Expand All @@ -94,13 +87,13 @@ function run(; to_json=false)
@info "Reference evaluation time: $(reference_time) seconds"

results_table = Tuple{
String,Int,String,String,Bool,Union{Float64,Missing},Union{Float64,Missing}
String,Int,String,Bool,Union{Float64,Missing},Union{Float64,Missing}
}[]

for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations
@info "Running benchmark for $model_name, $varinfo_choice, $adbackend, $islinked"
for (model_name, model, adbackend, islinked) in chosen_combinations
@info "Running benchmark for $model_name, $adbackend, $islinked"
relative_eval_time, relative_ad_eval_time = try
results = benchmark(model, varinfo_choice, adbackend, islinked)
results = benchmark(model, adbackend, islinked)
@info " t(eval) = $(results.primal_time)"
@info " t(grad) = $(results.grad_time)"
(results.primal_time / reference_time),
Expand All @@ -115,7 +108,6 @@ function run(; to_json=false)
model_name,
model_dimension(model, islinked),
string(adbackend),
string(varinfo_choice),
islinked,
relative_eval_time,
relative_ad_eval_time,
Expand All @@ -131,9 +123,8 @@ struct TestCase
model_name::String
dim::Integer
ad_backend::String
varinfo::String
linked::Bool
TestCase(d::Dict{String,Any}) = new((d[c] for c in colnames[1:5])...)
TestCase(d::Dict{String,Any}) = new((d[c] for c in colnames[1:4])...)
end
function combine(head_filename::String, base_filename::String)
head_results = try
Expand All @@ -148,23 +139,22 @@ function combine(head_filename::String, base_filename::String)
Dict{String,Any}[]
end
@info "Loaded $(length(base_results)) results from $base_filename"
# Identify unique combinations of (Model, Dim, AD Backend, VarInfo, Linked)
# Identify unique combinations of (Model, Dim, AD Backend, Linked)
head_testcases = Dict(
TestCase(d) => (d[colnames[6]], d[colnames[7]]) for d in head_results
TestCase(d) => (d[colnames[5]], d[colnames[6]]) for d in head_results
)
base_testcases = Dict(
TestCase(d) => (d[colnames[6]], d[colnames[7]]) for d in base_results
TestCase(d) => (d[colnames[5]], d[colnames[6]]) for d in base_results
)
all_testcases = union(Set(keys(head_testcases)), Set(keys(base_testcases)))
@info "$(length(all_testcases)) unique test cases found"
sorted_testcases = sort(
collect(all_testcases); by=(c -> (c.model_name, c.linked, c.varinfo, c.ad_backend))
collect(all_testcases); by=(c -> (c.model_name, c.linked, c.ad_backend))
)
results_table = Tuple{
String,
Int,
String,
String,
Bool,
String,
String,
Expand All @@ -179,12 +169,12 @@ function combine(head_filename::String, base_filename::String)
sublabels = ["base", "this PR", "speedup"]
results_colnames = [
[
EmptyCells(5),
EmptyCells(4),
MultiColumn(3, "t(eval) / t(ref)"),
MultiColumn(3, "t(grad) / t(eval)"),
MultiColumn(3, "t(grad) / t(ref)"),
],
[colnames[1:5]..., sublabels..., sublabels..., sublabels...],
[colnames[1:4]..., sublabels..., sublabels..., sublabels...],
]
sprint_float(x::Float64) = @sprintf("%.2f", x)
sprint_float(m::Missing) = "err"
Expand All @@ -211,7 +201,6 @@ function combine(head_filename::String, base_filename::String)
c.model_name,
c.dim,
c.ad_backend,
c.varinfo,
c.linked,
sprint_float(base_eval),
sprint_float(head_eval),
Expand Down
43 changes: 7 additions & 36 deletions benchmarks/src/DynamicPPLBenchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module DynamicPPLBenchmarks

using DynamicPPL: VarInfo, SimpleVarInfo, VarName
using DynamicPPL: VarInfo, VarName
using DynamicPPL: DynamicPPL
using DynamicPPL.TestUtils.AD: run_ad, NoTest
using ADTypes: ADTypes
Expand All @@ -23,7 +23,7 @@ Return the dimension of `model`, accounting for linking, if any.
"""
function model_dimension(model, islinked)
vi = VarInfo()
model(StableRNG(23), vi)
vi = last(DynamicPPL.init!!(StableRNG(23), model, vi))
if islinked
vi = DynamicPPL.link(vi, model)
end
Expand Down Expand Up @@ -52,53 +52,24 @@ function to_backend(x::Union{AbstractString,Symbol})
end

"""
benchmark(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool)
benchmark(model, adbackend::Symbol, islinked::Bool)

Benchmark evaluation and gradient calculation for `model` using the selected varinfo type
and AD backend.

Available varinfo choices:
• `:untyped` → uses `DynamicPPL.untyped_varinfo(model)`
• `:typed` → uses `DynamicPPL.typed_varinfo(model)`
• `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())`
• `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs)
Benchmark evaluation and gradient calculation for `model` using the selected AD backend.

The AD backend should be specified as a Symbol (e.g. `:forwarddiff`, `:reversediff`, `:zygote`).

`islinked` determines whether to link the VarInfo for evaluation.
"""
function benchmark(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool)
function benchmark(model, adbackend::Symbol, islinked::Bool)
rng = StableRNG(23)

vi = VarInfo(rng, model)
adbackend = to_backend(adbackend)

vi = if varinfo_choice == :untyped
DynamicPPL.untyped_varinfo(rng, model)
elseif varinfo_choice == :typed
DynamicPPL.typed_varinfo(rng, model)
elseif varinfo_choice == :simple_namedtuple
SimpleVarInfo{Float64}(model(rng))
elseif varinfo_choice == :simple_dict
retvals = model(rng)
vns = [VarName{k}() for k in keys(retvals)]
SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals))))
elseif varinfo_choice == :typed_vector
DynamicPPL.typed_vector_varinfo(rng, model)
elseif varinfo_choice == :untyped_vector
DynamicPPL.untyped_vector_varinfo(rng, model)
else
error("Unknown varinfo choice: $varinfo_choice")
end

adbackend = to_backend(adbackend)

if islinked
vi = DynamicPPL.link(vi, model)
end

return run_ad(
model, adbackend; varinfo=vi, benchmark=true, test=NoTest(), verbose=false
)
end

end # module
end
2 changes: 1 addition & 1 deletion benchmarks/src/Models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Models for benchmarking Turing.jl.

Each model returns a NamedTuple of all the random variables in the model that are not
observed (this is used for constructing SimpleVarInfos).
observed.
"""
module Models

Expand Down
2 changes: 0 additions & 2 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ DocumenterMermaid = "a078cd44-4d9c-4618-b545-3ab9d77f9177"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392"
Expand All @@ -24,7 +23,6 @@ DocumenterMermaid = "0.1, 0.2"
DynamicPPL = "0.40"
FillArrays = "0.13, 1"
ForwardDiff = "0.10, 1"
JET = "0.9, 0.10, 0.11"
LogDensityProblems = "2"
MCMCChains = "5, 6, 7"
MarginalLogDensities = "0.4"
Expand Down
Loading