-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Speedup sample and allow specifying compile_kwargs (several major changes related to step samplers)
#7578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
009959b to
161b859
Compare
d6f9e14 to
87fd299
Compare
874ae65 to
cb8d51e
Compare
|
10 minutes seem to be saved in pytest CI time compared to previous runs |
39efbd5 to
d752070
Compare
0c8da27 to
d2f9cf9
Compare
f5be4ab to
95ce8bc
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7578 +/- ##
==========================================
- Coverage 92.84% 92.84% -0.01%
==========================================
Files 106 106
Lines 17686 17719 +33
==========================================
+ Hits 16421 16451 +30
- Misses 1265 1268 +3
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, left a few small comments
Also initialize empty trace and set `trust_input=True`
Also removes default `model.check_start_vals()`
95ce8bc to
bd232d2
Compare
compile_kwargs (several major changes related to step samplers)
compile_kwargs (several major changes related to step samplers)sample and allow specifying compile_kwargs (several major changes related to step samplers)
Major changes
ravel_inputsis specified explicitly. Eventually it will only be possible to useravel_inputs=True.assign_step_methoddoes not callinstantiate_steppers, but returns arguments needed for the latter.compile_kwargstopm.samplewhich is then forwarded to the step samplers functionsEnhancement
This PR speedups
NUTS(and other step samplers), by:trust_input=Truewhich can have a large overhead.This PR speedups sample by:
init_nuts. This will also reduce the path towards external samplers with nutpie/numpyro as it avoids the costly and useless compilation of the logp_dlogp_functiontrust_inputand avoiding deepcopies in the trace function by usingpytensor.In(borrow=True)andpytensor.Out(borrow=True).Further speedups should come for free from #7539, specially for the Numba backend.
Benchmark
In the example below, sampling time is now only 7x slower than nutpie (5s vs 0.7s), compared to 13.5x slower (9.45s vs 0.7s) before. This assuming the same number of logp evals, in fact nutpie tuning allows us to get out with half the evals! We can hopefully bring it over.
Full time until from
pm.sampleto getting a trace is roughly halved as well (7.5s vs 14.4s), although this gain is not proportional to the number of draws.With
compile_kwargs=(mode="NUMBA"), sampling time is only 3x slower (2.3s).📚 Documentation preview 📚: https://pymc--7578.org.readthedocs.build/en/7578/