- 
                Notifications
    You must be signed in to change notification settings 
- Fork 72
More bugfixes for statespace #346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More bugfixes for statespace #346
Conversation
| Check out this pull request on   See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB | 
db3bb47    to
    dede2b3      
    Compare
  
    dede2b3    to
    491795e      
    Compare
  
    | This is ready for review. I will need to make changes again after the next pymc/pytensor releases, but for now this all works -- tests pass, and all notebooks run. For review I tried to organize the commits into chunks. It would be good to have eyeballs on the changes related to the distributions, since those were creating trouble for me in the first place. | 
| No idea what's going on with the CI, there seems to be a lot of broken stuff unrelated to this. | 
Explicitly set data shape to avoid broadcasting error Better handling of measurement error dims in `SARIMAX` models Freeze auxiliary models before forward sampling Bugfixes for posterior predictive sampling helpers Allow specification of time dimension name when registering data Save info about exogenous data for post-estimation tasks Restore `_exog_data_info` member variable Be more consistent with the names of filter outputs
Modify structural tests to accommodate deterministic models Save kalman filter outputs to idata for statespace tests Remove test related to `add_exogenous` Adjust structural module tests
Remove tests related to the `add_exogenous` method Add dummy `MvNormalSVDRV` for forward jax sampling with `method="SVD"` Dynamically generate `LinearGaussianStateSpaceRV` signature from inputs Add signature and simple test for `SequenceMvNormal`
020e246    to
    ae3ecd1      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving so you can merge, let me know if you want me to look at some specific changes carefully
* Allow forward sampling of statespace models in JAX mode Explicitly set data shape to avoid broadcasting error Better handling of measurement error dims in `SARIMAX` models Freeze auxiliary models before forward sampling Bugfixes for posterior predictive sampling helpers Allow specification of time dimension name when registering data Save info about exogenous data for post-estimation tasks Restore `_exog_data_info` member variable Be more consistent with the names of filter outputs * Adjust test suite to reflect API changes Modify structural tests to accommodate deterministic models Save kalman filter outputs to idata for statespace tests Remove test related to `add_exogenous` Adjust structural module tests * Add JAX test suite * Bug-fixes and changes to statespace distributions Remove tests related to the `add_exogenous` method Add dummy `MvNormalSVDRV` for forward jax sampling with `method="SVD"` Dynamically generate `LinearGaussianStateSpaceRV` signature from inputs Add signature and simple test for `SequenceMvNormal` * Re-run example notebooks * Add helper function to sample prior/posterior statespace matrices * fix tests * Wrap jax MvNormal rewrite in try/except block * Don't use `action` keyword in `catch_warnings` * Skip JAX test if `numpyro` is not installed * Handle batch dims on `SequenceMvNormal` * Remove unused batch_dim logic in SequenceMvNormal * Restore `get_support_shape_1d` import
* First draft of quadratic approximation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Review comments incorporated * License and copyright information added * Only add additional data to inferencedata when chains!=0 * Raise error if Hessian is singular * Replace for loop with call to remove_value_transforms * Pass model directly when finding MAP and the Hessian * Update pymc_experimental/inference/laplace.py Co-authored-by: Ricardo Vieira <[email protected]> * Remove chains from public parameters for Laplace approx method * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Parameter draws is not optional with default value 1000 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add warning if numbers of variables in vars does not equal number of model variables * Update version.txt * `shock_size` should never be scalar * Blackjax API change * Handle latest PyMC/PyTensor breaking changes * Temporarily mark two tests as xfail * More bugfixes for statespace (#346) * Allow forward sampling of statespace models in JAX mode Explicitly set data shape to avoid broadcasting error Better handling of measurement error dims in `SARIMAX` models Freeze auxiliary models before forward sampling Bugfixes for posterior predictive sampling helpers Allow specification of time dimension name when registering data Save info about exogenous data for post-estimation tasks Restore `_exog_data_info` member variable Be more consistent with the names of filter outputs * Adjust test suite to reflect API changes Modify structural tests to accommodate deterministic models Save kalman filter outputs to idata for statespace tests Remove test related to `add_exogenous` Adjust structural module tests * Add JAX test suite * Bug-fixes and changes to statespace distributions Remove tests related to the `add_exogenous` method Add dummy `MvNormalSVDRV` for forward jax sampling with `method="SVD"` Dynamically generate `LinearGaussianStateSpaceRV` signature from inputs Add signature and simple test for `SequenceMvNormal` * Re-run example notebooks * Add helper function to sample prior/posterior statespace matrices * fix tests * Wrap jax MvNormal rewrite in try/except block * Don't use `action` keyword in `catch_warnings` * Skip JAX test if `numpyro` is not installed * Handle batch dims on `SequenceMvNormal` * Remove unused batch_dim logic in SequenceMvNormal * Restore `get_support_shape_1d` import * Fix failing test case for laplace --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ricardo Vieira <[email protected]> Co-authored-by: Jesse Grabowski <[email protected]> Co-authored-by: Ricardo Vieira <[email protected]> Co-authored-by: Jesse Grabowski <[email protected]>

The statespace model is still broken, this PR is another round of bugfixes.
statespace#326I need some help with fixing the JAX forward sampling. I'm doing something wrong, because even after freezing I have dynamic shape errors. This is the major blocker to considering the module "working" again.