@@ -169,7 +169,12 @@ def test_eigvalsh_grad():
169169 )
170170
171171
172- class TestSolveBase (utt .InferShapeTester ):
172+ class TestSolveBase :
173+ class SolveTest (SolveBase ):
174+ def perform (self , node , inputs , outputs ):
175+ A , b = inputs
176+ outputs [0 ][0 ] = scipy .linalg .solve (A , b )
177+
173178 @pytest .mark .parametrize (
174179 "A_func, b_func, error_message" ,
175180 [
@@ -191,16 +196,16 @@ def test_make_node(self, A_func, b_func, error_message):
191196 with pytest .raises (ValueError , match = error_message ):
192197 A = A_func ()
193198 b = b_func ()
194- SolveBase (b_ndim = 2 )(A , b )
199+ self . SolveTest (b_ndim = 2 )(A , b )
195200
196201 def test__repr__ (self ):
197202 np .random .default_rng (utt .fetch_seed ())
198203 A = matrix ()
199204 b = matrix ()
200- y = SolveBase (b_ndim = 2 )(A , b )
205+ y = self . SolveTest (b_ndim = 2 )(A , b )
201206 assert (
202207 y .__repr__ ()
203- == "SolveBase {lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
208+ == "SolveTest {lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
204209 )
205210
206211
@@ -239,8 +244,9 @@ def test_correctness(self):
239244 A_val = np .asarray (rng .random ((5 , 5 )), dtype = config .floatX )
240245 A_val = np .dot (A_val .transpose (), A_val )
241246
242- assert np .allclose (
243- scipy .linalg .solve (A_val , b_val ), gen_solve_func (A_val , b_val )
247+ np .testing .assert_allclose (
248+ scipy .linalg .solve (A_val , b_val , assume_a = "gen" ),
249+ gen_solve_func (A_val , b_val ),
244250 )
245251
246252 A_undef = np .array (
@@ -253,7 +259,7 @@ def test_correctness(self):
253259 ],
254260 dtype = config .floatX ,
255261 )
256- assert np .allclose (
262+ np .testing . assert_allclose (
257263 scipy .linalg .solve (A_undef , b_val ), gen_solve_func (A_undef , b_val )
258264 )
259265
@@ -450,7 +456,7 @@ def test_solve_dtype(self):
450456 fn = function ([A , b ], x )
451457 x_result = fn (A_val .astype (A_dtype ), b_val .astype (b_dtype ))
452458
453- assert x .dtype == x_result .dtype
459+ assert x .dtype == x_result .dtype , ( A_dtype , b_dtype )
454460
455461
456462def test_cho_solve ():
0 commit comments