Skip to content

Commit 6f7acf5

Browse files
Blockwise PosDefMatrix check
1 parent 68d2ff2 commit 6f7acf5

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

pymc/distributions/multivariate.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
get_underlying_scalar_constant_value,
3434
sigmoid,
3535
)
36+
from pytensor.tensor.blockwise import Blockwise
3637
from pytensor.tensor.elemwise import DimShuffle
3738
from pytensor.tensor.exceptions import NotScalarConstantError
3839
from pytensor.tensor.linalg import cholesky, det, eigh, solve_triangular, trace
@@ -923,18 +924,15 @@ def posdef(AA):
923924
class PosDefMatrix(Op):
924925
"""Check if input is positive definite. Input should be a square matrix."""
925926

926-
# Properties attribute
927927
__props__ = ()
928-
929-
# Compulsory if itypes and otypes are not defined
928+
gufunc_signature = "(m,m)->()"
930929

931930
def make_node(self, x):
932931
x = pt.as_tensor_variable(x)
933932
assert x.ndim == 2
934933
o = TensorType(dtype="bool", shape=[])()
935934
return Apply(self, [x], [o])
936935

937-
# Python implementation:
938936
def perform(self, node, inputs, outputs):
939937
(x,) = inputs
940938
(z,) = outputs
@@ -955,7 +953,7 @@ def __str__(self):
955953
return "MatrixIsPositiveDefinite"
956954

957955

958-
matrix_pos_def = PosDefMatrix()
956+
matrix_pos_def = Blockwise(PosDefMatrix())
959957

960958

961959
class WishartRV(RandomVariable):

0 commit comments

Comments
 (0)