-
Notifications
You must be signed in to change notification settings - Fork 228
Unify Turing Transition
s, fix some tests
#2651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Turing.jl documentation for PR #2651 is available at: |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## mhauru/dppl-0.37 #2651 +/- ##
====================================================
- Coverage 85.87% 85.70% -0.17%
====================================================
Files 22 22
Lines 1444 1434 -10
====================================================
- Hits 1240 1229 -11
- Misses 204 205 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Transition
sTransition
s, fix some tests
@@ -114,6 +127,24 @@ end | |||
@test length(unique(c[:s])) == 1 | |||
end | |||
|
|||
@testset "addlogprob leads to reweighting" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was sort of tested in PG-within-Gibbs, but we didn't have a PG-only test
Gibbs(:b => PG(10), :x => ESS()), | ||
Gibbs(:b => PG(20), :x => ESS()), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test was failing on 1.10 due to numerical inaccuracy. Kind of unsure why it was only failing on 1.10 but not 1.11 given that we were using StableRNGs. My first guess would be the rng splitting in AdvancedPS.
I just bumped the atol up anyway because this test is so wonky (really we're mostly checking that it samples at all, since the results are incorrect depending on interpretation of model).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just one request for improving one of the tests.
) | ||
# Avoid mutating vi as it may be used later e.g. when constructing | ||
# sampler states. | ||
vi = deepcopy(vi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this could be a significant time cost. In which case we could make sure we have a proper copy
method for varinfos in DPPL. Would probably be good to have that anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does copy
tend to be more performant than deepcopy
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it can be. Depends on your data structures. deepcopy
can be slow because it plays it safe with aliasing and such, whereas copy
does whatever you make it do.
test/mcmc/particle_mcmc.jl
Outdated
end | ||
c = sample(addlogprob_demo(), PG(10), 100) | ||
# Result should be biased towards x > 0. | ||
@test mean(c[:x]) > 0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@test mean(c[:x]) > 0.5 | |
@test mean(c[:x]) > 0.7 |
Could this get a bit more clearance from 0.5, with, if necessary, the @addloprob!
value being increased? Otherwise there's 50% chance this will pass just by luck.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically, x is ~ Normal()
so the chance of it being larger than 0.5 is probably lower than 50%. That does assume that the chain has fully mixed though, which might not be valid with 100 iterations, and there's no harm to making the addlogprob value larger, so I'll do that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I don't know why I didn't put in StableRNGs here. Will do so too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I wasn't thinking. I mixed up the cases of x ~ Normal()
and x ~ Bernoulli()
.
Remaining CI failures are due to Libtask on 1.12 (3x pre), the Julia segfault bug JuliaLang/julia#54840 (1x Inference / min), and SciML/RecursiveArrayTools.jl#477 (2x everything else / min). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would like to remove the @show
, but I'll just trust that you'll do it and approve already, because otherwise I'm happy.
Co-authored-by: Markus Hauru <[email protected]>
* First efforts towards DPPL 0.37 compat, WIP * More DPPL 0.37 compat work, WIP * Add [sources] for [email protected] * Remove context argument from `LogDensityFunction` * Fix MH * Remove spurious logging * Remove residual OptimizationContext * Delete files that were removed in previous releases * Fix typo * Simplify ESS * Fix LDF * Fix Prior(), fix a couple more imports * fixes * actually fix prior * Remove extra return value from tilde_assume * fix ldf * actually fix prior * fix HMC log-density * fix ldf * fix make_evaluate_... * more fixes for evaluate!! * fix hmc * fix run_ad * even more fixes (oh goodness when will this end) * more fixes * fix * more fix fix fix * fix return values of tilde pipeline * even more fixes * Fix missing import * More MH fixes * Fix conversion * don't think it really needs those type params * implement copy for LogPriorWithoutJacAcc * Even more fixes * More fixes; I think the remaining failures are pMCMC related * Fix merge * DPPL 0.37 compat for particle MCMC (#2625) * Progress in DPPL 0.37 compat for particle MCMC * WIP PMCMC work * Gibbs fixes for DPPL 0.37 (plus tiny bugfixes for ESS + HMC) (#2628) * Obviously this single commit will make Gibbs work * Fixes for ESS * Fix HMC call * improve some comments * Fixes to ProduceLogLikelihoodAccumulator * Use LogProbAccumulator for ProduceLogLikelihoodAccumulator * use get_conditioned_gibbs --------- Co-authored-by: Penelope Yong <[email protected]> * "Fixes" for PG-in-Gibbs (#2629) * WIP PMCMC work * Fixes to ProduceLogLikelihoodAccumulator * inline definition of `set_retained_vns_del!` * Fix ProduceLogLikelihoodAcc * Remove all uses of `set_retained_vns_del!` * Use nice functions * Remove PG tests with dynamic number of Gibbs-conditioned-observations * Fix essential/container tests * Update pMCMC implementation as per discussion * remove extra printing statements * revert unneeded changes * Add back (some kind of) dynamic model test * fix rebase * Add a todo comment for dynamic model tests --------- Co-authored-by: Markus Hauru <[email protected]> * Use accumulators to fix all logp calculations when sampling (#2630) * Use new `getlogjoint` for optimisation * Change getlogjoint -> getlogjoint_internal where needed * Enforce re-evaluation when constructing `Transition` * fix tests * Remove extra evaluations from SGLD and SGHMC * Remove dead `transitions_from_chain` method (used to be part of `predict`) * metadata -> getstats_with_lp * Clean up some stray getlogp * InitContext isn't for 0.37, update comments * Fix merge * Do not re-evaluate model for Prior (#2644) * Allow Prior to skip model re-evaluation * remove unneeded `default_chain_type` method * add a test * add a likelihood term too * why not test correctness while we're at it * No need to test AD for SamplingContext{<:HMC} (#2645) * change breaking -> main * Remove calls to resetlogp!! & add changelog (#2650) * Remove calls to resetlogp!! * Add a changelog for 0.40 * Update HISTORY.md Co-authored-by: Markus Hauru <[email protected]> --------- Co-authored-by: Markus Hauru <[email protected]> * Remove `[sources]` * Unify Turing `Transition`s, fix some tests (#2651) * Unify `Transition` methods * Add tests * Add same test for SGLD/SGHMC * Refactor so that it's nice and organised * Fix failing test on 1.10 * just increase the atol * Make addlogprob test more robust * Remove stray `@show` Co-authored-by: Markus Hauru <[email protected]> --------- Co-authored-by: Markus Hauru <[email protected]> * Update changelog for PG in Gibbs --------- Co-authored-by: Penelope Yong <[email protected]>
* [no ci] Bump to v0.40.0 * Uncomment tests that should be there * Support DPPL 0.37 (#2550) * First efforts towards DPPL 0.37 compat, WIP * More DPPL 0.37 compat work, WIP * Add [sources] for [email protected] * Remove context argument from `LogDensityFunction` * Fix MH * Remove spurious logging * Remove residual OptimizationContext * Delete files that were removed in previous releases * Fix typo * Simplify ESS * Fix LDF * Fix Prior(), fix a couple more imports * fixes * actually fix prior * Remove extra return value from tilde_assume * fix ldf * actually fix prior * fix HMC log-density * fix ldf * fix make_evaluate_... * more fixes for evaluate!! * fix hmc * fix run_ad * even more fixes (oh goodness when will this end) * more fixes * fix * more fix fix fix * fix return values of tilde pipeline * even more fixes * Fix missing import * More MH fixes * Fix conversion * don't think it really needs those type params * implement copy for LogPriorWithoutJacAcc * Even more fixes * More fixes; I think the remaining failures are pMCMC related * Fix merge * DPPL 0.37 compat for particle MCMC (#2625) * Progress in DPPL 0.37 compat for particle MCMC * WIP PMCMC work * Gibbs fixes for DPPL 0.37 (plus tiny bugfixes for ESS + HMC) (#2628) * Obviously this single commit will make Gibbs work * Fixes for ESS * Fix HMC call * improve some comments * Fixes to ProduceLogLikelihoodAccumulator * Use LogProbAccumulator for ProduceLogLikelihoodAccumulator * use get_conditioned_gibbs --------- Co-authored-by: Penelope Yong <[email protected]> * "Fixes" for PG-in-Gibbs (#2629) * WIP PMCMC work * Fixes to ProduceLogLikelihoodAccumulator * inline definition of `set_retained_vns_del!` * Fix ProduceLogLikelihoodAcc * Remove all uses of `set_retained_vns_del!` * Use nice functions * Remove PG tests with dynamic number of Gibbs-conditioned-observations * Fix essential/container tests * Update pMCMC implementation as per discussion * remove extra printing statements * revert unneeded changes * Add back (some kind of) dynamic model test * fix rebase * Add a todo comment for dynamic model tests --------- Co-authored-by: Markus Hauru <[email protected]> * Use accumulators to fix all logp calculations when sampling (#2630) * Use new `getlogjoint` for optimisation * Change getlogjoint -> getlogjoint_internal where needed * Enforce re-evaluation when constructing `Transition` * fix tests * Remove extra evaluations from SGLD and SGHMC * Remove dead `transitions_from_chain` method (used to be part of `predict`) * metadata -> getstats_with_lp * Clean up some stray getlogp * InitContext isn't for 0.37, update comments * Fix merge * Do not re-evaluate model for Prior (#2644) * Allow Prior to skip model re-evaluation * remove unneeded `default_chain_type` method * add a test * add a likelihood term too * why not test correctness while we're at it * No need to test AD for SamplingContext{<:HMC} (#2645) * change breaking -> main * Remove calls to resetlogp!! & add changelog (#2650) * Remove calls to resetlogp!! * Add a changelog for 0.40 * Update HISTORY.md Co-authored-by: Markus Hauru <[email protected]> --------- Co-authored-by: Markus Hauru <[email protected]> * Remove `[sources]` * Unify Turing `Transition`s, fix some tests (#2651) * Unify `Transition` methods * Add tests * Add same test for SGLD/SGHMC * Refactor so that it's nice and organised * Fix failing test on 1.10 * just increase the atol * Make addlogprob test more robust * Remove stray `@show` Co-authored-by: Markus Hauru <[email protected]> --------- Co-authored-by: Markus Hauru <[email protected]> * Update changelog for PG in Gibbs --------- Co-authored-by: Penelope Yong <[email protected]> --------- Co-authored-by: Markus Hauru <[email protected]>
This PR removes
SGLDTransition
,SMCTransition
, andPGTransition
in favour of just using plain oldTransition
.This is pretty much the last thing I want to add to 0.40. The reason why I'd like to squeeze it in now is because it will correctly fix all the
chain[:lp]
,chain[:logprior]
, etc. for all of the samplers we have, which feels like a nice reward after all the upfront work on DynamicPPL accumulators 😄.It also adds a deepcopy to the
Transition
constructor as otherwise it would mutate the varinfo passed in which messed things up for SGLD (I'm not sure why it didn't mess anything up for the other samplers). This deepcopy used to be there previouslyTuring.jl/src/mcmc/Inference.jl
Lines 192 to 194 in d75e6f2
so I think maybe removing it in #2630 was too ambitious.
Fixes most of #2631, although properly fixing it will have to wait for the removal of SampleFromPrior.