Skip to content

Commit 1a3bfd9

Browse files
juanitorduzricardoV94
authored andcommitted
Make PosDefMatrix Op output type boolean
1 parent 3fbe9a9 commit 1a3bfd9

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

pymc/distributions/multivariate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -849,9 +849,9 @@ def dist(cls, *args, **kwargs):
849849
def posdef(AA):
850850
try:
851851
linalg.cholesky(AA)
852-
return 1
852+
return True
853853
except linalg.LinAlgError:
854-
return 0
854+
return False
855855

856856

857857
class PosDefMatrix(Op):
@@ -868,15 +868,15 @@ class PosDefMatrix(Op):
868868
def make_node(self, x):
869869
x = pt.as_tensor_variable(x)
870870
assert x.ndim == 2
871-
o = TensorType(dtype="int8", shape=[])()
871+
o = TensorType(dtype="bool", shape=[])()
872872
return Apply(self, [x], [o])
873873

874874
# Python implementation:
875875
def perform(self, node, inputs, outputs):
876876
(x,) = inputs
877877
(z,) = outputs
878878
try:
879-
z[0] = np.array(posdef(x), dtype="int8")
879+
z[0] = np.array(posdef(x), dtype="bool")
880880
except Exception:
881881
pm._log.exception("Failed to check if %s positive definite", x)
882882
raise

tests/distributions/test_multivariate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2130,10 +2130,10 @@ def test_car_rng_fn(sparse):
21302130
@pytest.mark.parametrize(
21312131
"matrix, result",
21322132
[
2133-
([[1.0, 0], [0, 1]], 1),
2134-
([[1.0, 2], [2, 1]], 0),
2135-
([[1.0, 1], [1, 1]], 0),
2136-
([[1, 0.99, 1], [0.99, 1, 0.999], [1, 0.999, 1]], 0),
2133+
([[1.0, 0], [0, 1]], True),
2134+
([[1.0, 2], [2, 1]], False),
2135+
([[1.0, 1], [1, 1]], False),
2136+
([[1, 0.99, 1], [0.99, 1, 0.999], [1, 0.999, 1]], False),
21372137
],
21382138
)
21392139
def test_posdef_symmetric(matrix, result):

0 commit comments

Comments
 (0)