Skip to content

Commit ea98b18

Browse files
committed
ENH: a simple cov
1 parent 6e596d9 commit ea98b18

File tree

3 files changed

+79
-5
lines changed

3 files changed

+79
-5
lines changed

src/array_api_extra/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

3-
from ._funcs import atleast_nd
3+
from ._funcs import atleast_nd, cov
44

55
__version__ = "0.1.2.dev0"
66

7-
__all__ = ["__version__", "atleast_nd"]
7+
__all__ = ["__version__", "atleast_nd", "cov"]

src/array_api_extra/_funcs.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from __future__ import annotations
22

3+
import warnings
34
from typing import TYPE_CHECKING
45

56
if TYPE_CHECKING:
67
from ._typing import Array, ModuleType
78

8-
__all__ = ["atleast_nd"]
9+
__all__ = ["atleast_nd", "cov"]
910

1011

1112
def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array:
@@ -46,3 +47,48 @@ def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array:
4647
x = xp.expand_dims(x, axis=0)
4748
x = atleast_nd(x, ndim=ndim, xp=xp)
4849
return x
50+
51+
52+
def cov(x: Array, *, xp: ModuleType) -> Array:
53+
"""..."""
54+
x = xp.asarray(x, copy=True)
55+
dtype = (
56+
xp.float64 if xp.isdtype(x.dtype, "integral") else xp.result_type(x, xp.float64)
57+
)
58+
59+
x = atleast_nd(x, ndim=2, xp=xp)
60+
x = xp.astype(x, dtype)
61+
62+
avg = mean(x, axis=1, xp=xp)
63+
fact = x.shape[1] - 1
64+
65+
if fact <= 0:
66+
warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2)
67+
fact = 0.0
68+
69+
x -= avg[:, None]
70+
y_transpose = x.T
71+
if xp.isdtype(y_transpose.dtype, "complex floating"):
72+
y_transpose = xp.conj(y_transpose)
73+
c = x @ y_transpose
74+
c /= fact
75+
axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1)
76+
return xp.squeeze(c, axis=axes)
77+
78+
79+
def mean(
80+
x: Array,
81+
/,
82+
*,
83+
axis: int | tuple[int, ...] | None = None,
84+
keepdims: bool = False,
85+
xp: ModuleType,
86+
) -> Array:
87+
"""..."""
88+
if xp.isdtype(x.dtype, "complex floating"):
89+
x_real = xp.real(x)
90+
x_imag = xp.imag(x)
91+
mean_real = xp.mean(x_real, axis=axis, keepdims=keepdims)
92+
mean_imag = xp.mean(x_imag, axis=axis, keepdims=keepdims)
93+
return mean_real + (mean_imag * xp.asarray(1j))
94+
return xp.mean(x, axis=axis, keepdims=keepdims)

tests/test_funcs.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

3+
import warnings
4+
35
# array-api-strict#6
46
import array_api_strict as xp # type: ignore[import-untyped]
5-
from numpy.testing import assert_array_equal
7+
from numpy.testing import assert_allclose, assert_array_equal
68

7-
from array_api_extra import atleast_nd
9+
from array_api_extra import atleast_nd, cov
810

911

1012
class TestAtLeastND:
@@ -67,3 +69,29 @@ def test_5D(self):
6769

6870
y = atleast_nd(x, ndim=9, xp=xp)
6971
assert_array_equal(y, xp.ones((1, 1, 1, 1, 1, 1, 1, 1, 1)))
72+
73+
74+
class TestCov:
75+
def test_basic(self):
76+
assert_allclose(
77+
cov(xp.asarray([[0, 2], [1, 1], [2, 0]]).T, xp=xp),
78+
xp.asarray([[1.0, -1.0], [-1.0, 1.0]]),
79+
)
80+
81+
def test_complex(self):
82+
x = xp.asarray([[1, 2, 3], [1j, 2j, 3j]])
83+
res = xp.asarray([[1.0, -1.0j], [1.0j, 1.0]])
84+
assert_allclose(cov(x, xp=xp), res)
85+
86+
def test_empty(self):
87+
with warnings.catch_warnings(record=True):
88+
warnings.simplefilter("always", RuntimeWarning)
89+
assert_array_equal(cov(xp.asarray([]), xp=xp), xp.nan)
90+
assert_array_equal(
91+
cov(xp.reshape(xp.asarray([]), (0, 2)), xp=xp),
92+
xp.reshape(xp.asarray([]), (0, 0)),
93+
)
94+
assert_array_equal(
95+
cov(xp.reshape(xp.asarray([]), (2, 0)), xp=xp),
96+
xp.asarray([[xp.nan, xp.nan], [xp.nan, xp.nan]]),
97+
)

0 commit comments

Comments
 (0)