@@ -1479,6 +1479,8 @@ def testLuOnZeroMatrix(self, lu):
14791479 dtype = float_types + complex_types ,
14801480 )
14811481 def testLuGrad (self , shape , dtype ):
1482+ if jtu .is_device_rocm ():
1483+ self .skipTest ("Skip on ROCm: testLuGrad. Test aborts due to HIP runtime issue" )
14821484 rng = jtu .rand_default (self .rng ())
14831485 a = rng (shape , dtype )
14841486 lu = vmap (jsp .linalg .lu ) if len (shape ) > 2 else jsp .linalg .lu
@@ -1709,6 +1711,9 @@ def testTriangularSolveSingularBatched(self):
17091711 dtype = int_types + float_types + complex_types
17101712 )
17111713 def testExpm (self , n , batch_size , dtype ):
1714+ if jtu .is_device_rocm ():
1715+ self .skipTest ("Skip on ROCm: testExpm. Test aborts due to HIP runtime issue" )
1716+
17121717 if (jtu .test_device_matches (["cuda" ]) and
17131718 _is_required_cuda_version_satisfied (12000 )):
17141719 self .skipTest ("Triggers a bug in cuda-12 b/287345077" )
@@ -1861,6 +1866,8 @@ def sp_func(a):
18611866 dtype = float_types + complex_types ,
18621867 )
18631868 def testIssue2131 (self , n , dtype ):
1869+ if jtu .is_device_rocm ():
1870+ self .skipTest ("Skip on ROCm: testIssue2131. Test aborts due to HIP runtime issue" )
18641871 args_maker_zeros = lambda : [np .zeros ((n , n ), dtype )]
18651872 osp_fun = lambda a : osp .linalg .expm (a )
18661873 jsp_fun = lambda a : jsp .linalg .expm (a )
@@ -1896,6 +1903,8 @@ def args_maker():
18961903 dtype = float_types + complex_types ,
18971904 )
18981905 def testExpmFrechet (self , n , dtype ):
1906+ if jtu .is_device_rocm ():
1907+ self .skipTest ("Skip on ROCm: testExpmFrechet. Test aborts due to HIP runtime issue" )
18991908 rng = jtu .rand_small (self .rng ())
19001909 if dtype == np .float64 or dtype == np .complex128 :
19011910 target_norms = [1.0e-2 , 2.0e-1 , 9.0e-01 , 2.0 , 3.0 ]
@@ -1934,6 +1943,9 @@ def args_maker():
19341943 dtype = float_types + complex_types ,
19351944 )
19361945 def testExpmGrad (self , n , dtype ):
1946+ if jtu .is_device_rocm ():
1947+ self .skipTest ("Skip on ROCm: testExpmGrad. Test aborts due to HIP runtime issue" )
1948+
19371949 rng = jtu .rand_small (self .rng ())
19381950 a = rng ((n , n ), dtype )
19391951 if dtype == np .float64 or dtype == np .complex128 :
0 commit comments