Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion .github/workflows/Enzyme.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ concurrency:

jobs:
enzyme:
runs-on: ubuntu-latest
runs-on: macos-latest
steps:
- uses: actions/checkout@v5

Expand All @@ -27,9 +27,19 @@ jobs:
version: "1.11"

- uses: julia-actions/cache@v2
id: julia-cache

- name: Run AD with Enzyme on demo models
working-directory: test/integration/enzyme
run: |
julia --project=. --color=yes -e 'using Pkg; Pkg.instantiate()'
julia --project=. --color=yes main.jl

- name: Save Julia depot cache on cancel or failure
id: julia-cache-save
if: cancelled() || failure()
uses: actions/cache/save@v4
with:
path: |
${{ steps.julia-cache.outputs.cache-paths }}
key: ${{ steps.julia-cache.outputs.cache-key }}
37 changes: 25 additions & 12 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@

### Breaking changes

#### Fast Log Density Functions

This version provides a reimplementation of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation.
Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`.

For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments.

As a result of this change, `LogDensityFunction` no longer stores a VarInfo inside it.
In general, if `ldf` is a `LogDensityFunction`, it is now only valid to access `ldf.model` and `ldf.adtype`.
If you were previously relying on this behaviour, you will need to store a VarInfo separately.

Along with this change, DynamicPPL now exposes the `fast_evaluate!!` method which allows you to hook into this 'fast evaluation' pipeline directly.
Please see the documentation for details.

#### Parent and leaf contexts

The `DynamicPPL.NodeTrait` function has been removed.
Expand All @@ -17,25 +31,24 @@ Leaf contexts require no changes, apart from a removal of the `NodeTrait` functi
`ConditionContext` and `PrefixContext` are no longer exported.
You should not need to use these directly, please use `AbstractPPL.condition` and `DynamicPPL.prefix` instead.

#### SimpleVarInfo

`SimpleVarInfo` has been removed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SimpleVarInfo has cleaner code than VarInfo. Doesn't it make more sense to remove VarInfo and replace it with SimpleVarInfo?

Copy link
Member Author

@penelopeysm penelopeysm Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Markus can probably say more, but I don't think the choice is between SVI and VI. The structs are actually the same except that SVI carries an extra field to store a transformation (which would typically be part of VI's metadata field).

struct SimpleVarInfo{NT,Accs<:AccumulatorTuple where {N},C<:AbstractTransformation} <:
AbstractVarInfo
"underlying representation of the realization represented"
values::NT
"tuple of accumulators for things like log prior and log likelihood"
accs::Accs
"represents whether it assumes variables to be transformed"
transformation::C
end

struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo
metadata::Tmeta
accs::Accs
end

I think the complexity lies in what kinds of metadata they carry. SVI code looks simpler because src/simple_varinfo.jl only handles NamedTuple or Dict metadata. VarInfo code looks terrible because src/varinfo.jl mostly deals with Metadata metadata (hence all the generated functions).

So instead of saying removing SVI, maybe it would have been more accurate to say that this is removing NamedTuple and Dict as supported forms of metadata. As far as I can tell it's mostly a coincidence that this code is associated with SVI.

If the aim is to clean up src/varinfo.jl, then what really needs to happen is to replace Metadata with VNV (although VNV has 1.7k lines of code to itself, so IMO that's not exactly clean; but it probably feels better because it's in its own file, more documented, etc.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(#1105)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, the plan is to use SimpleVarInfo with VNV and VNT as the unifying tracing data type, largely because SimpleVarInfo has a cleaner API and better documentation. I agree that the differences between SVI and VI aren’t all that significant.

Its main purpose was for evaluating models rapidly.
However, `fast_evaluate!!` provides a cleaner way of doing this.
In particular, if you want to evaluate a model at a given set of parameters, you can do:

```julia
retval, vi = DynamicPPL.fast_evaluate!!(rng, model, InitFromParams(params), accs)
```

#### Miscellaneous

Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead.

The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space.
This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function).

### Other changes

#### FastLDF

Added `DynamicPPL.Experimental.FastLDF`, a version of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation.
Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`.

Please note that `FastLDF` is currently considered internal and its API may change without warning.
We intend to replace `LogDensityFunction` with `FastLDF` in a release in the near future, but until then we recommend not using it.

For more information about `FastLDF`, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments.

## 0.38.9

Remove warning when using Enzyme as the AD backend.
Expand Down
2 changes: 0 additions & 2 deletions benchmarks/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ function run(; to_json=false)
false,
),
("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false),
("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true),
Expand Down
10 changes: 1 addition & 9 deletions benchmarks/src/DynamicPPLBenchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module DynamicPPLBenchmarks

using DynamicPPL: VarInfo, SimpleVarInfo, VarName
using DynamicPPL: VarInfo, VarName
using DynamicPPL: DynamicPPL
using DynamicPPL.TestUtils.AD: run_ad, NoTest
using ADTypes: ADTypes
Expand Down Expand Up @@ -60,8 +60,6 @@ and AD backend.
Available varinfo choices:
• `:untyped` → uses `DynamicPPL.untyped_varinfo(model)`
• `:typed` → uses `DynamicPPL.typed_varinfo(model)`
• `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())`
• `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs)

The AD backend should be specified as a Symbol (e.g. `:forwarddiff`, `:reversediff`, `:zygote`).

Expand All @@ -76,12 +74,6 @@ function benchmark(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::B
DynamicPPL.untyped_varinfo(rng, model)
elseif varinfo_choice == :typed
DynamicPPL.typed_varinfo(rng, model)
elseif varinfo_choice == :simple_namedtuple
SimpleVarInfo{Float64}(model(rng))
elseif varinfo_choice == :simple_dict
retvals = model(rng)
vns = [VarName{k}() for k in keys(retvals)]
SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals))))
elseif varinfo_choice == :typed_vector
DynamicPPL.typed_vector_varinfo(rng, model)
elseif varinfo_choice == :untyped_vector
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/src/Models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Models for benchmarking Turing.jl.

Each model returns a NamedTuple of all the random variables in the model that are not
observed (this is used for constructing SimpleVarInfos).
observed.
"""
module Models

Expand Down
13 changes: 7 additions & 6 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) inte
LogDensityFunction
```

Internally, this is accomplished using:

```@docs
OnlyAccsVarInfo
fast_evaluate!!
```

## Condition and decondition

A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref).
Expand Down Expand Up @@ -352,12 +359,6 @@ set_transformed!!
Base.empty!
```

#### `SimpleVarInfo`

```@docs
SimpleVarInfo
```

### Accumulators

The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators.
Expand Down
11 changes: 8 additions & 3 deletions ext/DynamicPPLMarginalLogDensitiesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@ using MarginalLogDensities: MarginalLogDensities
# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by
# MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type
# below.
struct LogDensityFunctionWrapper{L<:DynamicPPL.LogDensityFunction}
struct LogDensityFunctionWrapper{
L<:DynamicPPL.LogDensityFunction,V<:DynamicPPL.AbstractVarInfo
}
logdensity::L
# This field is used only to reconstruct the VarInfo later on; it's not needed for the
# actual log-density evaluation.
varinfo::V
end
function (lw::LogDensityFunctionWrapper)(x, _)
return LogDensityProblems.logdensity(lw.logdensity, x)
Expand Down Expand Up @@ -101,7 +106,7 @@ function DynamicPPL.marginalize(
# Construct the marginal log-density model.
f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
mld = MarginalLogDensities.MarginalLogDensity(
LogDensityFunctionWrapper(f), varinfo[:], varindices, (), method; kwargs...
LogDensityFunctionWrapper(f, varinfo), varinfo[:], varindices, (), method; kwargs...
)
return mld
end
Expand Down Expand Up @@ -190,7 +195,7 @@ function DynamicPPL.VarInfo(
unmarginalized_params::Union{AbstractVector,Nothing}=nothing,
)
# Extract the original VarInfo. Its contents will in general be junk.
original_vi = mld.logdensity.logdensity.varinfo
original_vi = mld.logdensity.varinfo
# Extract the stored parameters, which includes the modes for any marginalized
# parameters
full_params = MarginalLogDensities.cached_params(mld)
Expand Down
10 changes: 5 additions & 5 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ import Base:
# VarInfo
export AbstractVarInfo,
VarInfo,
SimpleVarInfo,
AbstractAccumulator,
LogLikelihoodAccumulator,
LogPriorAccumulator,
Expand Down Expand Up @@ -92,8 +91,10 @@ export AbstractVarInfo,
getargnames,
extract_priors,
values_as_in_model,
# LogDensityFunction
# LogDensityFunction and fasteval
LogDensityFunction,
fast_evaluate!!,
OnlyAccsVarInfo,
# Leaf contexts
AbstractContext,
contextualize,
Expand Down Expand Up @@ -172,7 +173,7 @@ Abstract supertype for data structures that capture random variables when execut
probabilistic model and accumulate log densities such as the log likelihood or the
log joint probability of the model.

See also: [`VarInfo`](@ref), [`SimpleVarInfo`](@ref).
See also: [`VarInfo`](@ref).
"""
abstract type AbstractVarInfo <: AbstractModelTrace end

Expand All @@ -194,11 +195,10 @@ include("default_accumulators.jl")
include("abstract_varinfo.jl")
include("threadsafe.jl")
include("varinfo.jl")
include("simple_varinfo.jl")
include("onlyaccs.jl")
include("compiler.jl")
include("pointwise_logdensities.jl")
include("logdensityfunction.jl")
include("fasteval.jl")
include("model_utils.jl")
include("extract_priors.jl")
include("values_as_in_model.jl")
Expand Down
95 changes: 4 additions & 91 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -502,52 +502,6 @@ If no `Type` is provided, return values as stored in `varinfo`.

# Examples

`SimpleVarInfo` with `NamedTuple`:

```jldoctest
julia> data = (x = 1.0, m = [2.0]);

julia> values_as(SimpleVarInfo(data))
(x = 1.0, m = [2.0])

julia> values_as(SimpleVarInfo(data), NamedTuple)
(x = 1.0, m = [2.0])

julia> values_as(SimpleVarInfo(data), OrderedDict)
OrderedDict{VarName{sym, typeof(identity)} where sym, Any} with 2 entries:
x => 1.0
m => [2.0]

julia> values_as(SimpleVarInfo(data), Vector)
2-element Vector{Float64}:
1.0
2.0
```

`SimpleVarInfo` with `OrderedDict`:

```jldoctest
julia> data = OrderedDict{Any,Any}(@varname(x) => 1.0, @varname(m) => [2.0]);

julia> values_as(SimpleVarInfo(data))
OrderedDict{Any, Any} with 2 entries:
x => 1.0
m => [2.0]

julia> values_as(SimpleVarInfo(data), NamedTuple)
(x = 1.0, m = [2.0])

julia> values_as(SimpleVarInfo(data), OrderedDict)
OrderedDict{Any, Any} with 2 entries:
x => 1.0
m => [2.0]

julia> values_as(SimpleVarInfo(data), Vector)
2-element Vector{Float64}:
1.0
2.0
```

`VarInfo` with `NamedTuple` of `Metadata`:

```jldoctest
Expand Down Expand Up @@ -828,8 +782,8 @@ function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
return link!!(default_transformation(model, vi), vi, vns, model)
end
function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
# Note that in practice this method is only called for SimpleVarInfo, because VarInfo
# has a dedicated implementation
# Note that VarInfo has a dedicated implementation so this is only a generic
# fallback (previously used for SimpleVarInfo)
model = setleafcontext(model, DynamicTransformationContext{false}())
vi = last(evaluate!!(model, vi))
return set_transformed!!(vi, t)
Expand Down Expand Up @@ -890,8 +844,8 @@ function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
return invlink!!(default_transformation(model, vi), vi, vns, model)
end
function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model)
# Note that in practice this method is only called for SimpleVarInfo, because VarInfo
# has a dedicated implementation
# Note that VarInfo has a dedicated implementation so this is only a generic
# fallback (previously used for SimpleVarInfo)
model = setleafcontext(model, DynamicTransformationContext{true}())
vi = last(evaluate!!(model, vi))
return set_transformed!!(vi, NoTransformation())
Expand Down Expand Up @@ -946,47 +900,6 @@ This will be called prior to `model` evaluation, allowing one to perform a singl
basis as is done with [`DynamicTransformation`](@ref).

See also: [`StaticTransformation`](@ref), [`DynamicTransformation`](@ref).

# Examples
```julia-repl
julia> using DynamicPPL, Distributions, Bijectors

julia> @model demo() = x ~ Normal()
demo (generic function with 2 methods)

julia> # By subtyping `Transform`, we inherit the `(inv)link!!`.
struct MyBijector <: Bijectors.Transform end

julia> # Define some dummy `inverse` which will be used in the `link!!` call.
Bijectors.inverse(f::MyBijector) = identity

julia> # We need to define `with_logabsdet_jacobian` for `MyBijector`
# (`identity` already has `with_logabsdet_jacobian` defined)
function Bijectors.with_logabsdet_jacobian(::MyBijector, x)
# Just using a large number of the logabsdet-jacobian term
# for demonstration purposes.
return (x, 1000)
end

julia> # Change the `default_transformation` for our model to be a
# `StaticTransformation` using `MyBijector`.
function DynamicPPL.default_transformation(::Model{typeof(demo)})
return DynamicPPL.StaticTransformation(MyBijector())
end

julia> model = demo();

julia> vi = SimpleVarInfo(x=1.0)
SimpleVarInfo((x = 1.0,), 0.0)

julia> # Uses the `inverse` of `MyBijector`, which we have defined as `identity`
vi_linked = link!!(vi, model)
Transformed SimpleVarInfo((x = 1.0,), 0.0)

julia> # Now performs a single `invlink!!` before model evaluation.
logjoint(model, vi_linked)
-1001.4189385332047
```
"""
function maybe_invlink_before_eval!!(vi::AbstractVarInfo, model::Model)
return maybe_invlink_before_eval!!(transformation(vi), vi, model)
Expand Down
2 changes: 1 addition & 1 deletion src/contexts/transformation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ constrained space if `isinverse` or unconstrained if `!isinverse`.
Note that some `AbstractVarInfo` types, must notably `VarInfo`, override the
`DynamicTransformationContext` methods with more efficient implementations.
`DynamicTransformationContext` is a fallback for when we need to evaluate the model to know
how to do the transformation, used by e.g. `SimpleVarInfo`.
how to do the transformation.
"""
struct DynamicTransformationContext{isinverse} <: AbstractContext end

Expand Down
2 changes: 0 additions & 2 deletions src/experimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ module Experimental

using DynamicPPL: DynamicPPL

include("fasteval.jl")

# This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency.
"""
is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...)
Expand Down
Loading
Loading