@@ -68,35 +68,38 @@ def perform(self, node, inputs, outputs):
6868 [x ] = inputs
6969 [out ] = outputs
7070
71+ (potrf ,) = scipy_linalg .get_lapack_funcs (("potrf" ,), (x ,))
72+
7173 # Quick return for square empty array
7274 if x .size == 0 :
73- eye = np .eye (1 , dtype = x .dtype )
74- (potrf ,) = scipy_linalg .get_lapack_funcs (("potrf" ,), (eye ,))
75- c , _ = potrf (eye , lower = False , overwrite_a = False , clean = True )
76- out [0 ] = np .empty_like (x , dtype = c .dtype )
75+ out [0 ] = np .empty_like (x , dtype = potrf .dtype )
7776 return
7877
79- x1 = np .asarray_chkfinite (x ) if self .check_finite else x
78+ if self .check_finite and not np .isfinite (x ).all ():
79+ if self .on_error == "nan" :
80+ out [0 ] = np .full (x .shape , np .nan , dtype = node .outputs [0 ].type .dtype )
81+ return
82+ else :
83+ raise ValueError ("array must not contain infs or NaNs" )
8084
8185 # Squareness check
82- if x1 .shape [0 ] != x1 .shape [1 ]:
86+ if x .shape [0 ] != x .shape [1 ]:
8387 raise ValueError (
84- "Input array is expected to be square but has "
85- f"the shape: { x1 .shape } ."
88+ "Input array is expected to be square but has " f"the shape: { x .shape } ."
8689 )
8790
8891 # Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
8992 # If we have a `C_CONTIGUOUS` array we transpose to benefit from it
90- if self .overwrite_a and x .flags ["C_CONTIGUOUS" ]:
91- x1 = x1 .T
93+ c_contiguous_input = self .overwrite_a and x .flags ["C_CONTIGUOUS" ]
94+ if c_contiguous_input :
95+ x = x .T
9296 lower = not self .lower
9397 overwrite_a = True
9498 else :
9599 lower = self .lower
96100 overwrite_a = self .overwrite_a
97101
98- (potrf ,) = scipy_linalg .get_lapack_funcs (("potrf" ,), (x1 ,))
99- c , info = potrf (x1 , lower = lower , overwrite_a = overwrite_a , clean = True )
102+ c , info = potrf (x , lower = lower , overwrite_a = overwrite_a , clean = True )
100103
101104 if info != 0 :
102105 if self .on_error == "nan" :
@@ -112,7 +115,7 @@ def perform(self, node, inputs, outputs):
112115 )
113116 else :
114117 # Transpose result if input was transposed
115- out [0 ] = c .T if ( self . overwrite_a and x . flags [ "C_CONTIGUOUS" ]) else c
118+ out [0 ] = c .T if c_contiguous_input else c
116119
117120 def L_op (self , inputs , outputs , gradients ):
118121 """
0 commit comments