Skip to content

Use typed varinfo in Prior #2649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 60 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
c700ddc
[no ci] Bump to v0.40.0
penelopeysm Jun 6, 2025
a6f2c9d
Merge branch 'main' into breaking
penelopeysm Jul 19, 2025
0164e84
Merge remote-tracking branch 'origin/main' into breaking
penelopeysm Jul 19, 2025
cea1f7d
First efforts towards DPPL 0.37 compat, WIP
mhauru May 15, 2025
5d860d9
More DPPL 0.37 compat work, WIP
mhauru May 20, 2025
c7c4638
Add [sources] for [email protected]
penelopeysm Jul 17, 2025
f16a5cf
Remove context argument from `LogDensityFunction`
penelopeysm Jul 19, 2025
98d5e7a
Fix MH
penelopeysm Jul 19, 2025
73e127b
Remove spurious logging
penelopeysm Jul 19, 2025
ce0c782
Remove residual OptimizationContext
penelopeysm Jul 19, 2025
4d03c07
Delete files that were removed in previous releases
penelopeysm Jul 19, 2025
06fec2d
Fix typo
penelopeysm Jul 19, 2025
0af8725
Simplify ESS
penelopeysm Jul 19, 2025
3d44c12
Fix LDF
penelopeysm Jul 19, 2025
a1837b5
Fix Prior(), fix a couple more imports
penelopeysm Jul 19, 2025
17efb8c
fixes
penelopeysm Jul 19, 2025
d62ad82
actually fix prior
penelopeysm Jul 19, 2025
aac93f1
Remove extra return value from tilde_assume
penelopeysm Jul 19, 2025
e903d1c
fix ldf
penelopeysm Jul 19, 2025
fd5a815
actually fix prior
penelopeysm Jul 19, 2025
10a130a
fix HMC log-density
penelopeysm Jul 20, 2025
c630723
fix ldf
penelopeysm Jul 20, 2025
9cbb2e9
fix make_evaluate_...
penelopeysm Jul 20, 2025
335cd2a
more fixes for evaluate!!
penelopeysm Jul 20, 2025
c912fb9
fix hmc
penelopeysm Jul 20, 2025
195f819
fix run_ad
penelopeysm Jul 20, 2025
cd52e9f
even more fixes (oh goodness when will this end)
penelopeysm Jul 20, 2025
9360f18
more fixes
penelopeysm Jul 20, 2025
64ebd92
fix
penelopeysm Jul 20, 2025
283d4dd
more fix fix fix
penelopeysm Jul 20, 2025
b346198
fix return values of tilde pipeline
penelopeysm Jul 20, 2025
9012774
even more fixes
penelopeysm Jul 20, 2025
e600589
Fix missing import
penelopeysm Jul 20, 2025
3d5072f
More MH fixes
penelopeysm Jul 20, 2025
37466cc
Fix conversion
penelopeysm Jul 20, 2025
1b73e5a
don't think it really needs those type params
penelopeysm Jul 20, 2025
66a8544
implement copy for LogPriorWithoutJacAcc
penelopeysm Jul 20, 2025
98e70c2
Even more fixes
penelopeysm Jul 20, 2025
d2c1c92
More fixes; I think the remaining failures are pMCMC related
penelopeysm Jul 20, 2025
465642e
Merge branch 'main' into breaking
penelopeysm Jul 21, 2025
a21f24d
Merge branch 'breaking' into mhauru/dppl-0.37
penelopeysm Jul 21, 2025
11a2a31
Fix merge
penelopeysm Jul 21, 2025
966e17b
Merge branch 'main' into breaking
penelopeysm Jul 28, 2025
7ca59ce
Merge branch 'breaking' into mhauru/dppl-0.37
penelopeysm Jul 28, 2025
c062867
DPPL 0.37 compat for particle MCMC (#2625)
mhauru Jul 31, 2025
7124864
"Fixes" for PG-in-Gibbs (#2629)
penelopeysm Jul 31, 2025
8fdecc0
Use accumulators to fix all logp calculations when sampling (#2630)
penelopeysm Aug 1, 2025
9d48201
Merge branch 'main' into breaking
penelopeysm Aug 1, 2025
eb2b7a7
Uncomment tests that should be there
penelopeysm Aug 1, 2025
27aab23
Merge branch 'breaking' into mhauru/dppl-0.37
penelopeysm Aug 1, 2025
119c818
InitContext isn't for 0.37, update comments
penelopeysm Aug 1, 2025
b41a4b1
Fix merge
penelopeysm Aug 1, 2025
d92fd56
Do not re-evaluate model for Prior (#2644)
penelopeysm Aug 5, 2025
806c82d
No need to test AD for SamplingContext{<:HMC} (#2645)
penelopeysm Aug 5, 2025
422fc68
Use typed varinfo for Prior
penelopeysm Aug 7, 2025
a4a4304
fix VAIMAcc bug
penelopeysm Aug 7, 2025
5743ff7
change breaking -> main
penelopeysm Aug 7, 2025
02515b5
Merge branch 'mhauru/dppl-0.37' into py/fast-prior
penelopeysm Aug 10, 2025
eff3bd9
No need to replace the accumulator following https://github.com/Turin…
penelopeysm Aug 10, 2025
42a3ceb
Add an explanatory comment
penelopeysm Aug 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# 0.40.0

TODO

- DynamicPPL 0.37 stuff

- pMCMC and Gibbs?
- Prior is faster

# 0.39.9

Revert a bug introduced in 0.39.5 in the external sampler interface.
Expand Down
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.39.9"
version = "0.40.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -64,7 +64,7 @@ Distributions = "0.25.77"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.36.3"
DynamicPPL = "0.37"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3, 1"
Libtask = "0.9.3"
Expand All @@ -90,3 +90,6 @@ julia = "1.10.8"
[extras]
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"

[sources]
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "main"}
21 changes: 6 additions & 15 deletions ext/TuringDynamicHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,12 @@ function DynamicPPL.initialstep(
# Ensure that initial sample is in unconstrained space.
if !DynamicPPL.islinked(vi)
vi = DynamicPPL.link!!(vi, model)
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))
vi = last(DynamicPPL.evaluate!!(model, vi))
end

# Define log-density function.
ℓ = DynamicPPL.LogDensityFunction(
model,
vi,
DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext());
adtype=spl.alg.adtype,
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype
)

# Perform initial step.
Expand All @@ -76,12 +73,9 @@ function DynamicPPL.initialstep(
steps = DynamicHMC.mcmc_steps(results.sampling_logdensity, results.final_warmup_state)
Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q)

# Update the variables.
vi = DynamicPPL.unflatten(vi, Q.q)
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)

# Create first sample and state.
sample = Turing.Inference.Transition(model, vi)
vi = DynamicPPL.unflatten(vi, Q.q)
sample = Turing.Inference.Transition(model, vi, nothing)
state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ)

return sample, state
Expand All @@ -100,12 +94,9 @@ function AbstractMCMC.step(
steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ℓ, state.stepsize)
Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache)

# Update the variables.
vi = DynamicPPL.unflatten(vi, Q.q)
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)

# Create next sample and state.
sample = Turing.Inference.Transition(model, vi)
vi = DynamicPPL.unflatten(vi, Q.q)
sample = Turing.Inference.Transition(model, vi, nothing)
newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize)

return sample, newstate
Expand Down
27 changes: 13 additions & 14 deletions ext/TuringOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ function Optim.optimize(
options::Optim.Options=Optim.Options();
kwargs...,
)
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
f = Optimisation.OptimLogDensity(model, ctx)
f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood)
init_vals = DynamicPPL.getparams(f.ldf)
optimizer = Optim.LBFGS()
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
Expand All @@ -57,8 +56,7 @@ function Optim.optimize(
options::Optim.Options=Optim.Options();
kwargs...,
)
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
f = Optimisation.OptimLogDensity(model, ctx)
f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood)
init_vals = DynamicPPL.getparams(f.ldf)
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
end
Expand All @@ -74,8 +72,8 @@ function Optim.optimize(
end

function _mle_optimize(model::DynamicPPL.Model, args...; kwargs...)
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)
f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood)
return _optimize(f, args...; kwargs...)
end

"""
Expand Down Expand Up @@ -104,8 +102,7 @@ function Optim.optimize(
options::Optim.Options=Optim.Options();
kwargs...,
)
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
f = Optimisation.OptimLogDensity(model, ctx)
f = Optimisation.OptimLogDensity(model, DynamicPPL.getlogjoint)
init_vals = DynamicPPL.getparams(f.ldf)
optimizer = Optim.LBFGS()
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
Expand All @@ -127,8 +124,7 @@ function Optim.optimize(
options::Optim.Options=Optim.Options();
kwargs...,
)
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
f = Optimisation.OptimLogDensity(model, ctx)
f = Optimisation.OptimLogDensity(model, DynamicPPL.getlogjoint)
init_vals = DynamicPPL.getparams(f.ldf)
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
end
Expand All @@ -144,9 +140,10 @@ function Optim.optimize(
end

function _map_optimize(model::DynamicPPL.Model, args...; kwargs...)
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)
f = Optimisation.OptimLogDensity(model, DynamicPPL.getlogjoint)
return _optimize(f, args...; kwargs...)
end

"""
_optimize(f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...)

Expand All @@ -166,7 +163,9 @@ function _optimize(
# whether initialisation is really necessary at all
vi = DynamicPPL.unflatten(f.ldf.varinfo, init_vals)
vi = DynamicPPL.link(vi, f.ldf.model)
f = Optimisation.OptimLogDensity(f.ldf.model, vi, f.ldf.context; adtype=f.ldf.adtype)
f = Optimisation.OptimLogDensity(
f.ldf.model, f.ldf.getlogdensity, vi; adtype=f.ldf.adtype
)
init_vals = DynamicPPL.getparams(f.ldf)

# Optimize!
Expand All @@ -184,7 +183,7 @@ function _optimize(
vi = f.ldf.varinfo
vi_optimum = DynamicPPL.unflatten(vi, M.minimizer)
logdensity_optimum = Optimisation.OptimLogDensity(
f.ldf.model, vi_optimum, f.ldf.context
f.ldf.model, f.ldf.getlogdensity, vi_optimum; adtype=f.ldf.adtype
)
vals_dict = Turing.Inference.getparams(f.ldf.model, vi_optimum)
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict))
Expand Down
2 changes: 0 additions & 2 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ using DynamicPPL:
unfix,
prefix,
conditioned,
@submodel,
to_submodel,
LogDensityFunction,
@addlogprob!
Expand All @@ -81,7 +80,6 @@ using OrderedCollections: OrderedDict
# Turing essentials - modelling macros and inference algorithms
export
# DEPRECATED
@submodel,
generated_quantities,
# Modelling - AbstractPPL and DynamicPPL
@model,
Expand Down
Loading
Loading