@@ -88,7 +88,7 @@ class TestSolves:
8888 [(5 , 1 ), (5 , 5 ), (5 ,)],
8989 ids = ["b_col_vec" , "b_matrix" , "b_vec" ],
9090 )
91- @pytest .mark .parametrize ("assume_a" , ["gen" , "sym" , "pos" ], ids = str )
91+ @pytest .mark .parametrize ("assume_a" , ["gen" , "sym" , "pos" ][ 1 :: 2 ] , ids = str )
9292 @pytest .mark .parametrize ("lower" , [True , False ], ids = lambda x : f"lower={ x } " )
9393 @pytest .mark .parametrize (
9494 "overwrite_a, overwrite_b" ,
@@ -169,7 +169,11 @@ def A_func(x):
169169 b_val_c_contig = np .copy (b_val , order = "C" )
170170 res_c_contig = f (A_val_c_contig , b_val_c_contig )
171171 np .testing .assert_allclose (res_c_contig , res )
172- np .testing .assert_allclose (A_val_c_contig , A_val )
172+ # In the symmetric and positive definite cases,
173+ # we can only destroy A C-contiguous arrays by inverting `lower` at runtime
174+ assert np .allclose (A_val_c_contig , A_val ) == (
175+ not (overwrite_a and assume_a in ("sym" , "pos" ))
176+ )
173177 # b vectors are always f_contiguous if also c_contiguous
174178 assert np .allclose (b_val_c_contig , b_val ) == (
175179 not (overwrite_b and b_val_c_contig .flags .f_contiguous )
@@ -273,7 +277,7 @@ def A_func(x):
273277 b_val_c_contig = np .copy (b_val , order = "C" )
274278 res_c_contig = f (A_val_c_contig , b_val_c_contig )
275279 np .testing .assert_allclose (res_c_contig , res )
276- np .testing . assert_allclose (A_val_c_contig , A_val )
280+ assert np .allclose (A_val_c_contig , A_val )
277281 # b c_contiguous vectors are also f_contiguous and destroyable
278282 assert np .allclose (b_val_c_contig , b_val ) == (
279283 not (overwrite_b and b_val_c_contig .flags .f_contiguous )
@@ -359,7 +363,7 @@ def A_func(x):
359363 res_f_contig = f (A_val_f_contig , b_val_f_contig )
360364 np .testing .assert_allclose (res_f_contig , res )
361365 # cho_solve never destroys A
362- np .testing .assert_allclose (A_val , A_val_f_contig )
366+ np .testing .assert_allclose (A_val == A_val_f_contig )
363367 # b Should always be destroyable
364368 assert (b_val == b_val_f_contig ).all () == (not overwrite_b )
365369
@@ -368,7 +372,7 @@ def A_func(x):
368372 b_val_c_contig = np .copy (b_val , order = "C" )
369373 res_c_contig = f (A_val_c_contig , b_val_c_contig )
370374 np .testing .assert_allclose (res_c_contig , res )
371- np .testing . assert_allclose (A_val_c_contig , A_val )
375+ assert np .allclose (A_val_c_contig , A_val )
372376 # b c_contiguous vectors are also f_contiguous and destroyable
373377 assert np .allclose (b_val_c_contig , b_val ) == (
374378 not (overwrite_b and b_val_c_contig .flags .f_contiguous )
0 commit comments