@@ -214,6 +214,22 @@ def test_solve_raises_on_invalid_A():
214214 Solve (assume_a = "test" , b_ndim = 2 )
215215
216216
217+ solve_test_cases = [
218+ ("gen" , False , False ),
219+ ("gen" , False , True ),
220+ ("sym" , False , False ),
221+ ("sym" , True , False ),
222+ ("sym" , True , True ),
223+ ("pos" , False , False ),
224+ ("pos" , True , False ),
225+ ("pos" , True , True ),
226+ ]
227+ solve_test_ids = [
228+ f'{ assume_a } _{ "lower" if lower else "upper" } _{ "A^T" if transposed else "A" } '
229+ for assume_a , lower , transposed in solve_test_cases
230+ ]
231+
232+
217233class TestSolve (utt .InferShapeTester ):
218234 @pytest .mark .parametrize ("b_shape" , [(5 , 1 ), (5 ,)])
219235 def test_infer_shape (self , b_shape ):
@@ -235,16 +251,26 @@ def test_infer_shape(self, b_shape):
235251 @pytest .mark .parametrize (
236252 "b_size" , [(5 , 1 ), (5 , 5 ), (5 ,)], ids = ["b_col_vec" , "b_matrix" , "b_vec" ]
237253 )
238- @pytest .mark .parametrize ("assume_a" , ["gen" , "sym" , "pos" ], ids = str )
239- def test_solve_correctness (self , b_size : tuple [int ], assume_a : str ):
254+ @pytest .mark .parametrize (
255+ "assume_a, lower, transposed" , solve_test_cases , ids = solve_test_ids
256+ )
257+ def test_solve_correctness (
258+ self , b_size : tuple [int ], assume_a : str , lower : bool , transposed : bool
259+ ):
240260 rng = np .random .default_rng (utt .fetch_seed ())
241261 A = pt .tensor ("A" , shape = (5 , 5 ))
242262 b = pt .tensor ("b" , shape = b_size )
243263
244264 A_val = rng .normal (size = (5 , 5 )).astype (config .floatX )
245265 b_val = rng .normal (size = b_size ).astype (config .floatX )
246266
247- solve_op = functools .partial (solve , assume_a = assume_a , b_ndim = len (b_size ))
267+ solve_op = functools .partial (
268+ solve ,
269+ assume_a = assume_a ,
270+ lower = lower ,
271+ transposed = transposed ,
272+ b_ndim = len (b_size ),
273+ )
248274
249275 def A_func (x ):
250276 if assume_a == "pos" :
@@ -254,6 +280,11 @@ def A_func(x):
254280 else :
255281 return x
256282
283+ def T (x ):
284+ if transposed :
285+ return x .T
286+ return x
287+
257288 solve_input_val = A_func (A_val )
258289
259290 y = solve_op (A_func (A ), b )
@@ -264,30 +295,27 @@ def A_func(x):
264295 RTOL = 1e-8 if config .floatX .endswith ("64" ) else 1e-4
265296
266297 np .testing .assert_allclose (
267- scipy .linalg .solve (solve_input_val , b_val , assume_a = assume_a ),
298+ scipy .linalg .solve (
299+ solve_input_val ,
300+ b_val ,
301+ assume_a = assume_a ,
302+ transposed = transposed ,
303+ lower = lower ,
304+ ),
268305 X_np ,
269306 atol = ATOL ,
270307 rtol = RTOL ,
271308 )
272309
273- np .testing .assert_allclose (A_func (A_val ) @ X_np , b_val , atol = ATOL , rtol = RTOL )
310+ np .testing .assert_allclose (T ( A_func (A_val ) ) @ X_np , b_val , atol = ATOL , rtol = RTOL )
274311
275312 @pytest .mark .parametrize (
276313 "b_size" , [(5 , 1 ), (5 , 5 ), (5 ,)], ids = ["b_col_vec" , "b_matrix" , "b_vec" ]
277314 )
278315 @pytest .mark .parametrize (
279316 "assume_a, lower, transposed" ,
280- [
281- ("gen" , False , False ),
282- ("gen" , False , True ),
283- ("sym" , False , False ),
284- ("sym" , True , False ),
285- ("sym" , True , True ),
286- ("pos" , False , False ),
287- ("pos" , True , False ),
288- ("pos" , True , True ),
289- ],
290- ids = str ,
317+ solve_test_cases ,
318+ ids = solve_test_ids ,
291319 )
292320 @pytest .mark .skipif (
293321 config .floatX == "float32" , reason = "Gradients not numerically stable in float32"
0 commit comments