@@ -367,56 +367,48 @@ def test_infer_shape(self, b_shape):
367367 warn = False ,
368368 )
369369
370+ @pytest .mark .parametrize ("b_shape" , [(5 , 1 ), (5 ,)])
370371 @pytest .mark .parametrize ("lower" , [True , False ])
371- def test_correctness (self , lower ):
372+ @pytest .mark .parametrize ("trans" , ["N" , "T" ])
373+ def test_correctness (self , b_shape : tuple [int ], lower , trans ):
372374 rng = np .random .default_rng (utt .fetch_seed ())
373375
374- b_val = np .asarray (rng .random ((5 , 1 )), dtype = config .floatX )
375-
376+ b_val = np .asarray (rng .random (b_shape ), dtype = config .floatX )
376377 A_val = np .asarray (rng .random ((5 , 5 )), dtype = config .floatX )
377378 A_val = np .dot (A_val .transpose (), A_val )
378379
379380 C_val = scipy .linalg .cholesky (A_val , lower = lower )
380381
381382 A = matrix ()
382- b = matrix ( )
383+ b = pt . tensor ( "b" , shape = b_shape )
383384
384385 cholesky = Cholesky (lower = lower )
385386 C = cholesky (A )
386- y_lower = solve_triangular (C , b , lower = lower )
387+ y_lower = solve_triangular (C , b , lower = lower , trans = trans )
387388 lower_solve_func = pytensor .function ([C , b ], y_lower )
388389
389390 assert np .allclose (
390- scipy .linalg .solve_triangular (C_val , b_val , lower = lower ),
391+ scipy .linalg .solve_triangular (C_val , b_val , lower = lower , trans = trans ),
391392 lower_solve_func (C_val , b_val ),
392393 )
393394
394- @pytest .mark .parametrize (
395- "m, n, lower" ,
396- [
397- (5 , None , False ),
398- (5 , None , True ),
399- (4 , 2 , False ),
400- (4 , 2 , True ),
401- ],
402- )
403- def test_solve_grad (self , m , n , lower ):
395+ @pytest .mark .parametrize ("b_shape" , [(5 , 1 ), (5 ,)])
396+ @pytest .mark .parametrize ("lower" , [True , False ])
397+ @pytest .mark .parametrize ("trans" , ["N" , "T" ])
398+ def test_solve_grad (self , b_shape : tuple [int ], lower , trans ):
404399 rng = np .random .default_rng (utt .fetch_seed ())
400+ m = b_shape [0 ]
405401
406402 # Ensure diagonal elements of `A` are relatively large to avoid
407403 # numerical precision issues
408404 A_val = (rng .normal (size = (m , m )) * 0.5 + np .eye (m )).astype (config .floatX )
409-
410- if n is None :
411- b_val = rng .normal (size = m ).astype (config .floatX )
412- else :
413- b_val = rng .normal (size = (m , n )).astype (config .floatX )
405+ b_val = rng .normal (size = b_shape ).astype (config .floatX )
414406
415407 eps = None
416408 if config .floatX == "float64" :
417409 eps = 2e-8
418410
419- solve_op = SolveTriangular (lower = lower , b_ndim = 1 if n is None else 2 )
411+ solve_op = SolveTriangular (lower = lower , b_ndim = len ( b_shape ), trans = trans )
420412 utt .verify_grad (solve_op , [A_val , b_val ], 3 , rng , eps = eps )
421413
422414
0 commit comments