@@ -238,37 +238,46 @@ def test_infer_shape(self, b_shape):
238238 "b_size" , [(5 , 1 ), (5 , 5 ), (5 ,)], ids = ["b_col_vec" , "b_matrix" , "b_vec" ]
239239)
240240@pytest .mark .parametrize ("assume_a" , ["gen" , "sym" , "pos" ], ids = str )
241- @pytest .mark .parametrize ("transposed" , [True , False ], ids = ["trans" , "no_trans" ])
242- def test_solve_correctness (b_size : tuple [int ], assume_a : str , transposed : bool ):
241+ def test_solve_correctness (b_size : tuple [int ], assume_a : str ):
243242 rng = np .random .default_rng (utt .fetch_seed ())
244243 A = pt .tensor ("A" , shape = (5 , 5 ))
245244 b = pt .tensor ("b" , shape = b_size )
246245
247- solve_op = functools .partial (
248- solve , assume_a = assume_a , transposed = transposed , b_ndim = len (b_size )
249- )
250- y = solve_op (
251- A ,
252- b ,
253- )
254- solve_func = pytensor .function ([A , b ], y )
255-
256- b_val = rng .normal (size = b_size ).astype (config .floatX )
257246 A_val = rng .normal (size = (5 , 5 )).astype (config .floatX )
247+ b_val = rng .normal (size = b_size ).astype (config .floatX )
248+
249+ solve_op = functools .partial (solve , assume_a = assume_a , b_ndim = len (b_size ))
258250
259- if assume_a == "sym" :
260- A_val = (A_val + A_val .T ) / 2
261- elif assume_a == "pos" :
262- A_val = A_val @ A_val .T
251+ def A_func (x ):
252+ if assume_a == "pos" :
253+ return x @ x .T
254+ elif assume_a == "sym" :
255+ return (x + x .T ) / 2
256+ else :
257+ return x
258+
259+ solve_input_val = A_func (A_val )
260+
261+ y = solve_op (A_func (A ), b )
262+ solve_func = pytensor .function ([A , b ], y )
263+ X_np = solve_func (A_val .copy (), b_val .copy ())
263264
264265 np .testing .assert_allclose (
265- scipy .linalg .solve (A_val , b_val , assume_a = assume_a , transposed = transposed ),
266- solve_func ( A_val , b_val ) ,
266+ scipy .linalg .solve (solve_input_val , b_val , assume_a = assume_a ),
267+ X_np ,
267268 )
268269
270+ np .testing .assert_allclose (A_func (A_val ) @ X_np , b_val , atol = 1e-6 )
271+
269272 eps = 2e-8 if config .floatX == "float64" else None
270273
271- utt .verify_grad (solve_op , [A_val , b_val ], 3 , rng , eps = eps )
274+ # To correctly check the gradients, we need to include a transformation from the space of unconstrained matrices
275+ # (A) to a valid input matrix for the given solver. This is done by the A_func function. If this isn't included,
276+ # the random perturbations used by verify_grad will result in invalid input matrices, and
277+ # LAPACK will silently do the wrong thing, making the gradients wrong
278+ utt .verify_grad (
279+ lambda A , b : solve_op (A_func (A ), b ), [A_val , b_val ], 3 , rng , eps = eps
280+ )
272281
273282
274283class TestSolveTriangular (utt .InferShapeTester ):
0 commit comments