1414
1515floatX = pytensor .config .floatX
1616
17- ATOL = 0 if floatX .endswith ("64" ) else 1e-6
18- RTOL = 1e-7 if floatX .endswith ("64" ) else 1e-6
17+ ATOL = 1e-8 if floatX .endswith ("64" ) else 1e-4
18+ RTOL = 1e-8 if floatX .endswith ("64" ) else 1e-4
1919rng = np .random .default_rng (42849 )
2020
2121
@@ -42,9 +42,7 @@ def transpose_func(x, trans):
4242@pytest .mark .filterwarnings (
4343 'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
4444)
45- def test_solve_triangular (
46- b_func , b_size , lower , trans , unit_diag , complex , overwrite_b
47- ):
45+ def test_solve_triangular (b_func , b_size , lower , trans , unit_diag , complex ):
4846 if complex :
4947 # TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous,
5048 # why?
@@ -62,11 +60,11 @@ def test_solve_triangular(
6260 f = pytensor .function ([A , b ], X , mode = "NUMBA" )
6361
6462 A_val = np .random .normal (size = (5 , 5 ))
65- b = np .random .normal (size = b_size )
63+ b_val = np .random .normal (size = b_size )
6664
6765 if complex :
6866 A_val = A_val + np .random .normal (size = (5 , 5 )) * 1j
69- b = b + np .random .normal (size = b_size ) * 1j
67+ b_val = b_val + np .random .normal (size = b_size ) * 1j
7068 A_sym = A_val @ A_val .conj ().T
7169
7270 A_tri = np .linalg .cholesky (A_sym ).astype (dtype )
@@ -76,19 +74,16 @@ def test_solve_triangular(
7674 A_tri = A_tri * adj_mat
7775
7876 A_tri = A_tri .astype (dtype )
79- b = b .astype (dtype )
77+ b_val = b_val .astype (dtype )
8078
8179 if not lower :
8280 A_tri = A_tri .T
8381
84- X_np = f (A_tri , b )
82+ X_np = f (A_tri , b_val )
8583 np .testing .assert_allclose (
86- transpose_func (A_tri , trans ) @ X_np , b , atol = ATOL , rtol = RTOL
84+ transpose_func (A_tri , trans ) @ X_np , b_val , atol = ATOL , rtol = RTOL
8785 )
8886
89- if overwrite_b :
90- assert_allclose (X_np , b )
91-
9287
9388@pytest .mark .parametrize ("value" , [np .nan , np .inf ])
9489@pytest .mark .filterwarnings (
@@ -100,11 +95,11 @@ def test_solve_triangular_raises_on_nan_inf(value):
10095
10196 X = pt .linalg .solve_triangular (A , b , check_finite = True )
10297 f = pytensor .function ([A , b ], X , mode = "NUMBA" )
103- A_val = np .random .normal (size = (5 , 5 ))
98+ A_val = np .random .normal (size = (5 , 5 )). astype ( floatX )
10499 A_sym = A_val @ A_val .conj ().T
105100
106101 A_tri = np .linalg .cholesky (A_sym ).astype (floatX )
107- b = np .full ((5 , 1 ), value )
102+ b = np .full ((5 , 1 ), value ). astype ( floatX )
108103
109104 with pytest .raises (
110105 np .linalg .LinAlgError ,
@@ -126,8 +121,8 @@ def test_numba_Cholesky(lower, trans):
126121
127122 fg = FunctionGraph (outputs = [chol ])
128123
129- x = np .array ([0.1 , 0.2 , 0.3 ])
130- val = np .eye (3 ) + x [None , :] * x [:, None ]
124+ x = np .array ([0.1 , 0.2 , 0.3 ]). astype ( floatX )
125+ val = np .eye (3 ). astype ( floatX ) + x [None , :] * x [:, None ]
131126
132127 compare_numba_and_py (fg , [val ])
133128
@@ -385,4 +380,5 @@ def test_cho_solve(b_func, b_size, lower):
385380 b = b .astype (floatX )
386381
387382 X_np = f (A , b )
383+
388384 np .testing .assert_allclose (A @ X_np , b , atol = ATOL , rtol = RTOL )
0 commit comments