Skip to content

Commit 0d05944

Browse files
committed
TST: cov: add another test
1 parent ea98b18 commit 0d05944

File tree

2 files changed

+21
-12
lines changed

2 files changed

+21
-12
lines changed

src/array_api_extra/_funcs.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,28 +49,28 @@ def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array:
4949
return x
5050

5151

52-
def cov(x: Array, *, xp: ModuleType) -> Array:
52+
def cov(m: Array, *, xp: ModuleType) -> Array:
5353
"""..."""
54-
x = xp.asarray(x, copy=True)
54+
m = xp.asarray(m, copy=True)
5555
dtype = (
56-
xp.float64 if xp.isdtype(x.dtype, "integral") else xp.result_type(x, xp.float64)
56+
xp.float64 if xp.isdtype(m.dtype, "integral") else xp.result_type(m, xp.float64)
5757
)
5858

59-
x = atleast_nd(x, ndim=2, xp=xp)
60-
x = xp.astype(x, dtype)
59+
m = atleast_nd(m, ndim=2, xp=xp)
60+
m = xp.astype(m, dtype)
6161

62-
avg = mean(x, axis=1, xp=xp)
63-
fact = x.shape[1] - 1
62+
avg = mean(m, axis=1, xp=xp)
63+
fact = m.shape[1] - 1
6464

6565
if fact <= 0:
6666
warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2)
6767
fact = 0.0
6868

69-
x -= avg[:, None]
70-
y_transpose = x.T
71-
if xp.isdtype(y_transpose.dtype, "complex floating"):
72-
y_transpose = xp.conj(y_transpose)
73-
c = x @ y_transpose
69+
m -= avg[:, None]
70+
m_transpose = m.T
71+
if xp.isdtype(m_transpose.dtype, "complex floating"):
72+
m_transpose = xp.conj(m_transpose)
73+
c = m @ m_transpose
7474
c /= fact
7575
axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1)
7676
return xp.squeeze(c, axis=axes)

tests/test_funcs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,12 @@ def test_empty(self):
9595
cov(xp.reshape(xp.asarray([]), (2, 0)), xp=xp),
9696
xp.asarray([[xp.nan, xp.nan], [xp.nan, xp.nan]]),
9797
)
98+
99+
def test_combination(self):
100+
x = xp.asarray([-2.1, -1, 4.3])
101+
y = xp.asarray([3, 1.1, 0.12])
102+
X = xp.stack((x, y), axis=0)
103+
desired = xp.asarray([[11.71, -4.286], [-4.286, 2.144133]])
104+
assert_allclose(cov(X, xp=xp), desired, rtol=1e-6)
105+
assert_allclose(cov(x, xp=xp), xp.asarray(11.71))
106+
assert_allclose(cov(y, xp=xp), xp.asarray(2.144133), rtol=1e-6)

0 commit comments

Comments
 (0)