Skip to content

Commit 28e6f86

Browse files
committed
DOC: cov, mean: WIP docstrings
1 parent 0d05944 commit 28e6f86

File tree

1 file changed

+99
-3
lines changed

1 file changed

+99
-3
lines changed

src/array_api_extra/_funcs.py

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,65 @@ def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array:
4949
return x
5050

5151

52-
def cov(m: Array, *, xp: ModuleType) -> Array:
53-
"""..."""
52+
def cov(m: Array, /, *, xp: ModuleType) -> Array:
53+
"""
54+
Estimate a covariance matrix, given data and weights.
55+
56+
Covariance indicates the level to which two variables vary together.
57+
If we examine N-dimensional samples, :math:`X = [x_1, x_2, ... x_N]^T`,
58+
then the covariance matrix element :math:`C_{ij}` is the covariance of
59+
:math:`x_i` and :math:`x_j`. The element :math:`C_{ii}` is the variance
60+
of :math:`x_i`.
61+
62+
This provides a subset of the functionality of ``numpy.cov``.
63+
64+
Parameters
65+
----------
66+
m : array
67+
A 1-D or 2-D array containing multiple variables and observations.
68+
Each row of `m` represents a variable, and each column a single
69+
observation of all those variables.
70+
71+
Returns
72+
-------
73+
res : array
74+
The covariance matrix of the variables.
75+
76+
Examples
77+
--------
78+
>>> import numpy as np
79+
80+
Consider two variables, :math:`x_0` and :math:`x_1`, which
81+
correlate perfectly, but in opposite directions:
82+
83+
>>> x = np.array([[0, 2], [1, 1], [2, 0]]).T
84+
>>> x
85+
array([[0, 1, 2],
86+
[2, 1, 0]])
87+
88+
Note how :math:`x_0` increases while :math:`x_1` decreases. The covariance
89+
matrix shows this clearly:
90+
91+
>>> np.cov(x)
92+
array([[ 1., -1.],
93+
[-1., 1.]])
94+
95+
Note that element :math:`C_{0,1}`, which shows the correlation between
96+
:math:`x_0` and :math:`x_1`, is negative.
97+
98+
Further, note how `x` and `y` are combined:
99+
100+
>>> x = [-2.1, -1, 4.3]
101+
>>> y = [3, 1.1, 0.12]
102+
>>> X = np.stack((x, y), axis=0)
103+
>>> np.cov(X)
104+
array([[11.71 , -4.286 ], # may vary
105+
[-4.286 , 2.144133]])
106+
>>> np.cov(x)
107+
array(11.71)
108+
>>> xp.cov(y)
109+
110+
"""
54111
m = xp.asarray(m, copy=True)
55112
dtype = (
56113
xp.float64 if xp.isdtype(m.dtype, "integral") else xp.result_type(m, xp.float64)
@@ -84,7 +141,46 @@ def mean(
84141
keepdims: bool = False,
85142
xp: ModuleType,
86143
) -> Array:
87-
"""..."""
144+
"""
145+
Calculates the arithmetic mean of the input array ``x``.
146+
147+
In addition to the standard ``mean``, this function supports complex-valued input.
148+
149+
Parameters
150+
----------
151+
x: array
152+
input array. Should have a floating-point data type.
153+
axis: int or tuple of ints, optional
154+
axis or axes along which arithmetic means must be computed.
155+
By default, the mean must be computed over the entire array.
156+
If a tuple of integers, arithmetic means must be computed over multiple axes.
157+
Default: ``None``.
158+
keepdims: bool, optional
159+
if ``True``, the reduced axes (dimensions) must be included in the result as
160+
singleton dimensions, and, accordingly, the result must be compatible with
161+
the input array (see :ref:`broadcasting`).
162+
Otherwise, if ``False``, the reduced axes (dimensions) must not be included
163+
in the result. Default: ``False``.
164+
165+
Returns
166+
-------
167+
out: array
168+
if the arithmetic mean was computed over the entire array,
169+
a zero-dimensional array containing the arithmetic mean;
170+
otherwise, a non-zero-dimensional array containing the arithmetic means.
171+
The returned array must have the same data type as ``x``.
172+
173+
Notes
174+
-----
175+
176+
**Special Cases**
177+
178+
Let ``N`` equal the number of elements over which to compute the arithmetic mean.
179+
180+
- If ``N`` is ``0``, the arithmetic mean is ``NaN``.
181+
- If ``x_i`` is ``NaN``, the arithmetic mean is ``NaN``
182+
(i.e., ``NaN`` values propagate).
183+
"""
88184
if xp.isdtype(x.dtype, "complex floating"):
89185
x_real = xp.real(x)
90186
x_imag = xp.imag(x)

0 commit comments

Comments
 (0)