18
18
19
19
import arviz as az
20
20
import jax
21
+ import jax .numpy as jnp
21
22
import numpy as np
22
23
import pytensor
23
24
import pytensor .tensor as pt
29
30
30
31
import pymc as pm
31
32
33
+ from pymc .distributions .multivariate import PosDefMatrix
32
34
from pymc .sampling .jax import (
33
35
_get_batched_jittered_initial_points ,
34
36
_get_log_likelihood ,
@@ -49,6 +51,25 @@ def test_old_import_route():
49
51
assert set (new_sj .__all__ ) <= set (dir (old_sj ))
50
52
51
53
54
+ def test_jax_PosDefMatrix ():
55
+ x = pt .tensor (name = "x" , shape = (2 , 2 ), dtype = "float32" )
56
+ matrix_pos_def = PosDefMatrix ()
57
+ x_is_pos_def = matrix_pos_def (x )
58
+ f = pytensor .function (inputs = [x ], outputs = [x_is_pos_def ], mode = "JAX" )
59
+
60
+ test_cases = [
61
+ (jnp .eye (2 ), True ),
62
+ (jnp .zeros (shape = (2 , 2 )), False ),
63
+ (jnp .array ([[1 , - 1.5 ], [0 , 1.2 ]], dtype = "float32" ), True ),
64
+ (- 1 * jnp .array ([[1 , - 1.5 ], [0 , 1.2 ]], dtype = "float32" ), False ),
65
+ (jnp .array ([[1 , - 1.5 ], [0 , - 1.2 ]], dtype = "float32" ), False ),
66
+ ]
67
+
68
+ for input , expected in test_cases :
69
+ actual = f (input )[0 ]
70
+ assert jnp .array_equal (a1 = actual , a2 = expected )
71
+
72
+
52
73
@pytest .mark .parametrize (
53
74
"sampler" ,
54
75
[
0 commit comments