77import numpy as np
88import scipy .linalg as scipy_linalg
99from numpy .exceptions import ComplexWarning
10+ from scipy .linalg ._misc import _datacopied
1011
1112import pytensor
1213import pytensor .tensor as pt
@@ -37,7 +38,7 @@ def __init__(
3738 self ,
3839 * ,
3940 lower : bool = True ,
40- check_finite : bool = True ,
41+ check_finite : bool = False ,
4142 on_error : Literal ["raise" , "nan" ] = "raise" ,
4243 overwrite_a : bool = False ,
4344 ):
@@ -64,21 +65,52 @@ def make_node(self, x):
6465 dtype = scipy_linalg .cholesky (np .eye (1 , dtype = x .type .dtype )).dtype
6566 return Apply (self , [x ], [tensor (shape = x .type .shape , dtype = dtype )])
6667
68+ def _cholesky (
69+ self , a , lower = False , overwrite_a = False , clean = True , check_finite = False
70+ ):
71+ a1 = np .asarray_chkfinite (a ) if check_finite else np .asarray (a )
72+
73+ # Squareness check
74+ if a1 .shape [0 ] != a1 .shape [1 ]:
75+ raise ValueError (
76+ "Input array is expected to be square but has "
77+ f"the shape: { a1 .shape } ."
78+ )
79+
80+ # Quick return for square empty array
81+ if a1 .size == 0 :
82+ dt = self ._cholesky (np .eye (1 , dtype = a1 .dtype )).dtype
83+ return np .empty_like (a1 , dtype = dt ), lower
84+
85+ overwrite_a = overwrite_a or _datacopied (a1 , a )
86+ (potrf ,) = scipy_linalg .get_lapack_funcs (("potrf" ,), (a1 ,))
87+ c , info = potrf (a1 , lower = lower , overwrite_a = overwrite_a , clean = clean )
88+ if info > 0 :
89+ raise scipy_linalg .LinAlgError (
90+ f"{ info } -th leading minor of the array is not positive definite"
91+ )
92+ if info < 0 :
93+ raise ValueError (
94+ f"LAPACK reported an illegal value in { - info } -th argument "
95+ f'on entry to "POTRF".'
96+ )
97+ return c
98+
6799 def perform (self , node , inputs , outputs ):
68100 [x ] = inputs
69101 [out ] = outputs
70102 try :
71103 # Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
72104 # If we have a `C_CONTIGUOUS` array we transpose to benefit from it
73105 if self .overwrite_a and x .flags ["C_CONTIGUOUS" ]:
74- out [0 ] = scipy_linalg . cholesky (
106+ out [0 ] = self . _cholesky (
75107 x .T ,
76108 lower = not self .lower ,
77109 check_finite = self .check_finite ,
78110 overwrite_a = True ,
79111 ).T
80112 else :
81- out [0 ] = scipy_linalg . cholesky (
113+ out [0 ] = self . _cholesky (
82114 x ,
83115 lower = self .lower ,
84116 check_finite = self .check_finite ,
@@ -201,7 +233,9 @@ def cholesky(
201233
202234 """
203235
204- return Blockwise (Cholesky (lower = lower , on_error = on_error ))(x )
236+ return Blockwise (
237+ Cholesky (lower = lower , on_error = on_error , check_finite = check_finite )
238+ )(x )
205239
206240
207241class SolveBase (Op ):
0 commit comments