Skip to content

Unify Turing Transitions, fix some tests #2651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 12, 2025

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Aug 11, 2025

This PR removes SGLDTransition, SMCTransition, and PGTransition in favour of just using plain old Transition.

This is pretty much the last thing I want to add to 0.40. The reason why I'd like to squeeze it in now is because it will correctly fix all the chain[:lp], chain[:logprior], etc. for all of the samplers we have, which feels like a nice reward after all the upfront work on DynamicPPL accumulators 😄.

It also adds a deepcopy to the Transition constructor as otherwise it would mutate the varinfo passed in which messed things up for SGLD (I'm not sure why it didn't mess anything up for the other samplers). This deepcopy used to be there previously

# NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
return DynamicPPL.values_as_in_model(model, true, deepcopy(vi))
end

so I think maybe removing it in #2630 was too ambitious.

Fixes most of #2631, although properly fixing it will have to wait for the removal of SampleFromPrior.

@penelopeysm penelopeysm changed the base branch from main to breaking August 11, 2025 13:59
@penelopeysm penelopeysm changed the base branch from breaking to mhauru/dppl-0.37 August 11, 2025 13:59
@penelopeysm penelopeysm marked this pull request as draft August 11, 2025 13:59
Copy link
Contributor

Turing.jl documentation for PR #2651 is available at:
https://TuringLang.github.io/Turing.jl/previews/PR2651/

Copy link

codecov bot commented Aug 11, 2025

Codecov Report

❌ Patch coverage is 94.73684% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 85.70%. Comparing base (bb21e1e) to head (7915287).
⚠️ Report is 1 commits behind head on mhauru/dppl-0.37.

Files with missing lines Patch % Lines
src/mcmc/particle_mcmc.jl 91.66% 1 Missing ⚠️
Additional details and impacted files
@@                 Coverage Diff                  @@
##           mhauru/dppl-0.37    #2651      +/-   ##
====================================================
- Coverage             85.87%   85.70%   -0.17%     
====================================================
  Files                    22       22              
  Lines                  1444     1434      -10     
====================================================
- Hits                   1240     1229      -11     
- Misses                  204      205       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@penelopeysm penelopeysm changed the title Unify Turing Transitions Unify Turing Transitions, fix some tests Aug 11, 2025
@@ -114,6 +127,24 @@ end
@test length(unique(c[:s])) == 1
end

@testset "addlogprob leads to reweighting" begin
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was sort of tested in PG-within-Gibbs, but we didn't have a PG-only test

Gibbs(:b => PG(10), :x => ESS()),
Gibbs(:b => PG(20), :x => ESS()),
Copy link
Member Author

@penelopeysm penelopeysm Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was failing on 1.10 due to numerical inaccuracy. Kind of unsure why it was only failing on 1.10 but not 1.11 given that we were using StableRNGs. My first guess would be the rng splitting in AdvancedPS.

I just bumped the atol up anyway because this test is so wonky (really we're mostly checking that it samples at all, since the results are incorrect depending on interpretation of model).

@penelopeysm penelopeysm marked this pull request as ready for review August 11, 2025 14:41
@penelopeysm penelopeysm requested a review from mhauru August 11, 2025 14:43
Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just one request for improving one of the tests.

)
# Avoid mutating vi as it may be used later e.g. when constructing
# sampler states.
vi = deepcopy(vi)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this could be a significant time cost. In which case we could make sure we have a proper copy method for varinfos in DPPL. Would probably be good to have that anyway.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does copy tend to be more performant than deepcopy?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it can be. Depends on your data structures. deepcopy can be slow because it plays it safe with aliasing and such, whereas copy does whatever you make it do.

end
c = sample(addlogprob_demo(), PG(10), 100)
# Result should be biased towards x > 0.
@test mean(c[:x]) > 0.5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@test mean(c[:x]) > 0.5
@test mean(c[:x]) > 0.7

Could this get a bit more clearance from 0.5, with, if necessary, the @addloprob! value being increased? Otherwise there's 50% chance this will pass just by luck.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, x is ~ Normal() so the chance of it being larger than 0.5 is probably lower than 50%. That does assume that the chain has fully mixed though, which might not be valid with 100 iterations, and there's no harm to making the addlogprob value larger, so I'll do that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I don't know why I didn't put in StableRNGs here. Will do so too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I wasn't thinking. I mixed up the cases of x ~ Normal() and x ~ Bernoulli().

@penelopeysm
Copy link
Member Author

Remaining CI failures are due to Libtask on 1.12 (3x pre), the Julia segfault bug JuliaLang/julia#54840 (1x Inference / min), and SciML/RecursiveArrayTools.jl#477 (2x everything else / min).

@penelopeysm penelopeysm requested a review from mhauru August 12, 2025 11:03
Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would like to remove the @show, but I'll just trust that you'll do it and approve already, because otherwise I'm happy.

Co-authored-by: Markus Hauru <[email protected]>
@penelopeysm penelopeysm merged commit 1bc2fbf into mhauru/dppl-0.37 Aug 12, 2025
25 of 30 checks passed
@penelopeysm penelopeysm deleted the py/single-transition branch August 12, 2025 11:27
penelopeysm added a commit that referenced this pull request Aug 12, 2025
* First efforts towards DPPL 0.37 compat, WIP

* More DPPL 0.37 compat work, WIP

* Add [sources] for [email protected]

* Remove context argument from `LogDensityFunction`

* Fix MH

* Remove spurious logging

* Remove residual OptimizationContext

* Delete files that were removed in previous releases

* Fix typo

* Simplify ESS

* Fix LDF

* Fix Prior(), fix a couple more imports

* fixes

* actually fix prior

* Remove extra return value from tilde_assume

* fix ldf

* actually fix prior

* fix HMC log-density

* fix ldf

* fix make_evaluate_...

* more fixes for evaluate!!

* fix hmc

* fix run_ad

* even more fixes (oh goodness when will this end)

* more fixes

* fix

* more fix fix fix

* fix return values of tilde pipeline

* even more fixes

* Fix missing import

* More MH fixes

* Fix conversion

* don't think it really needs those type params

* implement copy for LogPriorWithoutJacAcc

* Even more fixes

* More fixes; I think the remaining failures are pMCMC related

* Fix merge

* DPPL 0.37 compat for particle MCMC (#2625)

* Progress in DPPL 0.37 compat for particle MCMC

* WIP PMCMC work

* Gibbs fixes for DPPL 0.37 (plus tiny bugfixes for ESS + HMC) (#2628)

* Obviously this single commit will make Gibbs work

* Fixes for ESS

* Fix HMC call

* improve some comments

* Fixes to ProduceLogLikelihoodAccumulator

* Use LogProbAccumulator for ProduceLogLikelihoodAccumulator

* use get_conditioned_gibbs

---------

Co-authored-by: Penelope Yong <[email protected]>

* "Fixes" for PG-in-Gibbs (#2629)

* WIP PMCMC work

* Fixes to ProduceLogLikelihoodAccumulator

* inline definition of `set_retained_vns_del!`

* Fix ProduceLogLikelihoodAcc

* Remove all uses of `set_retained_vns_del!`

* Use nice functions

* Remove PG tests with dynamic number of Gibbs-conditioned-observations

* Fix essential/container tests

* Update pMCMC implementation as per discussion

* remove extra printing statements

* revert unneeded changes

* Add back (some kind of) dynamic model test

* fix rebase

* Add a todo comment for dynamic model tests

---------

Co-authored-by: Markus Hauru <[email protected]>

* Use accumulators to fix all logp calculations when sampling (#2630)

* Use new `getlogjoint` for optimisation

* Change getlogjoint -> getlogjoint_internal where needed

* Enforce re-evaluation when constructing `Transition`

* fix tests

* Remove extra evaluations from SGLD and SGHMC

* Remove dead `transitions_from_chain` method (used to be part of `predict`)

* metadata -> getstats_with_lp

* Clean up some stray getlogp

* InitContext isn't for 0.37, update comments

* Fix merge

* Do not re-evaluate model for Prior (#2644)

* Allow Prior to skip model re-evaluation

* remove unneeded `default_chain_type` method

* add a test

* add a likelihood term too

* why not test correctness while we're at it

* No need to test AD for SamplingContext{<:HMC} (#2645)

* change breaking -> main

* Remove calls to resetlogp!! & add changelog (#2650)

* Remove calls to resetlogp!!

* Add a changelog for 0.40

* Update HISTORY.md

Co-authored-by: Markus Hauru <[email protected]>

---------

Co-authored-by: Markus Hauru <[email protected]>

* Remove `[sources]`

* Unify Turing `Transition`s, fix some tests (#2651)

* Unify `Transition` methods

* Add tests

* Add same test for SGLD/SGHMC

* Refactor so that it's nice and organised

* Fix failing test on 1.10

* just increase the atol

* Make addlogprob test more robust

* Remove stray `@show`

Co-authored-by: Markus Hauru <[email protected]>

---------

Co-authored-by: Markus Hauru <[email protected]>

* Update changelog for PG in Gibbs

---------

Co-authored-by: Penelope Yong <[email protected]>
penelopeysm added a commit that referenced this pull request Aug 12, 2025
* [no ci] Bump to v0.40.0

* Uncomment tests that should be there

* Support DPPL 0.37 (#2550)

* First efforts towards DPPL 0.37 compat, WIP

* More DPPL 0.37 compat work, WIP

* Add [sources] for [email protected]

* Remove context argument from `LogDensityFunction`

* Fix MH

* Remove spurious logging

* Remove residual OptimizationContext

* Delete files that were removed in previous releases

* Fix typo

* Simplify ESS

* Fix LDF

* Fix Prior(), fix a couple more imports

* fixes

* actually fix prior

* Remove extra return value from tilde_assume

* fix ldf

* actually fix prior

* fix HMC log-density

* fix ldf

* fix make_evaluate_...

* more fixes for evaluate!!

* fix hmc

* fix run_ad

* even more fixes (oh goodness when will this end)

* more fixes

* fix

* more fix fix fix

* fix return values of tilde pipeline

* even more fixes

* Fix missing import

* More MH fixes

* Fix conversion

* don't think it really needs those type params

* implement copy for LogPriorWithoutJacAcc

* Even more fixes

* More fixes; I think the remaining failures are pMCMC related

* Fix merge

* DPPL 0.37 compat for particle MCMC (#2625)

* Progress in DPPL 0.37 compat for particle MCMC

* WIP PMCMC work

* Gibbs fixes for DPPL 0.37 (plus tiny bugfixes for ESS + HMC) (#2628)

* Obviously this single commit will make Gibbs work

* Fixes for ESS

* Fix HMC call

* improve some comments

* Fixes to ProduceLogLikelihoodAccumulator

* Use LogProbAccumulator for ProduceLogLikelihoodAccumulator

* use get_conditioned_gibbs

---------

Co-authored-by: Penelope Yong <[email protected]>

* "Fixes" for PG-in-Gibbs (#2629)

* WIP PMCMC work

* Fixes to ProduceLogLikelihoodAccumulator

* inline definition of `set_retained_vns_del!`

* Fix ProduceLogLikelihoodAcc

* Remove all uses of `set_retained_vns_del!`

* Use nice functions

* Remove PG tests with dynamic number of Gibbs-conditioned-observations

* Fix essential/container tests

* Update pMCMC implementation as per discussion

* remove extra printing statements

* revert unneeded changes

* Add back (some kind of) dynamic model test

* fix rebase

* Add a todo comment for dynamic model tests

---------

Co-authored-by: Markus Hauru <[email protected]>

* Use accumulators to fix all logp calculations when sampling (#2630)

* Use new `getlogjoint` for optimisation

* Change getlogjoint -> getlogjoint_internal where needed

* Enforce re-evaluation when constructing `Transition`

* fix tests

* Remove extra evaluations from SGLD and SGHMC

* Remove dead `transitions_from_chain` method (used to be part of `predict`)

* metadata -> getstats_with_lp

* Clean up some stray getlogp

* InitContext isn't for 0.37, update comments

* Fix merge

* Do not re-evaluate model for Prior (#2644)

* Allow Prior to skip model re-evaluation

* remove unneeded `default_chain_type` method

* add a test

* add a likelihood term too

* why not test correctness while we're at it

* No need to test AD for SamplingContext{<:HMC} (#2645)

* change breaking -> main

* Remove calls to resetlogp!! & add changelog (#2650)

* Remove calls to resetlogp!!

* Add a changelog for 0.40

* Update HISTORY.md

Co-authored-by: Markus Hauru <[email protected]>

---------

Co-authored-by: Markus Hauru <[email protected]>

* Remove `[sources]`

* Unify Turing `Transition`s, fix some tests (#2651)

* Unify `Transition` methods

* Add tests

* Add same test for SGLD/SGHMC

* Refactor so that it's nice and organised

* Fix failing test on 1.10

* just increase the atol

* Make addlogprob test more robust

* Remove stray `@show`

Co-authored-by: Markus Hauru <[email protected]>

---------

Co-authored-by: Markus Hauru <[email protected]>

* Update changelog for PG in Gibbs

---------

Co-authored-by: Penelope Yong <[email protected]>

---------

Co-authored-by: Markus Hauru <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants