Skip to content

Commit 65f9d43

Browse files
Use new method argument for MvNormal to defined MvNormalSVD
1 parent 00a4ca3 commit 65f9d43

File tree

1 file changed

+2
-23
lines changed

1 file changed

+2
-23
lines changed

pymc_extras/statespace/filters/distributions.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -62,29 +62,8 @@ class MvNormalSVD(MvNormal):
6262
A JAX MvNormal robust to low-rank covariance matrices
6363
"""
6464

65-
rv_op = MvNormalSVDRV()
66-
67-
68-
try:
69-
import jax.random
70-
71-
from pytensor.link.jax.dispatch.random import jax_sample_fn
72-
73-
@jax_sample_fn.register(MvNormalSVDRV)
74-
def jax_sample_fn_mvnormal_svd(op, node):
75-
def sample_fn(rng, size, dtype, *parameters):
76-
rng_key = rng["jax_state"]
77-
rng_key, sampling_key = jax.random.split(rng_key, 2)
78-
sample = jax.random.multivariate_normal(
79-
sampling_key, *parameters, shape=size, dtype=dtype, method="svd"
80-
)
81-
rng["jax_state"] = rng_key
82-
return (rng, sample)
83-
84-
return sample_fn
85-
86-
except ImportError:
87-
pass
65+
# TODO: Remove this entirely on next PyMC release; method will be exposed directly in MvNormal
66+
rv_op = MvNormalSVDRV(method="svd")
8867

8968

9069
class LinearGaussianStateSpaceRV(SymbolicRandomVariable):

0 commit comments

Comments
 (0)