Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion pytensor/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def jax_sample_fn(op, node):
@jax_sample_fn.register(ptr.BetaRV)
@jax_sample_fn.register(ptr.DirichletRV)
@jax_sample_fn.register(ptr.PoissonRV)
@jax_sample_fn.register(ptr.MvNormalRV)
def jax_sample_fn_generic(op, node):
"""Generic JAX implementation of random variables."""
name = op.name
Expand Down Expand Up @@ -173,6 +172,20 @@ def sample_fn(rng, size, dtype, *parameters):
return sample_fn


@jax_sample_fn.register(ptr.MvNormalRV)
def jax_sample_mvnormal(op, node):
def sample_fn(rng, size, dtype, mean, cov):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = jax.random.multivariate_normal(
sampling_key, mean, cov, shape=size, dtype=dtype, method=op.method
)
rng["jax_state"] = rng_key
return (rng, sample)

return sample_fn


@jax_sample_fn.register(ptr.BernoulliRV)
def jax_sample_fn_bernoulli(op, node):
"""JAX implementation of `BernoulliRV`."""
Expand Down
19 changes: 16 additions & 3 deletions pytensor/link/numba/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,24 @@

@numba_core_rv_funcify.register(ptr.MvNormalRV)
def core_MvNormalRV(op, node):
method = op.method

@numba_basic.numba_njit
def random_fn(rng, mean, cov):
chol = np.linalg.cholesky(cov)
stdnorm = rng.normal(size=cov.shape[-1])
return np.dot(chol, stdnorm) + mean
if method == "cholesky":
A = np.linalg.cholesky(cov)

Check warning on line 152 in pytensor/link/numba/dispatch/random.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/random.py#L152

Added line #L152 was not covered by tests
elif method == "svd":
A, s, _ = np.linalg.svd(cov)
A *= np.sqrt(s)[None, :]

Check warning on line 155 in pytensor/link/numba/dispatch/random.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/random.py#L154-L155

Added lines #L154 - L155 were not covered by tests
else:
w, A = np.linalg.eigh(cov)
A *= np.sqrt(w)[None, :]

Check warning on line 158 in pytensor/link/numba/dispatch/random.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/random.py#L157-L158

Added lines #L157 - L158 were not covered by tests

out = rng.normal(size=cov.shape[-1])

Check warning on line 160 in pytensor/link/numba/dispatch/random.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/random.py#L160

Added line #L160 was not covered by tests
# out argument not working correctly: https://github.com/numba/numba/issues/9924
out[:] = np.dot(A, out)
out += mean
return out

Check warning on line 164 in pytensor/link/numba/dispatch/random.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/random.py#L162-L164

Added lines #L162 - L164 were not covered by tests

random_fn.handles_out = True
return random_fn
Expand Down
88 changes: 39 additions & 49 deletions pytensor/tensor/random/basic.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import abc
import warnings
from typing import Literal

import numpy as np
import scipy.stats as stats
from numpy import broadcast_shapes as np_broadcast_shapes
from numpy import einsum as np_einsum
from numpy import sqrt as np_sqrt
from numpy.linalg import cholesky as np_cholesky
from numpy.linalg import eigh as np_eigh
from numpy.linalg import svd as np_svd

import pytensor
from pytensor.tensor import get_vector_length, specify_shape
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.math import sqrt
Expand Down Expand Up @@ -831,27 +837,6 @@
vonmises = VonMisesRV()


def safe_multivariate_normal(mean, cov, size=None, rng=None):
"""A shape consistent multivariate normal sampler.
What we mean by "shape consistent": SciPy will return scalars when the
arguments are vectors with dimension of size 1. We require that the output
be at least 1D, so that it's consistent with the underlying random
variable.
"""
res = np.atleast_1d(
stats.multivariate_normal(mean=mean, cov=cov, allow_singular=True).rvs(
size=size, random_state=rng
)
)

if size is not None:
res = res.reshape([*size, -1])

return res


class MvNormalRV(RandomVariable):
r"""A multivariate normal random variable.
Expand All @@ -870,8 +855,17 @@
signature = "(n),(n,n)->(n)"
dtype = "floatX"
_print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}")
__props__ = ("name", "signature", "dtype", "inplace", "method")

def __call__(self, mean=None, cov=None, size=None, **kwargs):
def __init__(self, *args, method: Literal["cholesky", "svd", "eigh"], **kwargs):
super().__init__(*args, **kwargs)
if method not in ("cholesky", "svd", "eigh"):
raise ValueError(

Check warning on line 863 in pytensor/tensor/random/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/random/basic.py#L863

Added line #L863 was not covered by tests
f"Unknown method {method}. The method must be one of 'cholesky', 'svd', or 'eigh'."
)
self.method = method

def __call__(self, mean, cov, size=None, **kwargs):
r""" "Draw samples from a multivariate normal distribution.
Signature
Expand All @@ -894,38 +888,34 @@
is specified, a single `N`-dimensional sample is returned.
"""
dtype = pytensor.config.floatX if self.dtype == "floatX" else self.dtype

if mean is None:
mean = np.array([0.0], dtype=dtype)
if cov is None:
cov = np.array([[1.0]], dtype=dtype)
Comment on lines -899 to -902
Copy link
Member Author

@ricardoV94 ricardoV94 Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These were dumb defaults, just removed them: #833

It made sense when both were None perhaps, but not just one of them. Anyway, numpy doesn't provide defaults either. In PyMC we do, because we are not trying to mimick numpy API there.

return super().__call__(mean, cov, size=size, **kwargs)

@classmethod
def rng_fn(cls, rng, mean, cov, size):
if mean.ndim > 1 or cov.ndim > 2:
# Neither SciPy nor NumPy implement parameter broadcasting for
# multivariate normals (or any other multivariate distributions),
# so we need to implement that here
def rng_fn(self, rng, mean, cov, size):
if size is None:
size = np_broadcast_shapes(mean.shape[:-1], cov.shape[:-2])

if size is None:
mean, cov = broadcast_params([mean, cov], [1, 2])
else:
mean = np.broadcast_to(mean, size + mean.shape[-1:])
cov = np.broadcast_to(cov, size + cov.shape[-2:])

res = np.empty(mean.shape)
for idx in np.ndindex(mean.shape[:-1]):
m = mean[idx]
c = cov[idx]
res[idx] = safe_multivariate_normal(m, c, rng=rng)
return res
if self.method == "cholesky":
A = np_cholesky(cov)
elif self.method == "svd":
A, s, _ = np_svd(cov)
A *= np_sqrt(s, out=s)[..., None, :]
else:
return safe_multivariate_normal(mean, cov, size=size, rng=rng)
w, A = np_eigh(cov)
A *= np_sqrt(w, out=w)[..., None, :]

out = rng.normal(size=(*size, mean.shape[-1]))
np_einsum(
"...ij,...j->...i", # numpy doesn't have a batch matrix-vector product
A,
out,
optimize=False, # Nothing to optimize with two operands, skip costly setup
out=out,
)
out += mean
return out


multivariate_normal = MvNormalRV()
multivariate_normal = MvNormalRV(method="cholesky")


class DirichletRV(RandomVariable):
Expand Down
6 changes: 6 additions & 0 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
batched_permutation_tester,
batched_unweighted_choice_without_replacement_tester,
batched_weighted_choice_without_replacement_tester,
create_mvnormal_cov_decomposition_method_test,
)


Expand Down Expand Up @@ -547,6 +548,11 @@ def test_random_mvnormal():
np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1)


test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_method_test(
"JAX"
)


@pytest.mark.parametrize(
"parameter, size",
[
Expand Down
6 changes: 6 additions & 0 deletions tests/link/numba/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
batched_permutation_tester,
batched_unweighted_choice_without_replacement_tester,
batched_weighted_choice_without_replacement_tester,
create_mvnormal_cov_decomposition_method_test,
)


Expand Down Expand Up @@ -147,6 +148,11 @@ def test_multivariate_normal():
)


test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_method_test(
"NUMBA"
)


@pytest.mark.parametrize(
"rv_op, dist_args, size",
[
Expand Down
6 changes: 4 additions & 2 deletions tests/tensor/random/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,8 +778,10 @@ def rand_bool_mask(shape, rng=None):
multivariate_normal,
(
np.array([200, 250], dtype=config.floatX),
# Second covariance is invalid, to test it is not chosen
np.dstack([np.eye(2), np.eye(2) * 0, np.eye(2)]).T.astype(config.floatX)
# Second covariance is very large, to test it is not chosen
np.dstack([np.eye(2), np.eye(2) * 1000, np.eye(2)]).T.astype(
config.floatX
)
* 1e-6,
),
(3,),
Expand Down
82 changes: 72 additions & 10 deletions tests/tensor/random/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pytensor.tensor import ones, stack
from pytensor.tensor.random.basic import (
ChoiceWithoutReplacement,
MvNormalRV,
PermutationRV,
_gamma,
bernoulli,
Expand Down Expand Up @@ -521,13 +522,19 @@ def test_fn(shape, scale, **kwargs):


def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None):
if mean is None:
mean = np.array([0.0], dtype=config.floatX)
if cov is None:
cov = np.array([[1.0]], dtype=config.floatX)
if size is not None:
size = tuple(size)
return multivariate_normal.rng_fn(random_state, mean, cov, size)
rng = random_state if random_state is not None else np.random.default_rng()

if size is None:
size = np.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])

mean = np.broadcast_to(mean, (*size, *mean.shape[-1:]))
cov = np.broadcast_to(cov, (*size, *cov.shape[-2:]))

@np.vectorize(signature="(n),(n,n)->(n)")
def vec_mvnormal(mean, cov):
return rng.multivariate_normal(mean, cov, method="cholesky")

return vec_mvnormal(mean, cov)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -609,18 +616,30 @@ def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None):
),
],
)
@pytest.mark.skipif(
config.floatX == "float32",
reason="Draws are only strictly equal to numpy in float64",
)
def test_mvnormal_samples(mu, cov, size):
compare_sample_values(
multivariate_normal, mu, cov, size=size, test_fn=mvnormal_test_fn
)


def test_mvnormal_default_args():
compare_sample_values(multivariate_normal, test_fn=mvnormal_test_fn)
def test_mvnormal_no_default_args():
with pytest.raises(
TypeError, match="missing 2 required positional arguments: 'mean' and 'cov'"
):
multivariate_normal()


def test_mvnormal_impl_catches_incompatible_size():
with pytest.raises(ValueError, match="operands could not be broadcast together "):
multivariate_normal.rng_fn(
None, np.zeros((3, 2)), np.ones((3, 2, 2)), size=(4,)
np.random.default_rng(),
np.zeros((3, 2)),
np.broadcast_to(np.eye(2), (3, 2, 2)),
size=(4,),
)


Expand Down Expand Up @@ -668,6 +687,49 @@ def test_mvnormal_ShapeFeature():
assert s4.get_test_value() == 3


def create_mvnormal_cov_decomposition_method_test(mode):
@pytest.mark.parametrize("psd", (True, False))
@pytest.mark.parametrize("method", ("cholesky", "svd", "eigh"))
def test_mvnormal_cov_decomposition_method(method, psd):
mean = 2 ** np.arange(3)
if psd:
cov = [
[1, 0.5, -1],
[0.5, 2, 0],
[-1, 0, 3],
]
else:
cov = [
[1, 0.5, 0],
[0.5, 2, 0],
[0, 0, 0],
]
rng = shared(np.random.default_rng(675))
draws = MvNormalRV(method=method)(mean, cov, rng=rng, size=(10_000,))
assert draws.owner.op.method == method

# JAX doesn't raise errors at runtime
if not psd and method == "cholesky":
if mode == "JAX":
# JAX doesn't raise errors at runtime, instead it returns nan
np.isnan(draws.eval(mode=mode)).all()
else:
with pytest.raises(np.linalg.LinAlgError):
draws.eval(mode=mode)

else:
draws_eval = draws.eval(mode=mode)
np.testing.assert_allclose(np.mean(draws_eval, axis=0), mean, rtol=0.02)
np.testing.assert_allclose(np.cov(draws_eval, rowvar=False), cov, atol=0.1)

return test_mvnormal_cov_decomposition_method


test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_method_test(
None
)


@pytest.mark.parametrize(
"alphas, size",
[
Expand Down