Skip to content

Commit 5fc49f5

Browse files
author
Adrián García Pitarch
committed
use explicit dtypes for cov tests
1 parent 88224c3 commit 5fc49f5

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tests/test_funcs.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -404,32 +404,32 @@ def test_none(self, args: tuple[tuple[float | None, ...], ...]):
404404
class TestCov:
405405
def test_basic(self, xp: ModuleType):
406406
xp_assert_close(
407-
cov(xp.asarray([[0, 2], [1, 1], [2, 0]]).T),
407+
cov(xp.asarray([[0, 2], [1, 1], [2, 0]], dtype=xp.float64).T),
408408
xp.asarray([[1.0, -1.0], [-1.0, 1.0]], dtype=xp.float64),
409409
)
410410

411411
def test_complex(self, xp: ModuleType):
412-
actual = cov(xp.asarray([[1, 2, 3], [1j, 2j, 3j]]))
412+
actual = cov(xp.asarray([[1, 2, 3], [1j, 2j, 3j]], dtype=xp.complex128))
413413
expect = xp.asarray([[1.0, -1.0j], [1.0j, 1.0]], dtype=xp.complex128)
414414
xp_assert_close(actual, expect)
415415

416416
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="sparse#877")
417417
def test_empty(self, xp: ModuleType):
418418
with warnings.catch_warnings(record=True):
419419
warnings.simplefilter("always", RuntimeWarning)
420-
xp_assert_equal(cov(xp.asarray([])), xp.asarray(xp.nan, dtype=xp.float64))
420+
xp_assert_equal(cov(xp.asarray([], dtype=xp.float64)), xp.asarray(xp.nan, dtype=xp.float64))
421421
xp_assert_equal(
422-
cov(xp.reshape(xp.asarray([]), (0, 2))),
422+
cov(xp.reshape(xp.asarray([], dtype=xp.float64), (0, 2))),
423423
xp.reshape(xp.asarray([], dtype=xp.float64), (0, 0)),
424424
)
425425
xp_assert_equal(
426-
cov(xp.reshape(xp.asarray([]), (2, 0))),
426+
cov(xp.reshape(xp.asarray([], dtype=xp.float64), (2, 0))),
427427
xp.asarray([[xp.nan, xp.nan], [xp.nan, xp.nan]], dtype=xp.float64),
428428
)
429429

430430
def test_combination(self, xp: ModuleType):
431-
x = xp.asarray([-2.1, -1, 4.3])
432-
y = xp.asarray([3, 1.1, 0.12])
431+
x = xp.asarray([-2.1, -1, 4.3], dtype=xp.float64)
432+
y = xp.asarray([3, 1.1, 0.12], dtype=xp.float64)
433433
X = xp.stack((x, y), axis=0)
434434
desired = xp.asarray([[11.71, -4.286], [-4.286, 2.144133]], dtype=xp.float64)
435435
xp_assert_close(cov(X), desired, rtol=1e-6)
@@ -443,7 +443,7 @@ def test_device(self, xp: ModuleType, device: Device):
443443
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
444444
def test_xp(self, xp: ModuleType):
445445
xp_assert_close(
446-
cov(xp.asarray([[0.0, 2.0], [1.0, 1.0], [2.0, 0.0]]).T, xp=xp),
446+
cov(xp.asarray([[0.0, 2.0], [1.0, 1.0], [2.0, 0.0]], dtype=xp.float64).T, xp=xp),
447447
xp.asarray([[1.0, -1.0], [-1.0, 1.0]], dtype=xp.float64),
448448
)
449449

0 commit comments

Comments
 (0)