Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7cddac7
Fast Log Density Function
penelopeysm Nov 5, 2025
5ed4295
Make it work with AD
penelopeysm Nov 6, 2025
e199520
Optimise performance for identity VarNames
penelopeysm Nov 6, 2025
4cefaca
Mark `get_range_and_linked` as having zero derivative
penelopeysm Nov 6, 2025
6dfd106
Update comment
penelopeysm Nov 6, 2025
41ee7f3
make AD testing / benchmarking use FastLDF
penelopeysm Nov 6, 2025
22e32a6
Fix tests
penelopeysm Nov 6, 2025
79cc128
Optimise away `make_evaluate_args_and_kwargs`
penelopeysm Nov 6, 2025
f7c6a78
const func annotation
penelopeysm Nov 6, 2025
b1a7650
Disable benchmarks on non-typed-Metadata-VarInfo
penelopeysm Nov 6, 2025
e60873a
Fix `_evaluate!!` correctly to handle submodels
penelopeysm Nov 6, 2025
fa0664e
Actually fix submodel evaluate
penelopeysm Nov 6, 2025
09a1fbb
Document thoroughly and organise code
penelopeysm Nov 6, 2025
7306ba4
Support more VarInfos, make it thread-safe (?)
penelopeysm Nov 6, 2025
53bccc1
fix bug in parsing ranges from metadata/VNV
penelopeysm Nov 6, 2025
30b9247
Fix get_param_eltype for TSVI
penelopeysm Nov 6, 2025
316937a
Disable Enzyme benchmark
penelopeysm Nov 6, 2025
075cee8
Don't override _evaluate!!, that breaks ForwardDiff (sometimes)
penelopeysm Nov 6, 2025
5f5a92c
Move FastLDF to experimental for now
penelopeysm Nov 6, 2025
0716de5
Fix imports, add tests, etc
penelopeysm Nov 6, 2025
cd2461e
More test fixes
penelopeysm Nov 6, 2025
1b8b873
Fix imports / tests
penelopeysm Nov 6, 2025
ff5680d
Remove AbstractFastEvalContext
penelopeysm Nov 6, 2025
500d5ac
Changelog and patch bump
penelopeysm Nov 6, 2025
e560c30
Add correctness tests, fix imports
penelopeysm Nov 6, 2025
22e9dbe
Merge branch 'main' into py/fastldf
penelopeysm Nov 6, 2025
c3bdcd0
Merge branch 'main' into py/fastldf
penelopeysm Nov 6, 2025
86d8a73
Concretise parameter vector in tests
penelopeysm Nov 8, 2025
4ec0c72
Merge branch 'main' into py/fastldf
penelopeysm Nov 8, 2025
4b324b0
Add zero-allocation tests
penelopeysm Nov 9, 2025
c55171b
Add Chairmarks as test dep
penelopeysm Nov 9, 2025
77deae9
Disable allocations tests on multi-threaded
penelopeysm Nov 9, 2025
8715446
Fast InitContext (#1125)
penelopeysm Nov 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions ext/DynamicPPLEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
module DynamicPPLEnzymeCoreExt

if isdefined(Base, :get_extension)
using DynamicPPL: DynamicPPL
using EnzymeCore
else
using ..DynamicPPL: DynamicPPL
using ..EnzymeCore
end
using DynamicPPL: DynamicPPL
using EnzymeCore

# Mark is_transformed as having 0 derivative. The `nothing` return value is not significant, Enzyme
# only checks whether such a method exists, and never runs it.
@inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.is_transformed), args...) =
nothing
# Likewise for get_range_and_linked.
@inline EnzymeCore.EnzymeRules.inactive(
::typeof(DynamicPPL.get_range_and_linked), args...
) = nothing

end
3 changes: 2 additions & 1 deletion ext/DynamicPPLMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
module DynamicPPLMooncakeExt

using DynamicPPL: DynamicPPL, is_transformed
using DynamicPPL: DynamicPPL, is_transformed, get_range_and_linked
using Mooncake: Mooncake

# This is purely an optimisation.
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg}
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(get_range_and_linked),Vararg}

end # module
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ include("simple_varinfo.jl")
include("compiler.jl")
include("pointwise_logdensities.jl")
include("logdensityfunction.jl")
include("fastldf.jl")
include("model_utils.jl")
include("extract_priors.jl")
include("values_as_in_model.jl")
Expand Down
176 changes: 176 additions & 0 deletions src/fastldf.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo
accs::Accs
end
DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs
DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi
Copy link
Member Author

@penelopeysm penelopeysm Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out that a substantial amount of performance gains were in this innocuous line 😄 Just optimising the implementation of this function for normal VarInfo, plus sticking some @views in appropriate places, gets the old LogDensityFunction to within ~ 50% of FastLDF's performance. See #1115.

In some ways the simplification here is a result of the current approach: because the linked status is no longer stored in a VarInfo object but rather a separate struct, there is no longer any need to care about linking / invlinking a VarInfo.

But overall I think this means that the performance gains from the current approach are much more comprehensible.

DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs)

struct RangeAndLinked
# indices that the variable corresponds to in the vectorised parameter
range::UnitRange{Int}
# whether it's linked
is_linked::Bool
end

struct FastLDFContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext
# The ranges of identity VarNames are stored in a NamedTuple for improved performance
# (it's around 1.5x faster).
iden_varname_ranges::N
# This Dict stores the ranges for all other VarNames
varname_ranges::Dict{VarName,RangeAndLinked}
# The full parameter vector which we index into to get variable values
params::T
end
DynamicPPL.NodeTrait(::FastLDFContext) = IsLeaf()
function get_range_and_linked(
ctx::FastLDFContext, ::VarName{sym,typeof(identity)}
) where {sym}
return ctx.iden_varname_ranges[sym]
end
function get_range_and_linked(ctx::FastLDFContext, vn::VarName)
return ctx.varname_ranges[vn]
end

function tilde_assume!!(
ctx::FastLDFContext, right::Distribution, vn::VarName, vi::OnlyAccsVarInfo
)
# Don't need to read the data from the varinfo at all since it's
# all inside the context.
range_and_linked = get_range_and_linked(ctx, vn)
y = @view ctx.params[range_and_linked.range]
f = if range_and_linked.is_linked
from_linked_vec_transform(right)
else
from_vec_transform(right)
end
x, inv_logjac = with_logabsdet_jacobian(f, y)
vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right)
return x, vi
end

function tilde_observe!!(
::FastLDFContext,
right::Distribution,
left,
vn::Union{VarName,Nothing},
vi::OnlyAccsVarInfo,
)
# This is the same as for DefaultContext
vi = accumulate_observe!!(vi, right, left, vn)
return left, vi
end

struct FastLDF{
M<:Model,
F<:Function,
N<:NamedTuple,
AD<:Union{ADTypes.AbstractADType,Nothing},
ADP<:Union{Nothing,DI.GradientPrep},
}
_model::M
_getlogdensity::F
# See FastLDFContext for explanation of these two fields
_iden_varname_ranges::N
_varname_ranges::Dict{VarName,RangeAndLinked}
_adtype::AD
_adprep::ADP

function FastLDF(
model::Model,
getlogdensity::Function,
# This only works with typed Metadata-varinfo.
# Obviously, this can be generalised later.
varinfo::VarInfo{<:NamedTuple{syms}};
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
) where {syms}
# Figure out which variable corresponds to which index, and
# which variables are linked.
all_iden_ranges = NamedTuple()
all_ranges = Dict{VarName,RangeAndLinked}()
offset = 1
for sym in syms
md = varinfo.metadata[sym]
for (vn, idx) in md.idcs
len = length(md.ranges[idx])
is_linked = md.is_transformed[idx]
range = offset:(offset + len - 1)
if AbstractPPL.getoptic(vn) === identity
all_iden_ranges = merge(
all_iden_ranges,
NamedTuple((
AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),
)),
)
else
all_ranges[vn] = RangeAndLinked(range, is_linked)
end
offset += len
end
end
# Do AD prep if needed
prep = if adtype === nothing
nothing
else
# Make backend-specific tweaks to the adtype
adtype = tweak_adtype(adtype, model, varinfo)
x = [val for val in varinfo[:]]
DI.prepare_gradient(
FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges),
adtype,
x,
)
end

return new{
typeof(model),
typeof(getlogdensity),
typeof(all_iden_ranges),
typeof(adtype),
typeof(prep),
}(
model, getlogdensity, all_iden_ranges, all_ranges, adtype, prep
)
end
end

struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
_model::M
_getlogdensity::F
_iden_varname_ranges::N
_varname_ranges::Dict{VarName,RangeAndLinked}
end
function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
ctx = FastLDFContext(f._iden_varname_ranges, f._varname_ranges, params)
model = DynamicPPL.setleafcontext(f._model, ctx)
# This can obviously also be optimised for the case where not
# all accumulators are needed.
accs = AccumulatorTuple((
LogPriorAccumulator(), LogLikelihoodAccumulator(), LogJacobianAccumulator()
))
_, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs))
return f._getlogdensity(vi)
end

function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real})
return FastLogDensityAt(
fldf._model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges
)(
params
)
end

function LogDensityProblems.logdensity_and_gradient(
fldf::FastLDF, params::AbstractVector{<:Real}
)
return DI.value_and_gradient(
FastLogDensityAt(
fldf._model,
fldf._getlogdensity,
fldf._iden_varname_ranges,
fldf._varname_ranges,
),
fldf._adprep,
fldf._adtype,
params,
)
end