@@ -413,10 +413,12 @@ def test_complex(self, xp: ModuleType):
413
413
expect = xp .asarray ([[1.0 , - 1.0j ], [1.0j , 1.0 ]], dtype = xp .complex128 )
414
414
xp_assert_close (actual , expect )
415
415
416
+ @pytest .mark .xfail_xp_backend (Backend .JAX , reason = "jax#32296" )
416
417
@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "sparse#877" )
417
418
def test_empty (self , xp : ModuleType ):
418
419
with warnings .catch_warnings (record = True ):
419
420
warnings .simplefilter ("always" , RuntimeWarning )
421
+ warnings .simplefilter ("always" , UserWarning )
420
422
xp_assert_equal (cov (xp .asarray ([], dtype = xp .float64 )), xp .asarray (xp .nan , dtype = xp .float64 ))
421
423
xp_assert_equal (
422
424
cov (xp .reshape (xp .asarray ([], dtype = xp .float64 ), (0 , 2 ))),
@@ -436,6 +438,7 @@ def test_combination(self, xp: ModuleType):
436
438
xp_assert_close (cov (x ), xp .asarray (11.71 , dtype = xp .float64 ))
437
439
xp_assert_close (cov (y ), xp .asarray (2.144133 , dtype = xp .float64 ), rtol = 1e-6 )
438
440
441
+ @pytest .mark .xfail_xp_backend (Backend .TORCH , reason = "array-api-extra#455" )
439
442
def test_device (self , xp : ModuleType , device : Device ):
440
443
x = xp .asarray ([1 , 2 , 3 ], device = device )
441
444
assert get_device (cov (x )) == device
0 commit comments