Skip to content

Commit 90fd8a3

Browse files
juanitorduzricardoV94
authored andcommitted
Add JAX implementation of PosDefMatrix Op
1 parent 1a3bfd9 commit 90fd8a3

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

pymc/sampling/jax.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import arviz as az
2323
import jax
24+
import jax.numpy as jnp
2425
import numpy as np
2526
import pytensor.tensor as pt
2627

@@ -37,6 +38,7 @@
3738

3839
from pymc import Model, modelcontext
3940
from pymc.backends.arviz import find_constants, find_observations
41+
from pymc.distributions.multivariate import PosDefMatrix
4042
from pymc.initial_point import StartDict
4143
from pymc.logprob.utils import CheckParameterValue
4244
from pymc.sampling.mcmc import _init_jitter
@@ -72,6 +74,15 @@ def assert_fn(value, *inps):
7274
return assert_fn
7375

7476

77+
@jax_funcify.register(PosDefMatrix)
78+
def jax_funcify_PosDefMatrix(op, **kwargs):
79+
def posdefmatrix_fn(value, *inps):
80+
no_pos_def = jnp.any(jnp.isnan(jnp.linalg.cholesky(value)))
81+
return jnp.invert(no_pos_def)
82+
83+
return posdefmatrix_fn
84+
85+
7586
def _replace_shared_variables(graph: List[TensorVariable]) -> List[TensorVariable]:
7687
"""Replace shared variables in graph by their constant values
7788

tests/sampling/test_jax.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import arviz as az
2020
import jax
21+
import jax.numpy as jnp
2122
import numpy as np
2223
import pytensor
2324
import pytensor.tensor as pt
@@ -29,6 +30,7 @@
2930

3031
import pymc as pm
3132

33+
from pymc.distributions.multivariate import PosDefMatrix
3234
from pymc.sampling.jax import (
3335
_get_batched_jittered_initial_points,
3436
_get_log_likelihood,
@@ -49,6 +51,25 @@ def test_old_import_route():
4951
assert set(new_sj.__all__) <= set(dir(old_sj))
5052

5153

54+
def test_jax_PosDefMatrix():
55+
x = pt.tensor(name="x", shape=(2, 2), dtype="float32")
56+
matrix_pos_def = PosDefMatrix()
57+
x_is_pos_def = matrix_pos_def(x)
58+
f = pytensor.function(inputs=[x], outputs=[x_is_pos_def], mode="JAX")
59+
60+
test_cases = [
61+
(jnp.eye(2), True),
62+
(jnp.zeros(shape=(2, 2)), False),
63+
(jnp.array([[1, -1.5], [0, 1.2]], dtype="float32"), True),
64+
(-1 * jnp.array([[1, -1.5], [0, 1.2]], dtype="float32"), False),
65+
(jnp.array([[1, -1.5], [0, -1.2]], dtype="float32"), False),
66+
]
67+
68+
for input, expected in test_cases:
69+
actual = f(input)[0]
70+
assert jnp.array_equal(a1=actual, a2=expected)
71+
72+
5273
@pytest.mark.parametrize(
5374
"sampler",
5475
[

0 commit comments

Comments
 (0)