diff --git a/pymc_extras/statespace/filters/distributions.py b/pymc_extras/statespace/filters/distributions.py index d3b70c847..e6901b260 100644 --- a/pymc_extras/statespace/filters/distributions.py +++ b/pymc_extras/statespace/filters/distributions.py @@ -62,7 +62,7 @@ class MvNormalSVD(MvNormal): A JAX MvNormal robust to low-rank covariance matrices """ - rv_op = MvNormalSVDRV() + rv_op = MvNormalSVDRV(method="svd") try: diff --git a/requirements-dev.txt b/requirements-dev.txt index a28518d8e..58d5ac401 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,3 +3,5 @@ blackjax # Used as benchmark for statespace models statsmodels +scikit-learn +better-optimize