-
Notifications
You must be signed in to change notification settings - Fork 72
Update and Refactor find_MAP and fit_laplace
#531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update and Refactor find_MAP and fit_laplace
#531
Conversation
c18ee5a to
2fc4b45
Compare
|
This is no longer true, find_MAP returns dummy chain, draw dims now. I thought it was too much to break the arviz promise that posterior always has chain/draw Another jank choice is the temp_chain, temp_draw thing in |
4b9ba99 to
067860f
Compare
d79d642 to
1af7049
Compare
6554ad8 to
2ea85fe
Compare
2ea85fe to
48b74f7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors the MAP-finding and Laplace approximation routines to improve performance by caching Hessian computations, reorganizes code into a laplace_approx submodule, and standardizes return types to ArviZ InferenceData.
- Cache and reuse Hessian subcomputations in
find_MAP/fit_laplaceworkflows. - Move all Laplace-related modules under
pymc_extras/inference/laplace_approx. - Update
find_MAPto returnInferenceDataand simplifyfit_laplaceinterface.
Reviewed Changes
Copilot reviewed 15 out of 18 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| tests/inference/laplace_approx/test_find_map.py | Add tests for new find_MAP defaults, PSD helper, and JAX paths |
| pymc_extras/inference/laplace_approx/scipy_interface.py | New module for compiling loss/grad/Hessian for SciPy optimizers |
| pymc_extras/inference/laplace_approx/laplace.py | Refactor fit_laplace, cache inverse Hessian, build idata |
| pymc_extras/inference/laplace_approx/find_map.py | Refactor find_MAP, wrap into InferenceData, split logic |
| pymc_extras/inference/laplace_approx/idata.py | Helpers to add data/fit/optimizer results into InferenceData |
| pymc_extras/inference/pathfinder/pathfinder.py | Update import to new add_data_to_inference_data helper |
| pymc_extras/inference/fit.py | Route fit(method="laplace") to new Laplace submodule |
| pyproject.toml | Bump better-optimize dependency to ≥0.1.4 |
Comments suppressed due to low confidence (2)
pymc_extras/inference/laplace_approx/scipy_interface.py:101
- The docstring lists
f_fusedandf_hesspas return values but the function actually returns a list of one or twoFunctionobjects. Update the doc to reflect that it returns alist[Function](or[Function, Function]).
f_fused: Function
pymc_extras/inference/laplace_approx/scipy_interface.py:53
- The return statement uses a starred expression (
return *loss_and_grad, hess), which is invalid syntax in Python. Wrap the unpacking in a tuple, e.g.:return (*loss_and_grad, hess).
return *loss_and_grad, hess
|
This should be ready to go. Last changes:
|
48b74f7 to
3f2aa8b
Compare
|
|
||
| f_constrain = pm.compile(inputs=[batched_values], outputs=batched_rvs, **compile_kwargs) | ||
| posterior_draws = f_constrain(posterior_draws) | ||
| # There are corner cases where the value_vars will not have the same dimensions as the random variable (e.g. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
arviz suggests a separate group for these cases https://python.arviz.org/en/latest/schema/schema.html#unconstrained-posterior
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But we don't have coords for those ofc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would be for all the unconstrained values yeah, not just the oddball ones?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No idea. I don't see why we have this function though (see my other comment)
| random_seed: int | np.random.Generator | None = None, | ||
| compile_kwargs: dict | None = None, | ||
| ) -> az.InferenceData: | ||
| def unstack_laplace_draws(idata, model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have this function? Didn't you make a model where each variable is already an unstacked deterministic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me double-check, might be a holdover from the old stuff that I got confused about.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The unobserved_value_vars graph already converts from latent space to constrained space, then you vectorize that with batch draws and you have everything. That's how I read it. Which if true, is nice, this function call all go, and you don't worry about coords, since you don't store constrained draws per RV, only the whole concatenated vector?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok looking it back over, the purpose of this function is to take the draws from the actual laplace approximation and also return those as part of the posterior. My logic is that users might want this for diagnostic purposes, since this is where the multivariate normal actually lives.
My impression (based on very little) is that most packages won't do the constraining transformation on the outputs at all. I have this impression because people often cite "not respecting the domain of the priors" as a reason why laplace isn't the best tool.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't that just the long flatten vector?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay. Still seems a niche place to invent this stuff here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also agree with that.
I was thinking that the right solution for this is to add a method to transformers that acts on coords. Conceptually, that seems like the right place for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the strategy point you don't need two steps, when you did the flat->constrained mapping using unobserved_value_names as outputs you could have done
flat->constrained+unconstrained mapping, using value_vars + unobserved_value_vars as outputs.
Where is this? In the call to join_nonshared_inputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pm.sample_posterior_predictive filters out value variable names here, so doing everything symbolically in one shot doesn't work.
I'm sticking with what I have for now, I really just need this PR to be done.
* Move laplace and find_map to submodule * Split idata utilities into `idata.py` * Refactor find_MAP * Refactor fit_laplace * Update better-optimize version pin * Handle labeling of non-scalar RVs without dims * Add unconstrained posterior draws/points to unconstrained_posterior
* Move laplace and find_map to submodule * Split idata utilities into `idata.py` * Refactor find_MAP * Refactor fit_laplace * Update better-optimize version pin * Handle labeling of non-scalar RVs without dims * Add unconstrained posterior draws/points to unconstrained_posterior
I just made an update to better-optimize that uses hessian matrix caching for better performance. This is something we can immediately take advantage of with
find_MAP, re-using sub-computation from the loss or gradient in the hessian computation. This PR updates the functions generated by find_MAP to take advantage of this.While I was at it, I went ahead and did some cleanup and reorganization of the code. In particular:
laplace_approxsubmodule.find_MAPnow returns an idata. This is more consistent with all the other PyMC sampling function -- it's weird to get back a dictionary in this one case.find_MAPwill now always store the inverse hessian. This is done to try to avoid an extra function compilation when it is used in conjunction withfit_laplace.fit_laplacewas a really dumb function that was inexplicably sampling from scipy distributions. This required a ton of unnecessary work. If only we had a PPL that could help sample from complicated distributions...fit_laplacestill isn't perfect. I wanted to store both the value variables and the transformed RVs as deterministics in a pymc model and sample them directly, but that doesn't appear to work -- maybe this is a bug? I ended up doing two passes, once for the constrained RVs, then a second pass for the unconstrained. It would be good to minimize that.I also removed as many little options that were floating around as possible. These function signatures were already horrible.
Finally, I eliminated a lot of test parameterizations to speed the CI up, but also added a lot of new tests for functions that were previously not covered. Hopefully it's still net positive.