|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import warnings |
| 4 | + |
3 | 5 | # array-api-strict#6 |
4 | 6 | import array_api_strict as xp # type: ignore[import-untyped] |
5 | | -from numpy.testing import assert_array_equal |
| 7 | +from numpy.testing import assert_allclose, assert_array_equal |
6 | 8 |
|
7 | | -from array_api_extra import atleast_nd |
| 9 | +from array_api_extra import atleast_nd, cov |
8 | 10 |
|
9 | 11 |
|
10 | 12 | class TestAtLeastND: |
@@ -67,3 +69,29 @@ def test_5D(self): |
67 | 69 |
|
68 | 70 | y = atleast_nd(x, ndim=9, xp=xp) |
69 | 71 | assert_array_equal(y, xp.ones((1, 1, 1, 1, 1, 1, 1, 1, 1))) |
| 72 | + |
| 73 | + |
| 74 | +class TestCov: |
| 75 | + def test_basic(self): |
| 76 | + assert_allclose( |
| 77 | + cov(xp.asarray([[0, 2], [1, 1], [2, 0]]).T, xp=xp), |
| 78 | + xp.asarray([[1.0, -1.0], [-1.0, 1.0]]), |
| 79 | + ) |
| 80 | + |
| 81 | + def test_complex(self): |
| 82 | + x = xp.asarray([[1, 2, 3], [1j, 2j, 3j]]) |
| 83 | + res = xp.asarray([[1.0, -1.0j], [1.0j, 1.0]]) |
| 84 | + assert_allclose(cov(x, xp=xp), res) |
| 85 | + |
| 86 | + def test_empty(self): |
| 87 | + with warnings.catch_warnings(record=True): |
| 88 | + warnings.simplefilter("always", RuntimeWarning) |
| 89 | + assert_array_equal(cov(xp.asarray([]), xp=xp), xp.nan) |
| 90 | + assert_array_equal( |
| 91 | + cov(xp.reshape(xp.asarray([]), (0, 2)), xp=xp), |
| 92 | + xp.reshape(xp.asarray([]), (0, 0)), |
| 93 | + ) |
| 94 | + assert_array_equal( |
| 95 | + cov(xp.reshape(xp.asarray([]), (2, 0)), xp=xp), |
| 96 | + xp.asarray([[xp.nan, xp.nan], [xp.nan, xp.nan]]), |
| 97 | + ) |
0 commit comments