@@ -3415,43 +3415,27 @@ def test_tgsyl(dtype, trans, ijob):
3415
3415
err_msg = 'lhs2 and rhs2 do not match' )
3416
3416
3417
3417
3418
+ @pytest .mark .parametrize ('mtype' , ['sy' , 'he' ]) # matrix type
3418
3419
@pytest .mark .parametrize ('dtype' , DTYPES )
3419
3420
@pytest .mark .parametrize ('lower' , (0 , 1 ))
3420
- def test_sytrs (dtype , lower ):
3421
+ def test_sy_hetrs (mtype , dtype , lower ):
3422
+ if mtype == 'he' and dtype in REAL_DTYPES :
3423
+ pytest .skip ("hetrs not for real dtypes." )
3421
3424
rng = np .random .default_rng (1723059677121834 )
3422
3425
n , nrhs = 20 , 5
3423
3426
if dtype in COMPLEX_DTYPES :
3424
3427
A = (rng .uniform (size = (n , n )) + rng .uniform (size = (n , n ))* 1j ).astype (dtype )
3425
3428
else :
3426
3429
A = rng .uniform (size = (n , n )).astype (dtype )
3427
3430
3428
- A = A + A .T
3431
+ A = A + A .T if mtype == 'sy' else A + A . conj (). T
3429
3432
b = rng .uniform (size = (n , nrhs )).astype (dtype )
3430
- sytrf , sytrf_lwork , sytrs = get_lapack_funcs ([ 'sytrf ' , 'sytrf_lwork ' , 'sytrs' ],
3431
- dtype = dtype )
3432
- lwork = sytrf_lwork (n , lower = lower )
3433
- ldu , ipiv , info = sytrf (A , lwork = lwork )
3433
+ names = f' { mtype } trf ' , f' { mtype } trf_lwork ' , f' { mtype } trs'
3434
+ trf , trf_lwork , trs = get_lapack_funcs ( names , dtype = dtype )
3435
+ lwork = trf_lwork (n , lower = lower )
3436
+ ldu , ipiv , info = trf (A , lwork = lwork )
3434
3437
assert info == 0
3435
- x , info = sytrs (a = ldu , ipiv = ipiv , b = b )
3438
+ x , info = trs (a = ldu , ipiv = ipiv , b = b )
3436
3439
assert info == 0
3437
3440
eps = np .finfo (dtype ).eps
3438
3441
assert_allclose (A @x , b , atol = 100 * n * eps )
3439
-
3440
-
3441
- @pytest .mark .parametrize ('dtype' , COMPLEX_DTYPES )
3442
- @pytest .mark .parametrize ('lower' , (0 , 1 ))
3443
- def test_hetrs (dtype , lower ):
3444
- rng = np .random .default_rng (1723059677121834 )
3445
- n , nrhs = 20 , 5
3446
- A = (rng .uniform (size = (n , n )) + rng .uniform (size = (n , n ))* 1j ).astype (dtype )
3447
- A = A + A .conj ().T
3448
- b = np .random .rand (n , nrhs ).astype (dtype )
3449
- hetrf , hetrf_lwork , hetrs = get_lapack_funcs (['hetrf' , 'hetrf_lwork' , 'hetrs' ],
3450
- dtype = dtype )
3451
- lwork = hetrf_lwork (n , lower = lower )
3452
- ldu , ipiv , info = hetrf (A , lwork = lwork )
3453
- assert info == 0
3454
- x , info = hetrs (a = ldu , ipiv = ipiv , b = b )
3455
- assert info == 0
3456
- eps = np .finfo (dtype ).eps
3457
- assert_allclose (A @ x , b , atol = 100 * n * eps )
0 commit comments