Skip to content

Commit c56004c

Browse files
author
Adrián García Pitarch
committed
ENH: cov delegation
1 parent ca20f03 commit c56004c

File tree

3 files changed

+84
-70
lines changed

3 files changed

+84
-70
lines changed

src/array_api_extra/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._delegation import isclose, nan_to_num, one_hot, pad
3+
from ._delegation import cov, isclose, nan_to_num, one_hot, pad
44
from ._lib._at import at
55
from ._lib._funcs import (
66
apply_where,
77
atleast_nd,
88
broadcast_shapes,
9-
cov,
109
create_diagonal,
1110
default_dtype,
1211
expand_dims,

src/array_api_extra/_delegation.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,87 @@
1818
from ._lib._utils._helpers import asarrays
1919
from ._lib._utils._typing import Array, DType
2020

21-
__all__ = ["isclose", "nan_to_num", "one_hot", "pad"]
21+
__all__ = ["cov", "isclose", "nan_to_num", "one_hot", "pad"]
22+
23+
24+
def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
25+
"""
26+
Estimate a covariance matrix.
27+
28+
Covariance indicates the level to which two variables vary together.
29+
If we examine N-dimensional samples, :math:`X = [x_1, x_2, ... x_N]^T`,
30+
then the covariance matrix element :math:`C_{ij}` is the covariance of
31+
:math:`x_i` and :math:`x_j`. The element :math:`C_{ii}` is the variance
32+
of :math:`x_i`.
33+
34+
This provides a subset of the functionality of ``numpy.cov``.
35+
36+
Parameters
37+
----------
38+
m : array
39+
A 1-D or 2-D array containing multiple variables and observations.
40+
Each row of `m` represents a variable, and each column a single
41+
observation of all those variables.
42+
xp : array_namespace, optional
43+
The standard-compatible namespace for `m`. Default: infer.
44+
45+
Returns
46+
-------
47+
array
48+
The covariance matrix of the variables.
49+
50+
Examples
51+
--------
52+
>>> import array_api_strict as xp
53+
>>> import array_api_extra as xpx
54+
55+
Consider two variables, :math:`x_0` and :math:`x_1`, which
56+
correlate perfectly, but in opposite directions:
57+
58+
>>> x = xp.asarray([[0, 2], [1, 1], [2, 0]]).T
59+
>>> x
60+
Array([[0, 1, 2],
61+
[2, 1, 0]], dtype=array_api_strict.int64)
62+
63+
Note how :math:`x_0` increases while :math:`x_1` decreases. The covariance
64+
matrix shows this clearly:
65+
66+
>>> xpx.cov(x, xp=xp)
67+
Array([[ 1., -1.],
68+
[-1., 1.]], dtype=array_api_strict.float64)
69+
70+
Note that element :math:`C_{0,1}`, which shows the correlation between
71+
:math:`x_0` and :math:`x_1`, is negative.
72+
73+
Further, note how `x` and `y` are combined:
74+
75+
>>> x = xp.asarray([-2.1, -1, 4.3])
76+
>>> y = xp.asarray([3, 1.1, 0.12])
77+
>>> X = xp.stack((x, y), axis=0)
78+
>>> xpx.cov(X, xp=xp)
79+
Array([[11.71 , -4.286 ],
80+
[-4.286 , 2.14413333]], dtype=array_api_strict.float64)
81+
82+
>>> xpx.cov(x, xp=xp)
83+
Array(11.71, dtype=array_api_strict.float64)
84+
85+
>>> xpx.cov(y, xp=xp)
86+
Array(2.14413333, dtype=array_api_strict.float64)
87+
"""
88+
89+
if xp is None:
90+
xp = array_namespace(m)
91+
92+
if (
93+
is_numpy_namespace(xp)
94+
or is_cupy_namespace(xp)
95+
or is_torch_namespace(xp)
96+
or is_dask_namespace(xp)
97+
or is_jax_namespace(xp)
98+
):
99+
return xp.cov(m)
100+
101+
return _funcs.cov(m, xp=xp)
22102

23103

24104
def isclose(

src/array_api_extra/_lib/_funcs.py

Lines changed: 2 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -281,73 +281,8 @@ def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ...
281281
return tuple(out)
282282

283283

284-
def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
285-
"""
286-
Estimate a covariance matrix.
287-
288-
Covariance indicates the level to which two variables vary together.
289-
If we examine N-dimensional samples, :math:`X = [x_1, x_2, ... x_N]^T`,
290-
then the covariance matrix element :math:`C_{ij}` is the covariance of
291-
:math:`x_i` and :math:`x_j`. The element :math:`C_{ii}` is the variance
292-
of :math:`x_i`.
293-
294-
This provides a subset of the functionality of ``numpy.cov``.
295-
296-
Parameters
297-
----------
298-
m : array
299-
A 1-D or 2-D array containing multiple variables and observations.
300-
Each row of `m` represents a variable, and each column a single
301-
observation of all those variables.
302-
xp : array_namespace, optional
303-
The standard-compatible namespace for `m`. Default: infer.
304-
305-
Returns
306-
-------
307-
array
308-
The covariance matrix of the variables.
309-
310-
Examples
311-
--------
312-
>>> import array_api_strict as xp
313-
>>> import array_api_extra as xpx
314-
315-
Consider two variables, :math:`x_0` and :math:`x_1`, which
316-
correlate perfectly, but in opposite directions:
317-
318-
>>> x = xp.asarray([[0, 2], [1, 1], [2, 0]]).T
319-
>>> x
320-
Array([[0, 1, 2],
321-
[2, 1, 0]], dtype=array_api_strict.int64)
322-
323-
Note how :math:`x_0` increases while :math:`x_1` decreases. The covariance
324-
matrix shows this clearly:
325-
326-
>>> xpx.cov(x, xp=xp)
327-
Array([[ 1., -1.],
328-
[-1., 1.]], dtype=array_api_strict.float64)
329-
330-
Note that element :math:`C_{0,1}`, which shows the correlation between
331-
:math:`x_0` and :math:`x_1`, is negative.
332-
333-
Further, note how `x` and `y` are combined:
334-
335-
>>> x = xp.asarray([-2.1, -1, 4.3])
336-
>>> y = xp.asarray([3, 1.1, 0.12])
337-
>>> X = xp.stack((x, y), axis=0)
338-
>>> xpx.cov(X, xp=xp)
339-
Array([[11.71 , -4.286 ],
340-
[-4.286 , 2.14413333]], dtype=array_api_strict.float64)
341-
342-
>>> xpx.cov(x, xp=xp)
343-
Array(11.71, dtype=array_api_strict.float64)
344-
345-
>>> xpx.cov(y, xp=xp)
346-
Array(2.14413333, dtype=array_api_strict.float64)
347-
"""
348-
if xp is None:
349-
xp = array_namespace(m)
350-
284+
def cov(m: Array, /, *, xp: ModuleType) -> Array: # numpydoc ignore=PR01,RT01
285+
"""See docstring in array_api_extra._delegation."""
351286
m = xp.asarray(m, copy=True)
352287
dtype = (
353288
xp.float64 if xp.isdtype(m.dtype, "integral") else xp.result_type(m, xp.float64)

0 commit comments

Comments
 (0)