Skip to content

Commit b96608f

Browse files
committed
Allow decomposition methods in MvNormal
1 parent bbf34cd commit b96608f

File tree

3 files changed

+74
-15
lines changed

3 files changed

+74
-15
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ def jax_sample_fn(op, node):
128128
@jax_sample_fn.register(ptr.BetaRV)
129129
@jax_sample_fn.register(ptr.DirichletRV)
130130
@jax_sample_fn.register(ptr.PoissonRV)
131-
@jax_sample_fn.register(ptr.MvNormalRV)
132131
def jax_sample_fn_generic(op, node):
133132
"""Generic JAX implementation of random variables."""
134133
name = op.name
@@ -173,6 +172,22 @@ def sample_fn(rng, size, dtype, *parameters):
173172
return sample_fn
174173

175174

175+
@jax_sample_fn.register(ptr.MvNormalRV)
176+
def jax_sample_mvnormal(op, node):
177+
"""Generic JAX implementation of random variables."""
178+
179+
def sample_fn(rng, size, dtype, mean, cov):
180+
rng_key = rng["jax_state"]
181+
rng_key, sampling_key = jax.random.split(rng_key, 2)
182+
sample = jax.random.multivariate_normal(
183+
sampling_key, mean, cov, shape=size, dtype=dtype, method=op.method
184+
)
185+
rng["jax_state"] = rng_key
186+
return (rng, sample)
187+
188+
return sample_fn
189+
190+
176191
@jax_sample_fn.register(ptr.BernoulliRV)
177192
def jax_sample_fn_bernoulli(op, node):
178193
"""JAX implementation of `BernoulliRV`."""

pytensor/tensor/random/basic.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import abc
22
import warnings
3+
from typing import Literal
34

45
import numpy as np
56
import scipy.stats as stats
67
from numpy import broadcast_shapes as np_broadcast_shapes
78
from numpy import einsum as np_einsum
9+
from numpy import sqrt as np_sqrt
810
from numpy.linalg import cholesky as np_cholesky
11+
from numpy.linalg import eigh as np_eigh
12+
from numpy.linalg import svd as np_svd
913

10-
import pytensor
1114
from pytensor.tensor import get_vector_length, specify_shape
1215
from pytensor.tensor.basic import as_tensor_variable
1316
from pytensor.tensor.math import sqrt
@@ -852,8 +855,17 @@ class MvNormalRV(RandomVariable):
852855
signature = "(n),(n,n)->(n)"
853856
dtype = "floatX"
854857
_print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}")
858+
__props__ = ("name", "signature", "dtype", "inplace", "method")
855859

856-
def __call__(self, mean=None, cov=None, size=None, **kwargs):
860+
def __init__(self, *args, method: Literal["cholesky", "svd", "eigh"], **kwargs):
861+
super().__init__(*args, **kwargs)
862+
if method not in ("cholesky", "svd", "eigh"):
863+
raise ValueError(
864+
f"Unknown method {method}. The method must be one of 'cholesky', 'svd', or 'eigh'."
865+
)
866+
self.method = method
867+
868+
def __call__(self, mean, cov, size=None, **kwargs):
857869
r""" "Draw samples from a multivariate normal distribution.
858870
859871
Signature
@@ -876,33 +888,34 @@ def __call__(self, mean=None, cov=None, size=None, **kwargs):
876888
is specified, a single `N`-dimensional sample is returned.
877889
878890
"""
879-
dtype = pytensor.config.floatX if self.dtype == "floatX" else self.dtype
880-
881-
if mean is None:
882-
mean = np.array([0.0], dtype=dtype)
883-
if cov is None:
884-
cov = np.array([[1.0]], dtype=dtype)
885891
return super().__call__(mean, cov, size=size, **kwargs)
886892

887-
@classmethod
888-
def rng_fn(cls, rng, mean, cov, size):
893+
def rng_fn(self, rng, mean, cov, size):
889894
if size is None:
890895
size = np_broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
891896

892-
chol = np_cholesky(cov)
897+
if self.method == "cholesky":
898+
A = np_cholesky(cov)
899+
elif self.method == "svd":
900+
A, s, _ = np_svd(cov)
901+
A *= np_sqrt(s, out=s)[..., None, :]
902+
else:
903+
w, A = np_eigh(cov)
904+
A *= np_sqrt(w, out=w)[..., None, :]
905+
893906
out = rng.normal(size=(*size, mean.shape[-1]))
894907
np_einsum(
895908
"...ij,...j->...i", # numpy doesn't have a batch matrix-vector product
896-
chol,
909+
A,
897910
out,
898-
out=out,
899911
optimize=False, # Nothing to optimize with two operands, skip costly setup
912+
out=out,
900913
)
901914
out += mean
902915
return out
903916

904917

905-
multivariate_normal = MvNormalRV()
918+
multivariate_normal = MvNormalRV(method="cholesky")
906919

907920

908921
class DirichletRV(RandomVariable):

tests/tensor/random/test_basic.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pytensor.tensor import ones, stack
2020
from pytensor.tensor.random.basic import (
2121
ChoiceWithoutReplacement,
22+
MvNormalRV,
2223
PermutationRV,
2324
_gamma,
2425
bernoulli,
@@ -668,6 +669,36 @@ def test_mvnormal_ShapeFeature():
668669
assert s4.get_test_value() == 3
669670

670671

672+
@pytest.mark.parametrize("psd", (True, False))
673+
@pytest.mark.parametrize("method", ("cholesky", "svd", "eigh"))
674+
def test_mvnormal_cov_decomposition_method(method, psd):
675+
mean = 2 ** np.arange(3)
676+
if psd:
677+
cov = [
678+
[1, 0.5, -1],
679+
[0.5, 2, 0],
680+
[-1, 0, 3],
681+
]
682+
else:
683+
cov = [
684+
[1, 0.5, 0],
685+
[0.5, 2, 0],
686+
[0, 0, 0],
687+
]
688+
rng = shared(np.random.default_rng(675))
689+
draws = MvNormalRV(method=method)(mean, cov, rng=rng, size=(10_000,))
690+
assert draws.owner.op.method == method
691+
692+
if not psd and method == "cholesky":
693+
with pytest.raises(np.linalg.LinAlgError):
694+
draws.eval()
695+
return
696+
697+
draws_eval = draws.eval()
698+
np.testing.assert_allclose(np.mean(draws_eval, axis=0), mean, rtol=0.01)
699+
np.testing.assert_allclose(np.cov(draws_eval, rowvar=False), cov, atol=0.1)
700+
701+
671702
@pytest.mark.parametrize(
672703
"alphas, size",
673704
[

0 commit comments

Comments
 (0)