From ea312c4dd116ecaef4562293aadce05c80ed1b6c Mon Sep 17 00:00:00 2001 From: aphc14 <177544929+aphc14@users.noreply.github.com> Date: Fri, 21 Feb 2025 20:19:58 +1100 Subject: [PATCH] Add to requirements-dev.txt and MvNormalSVD method --- pymc_extras/statespace/filters/distributions.py | 2 +- requirements-dev.txt | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pymc_extras/statespace/filters/distributions.py b/pymc_extras/statespace/filters/distributions.py index d3b70c847..e6901b260 100644 --- a/pymc_extras/statespace/filters/distributions.py +++ b/pymc_extras/statespace/filters/distributions.py @@ -62,7 +62,7 @@ class MvNormalSVD(MvNormal): A JAX MvNormal robust to low-rank covariance matrices """ - rv_op = MvNormalSVDRV() + rv_op = MvNormalSVDRV(method="svd") try: diff --git a/requirements-dev.txt b/requirements-dev.txt index a28518d8e..58d5ac401 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,3 +3,5 @@ blackjax # Used as benchmark for statespace models statsmodels +scikit-learn +better-optimize