Skip to content

Commit 88224c3

Browse files
author
Adrián García Pitarch
committed
Merge branch 'main' into add-cov-delegation
# Conflicts: # src/array_api_extra/__init__.py # src/array_api_extra/_delegation.py
2 parents c56004c + 747f994 commit 88224c3

File tree

3 files changed

+206
-164
lines changed

3 files changed

+206
-164
lines changed

src/array_api_extra/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._delegation import cov, isclose, nan_to_num, one_hot, pad
3+
from ._delegation import (
4+
cov,
5+
expand_dims,
6+
isclose,
7+
nan_to_num,
8+
one_hot,
9+
pad,
10+
sinc,
11+
)
412
from ._lib._at import at
513
from ._lib._funcs import (
614
apply_where,
715
atleast_nd,
816
broadcast_shapes,
917
create_diagonal,
1018
default_dtype,
11-
expand_dims,
1219
kron,
1320
nunique,
1421
setdiff1d,
15-
sinc,
1622
)
1723
from ._lib._lazy import lazy_apply
1824

src/array_api_extra/_delegation.py

Lines changed: 191 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,15 @@
1818
from ._lib._utils._helpers import asarrays
1919
from ._lib._utils._typing import Array, DType
2020

21-
__all__ = ["cov", "isclose", "nan_to_num", "one_hot", "pad"]
21+
__all__ = [
22+
"cov",
23+
"expand_dims",
24+
"isclose",
25+
"nan_to_num",
26+
"one_hot",
27+
"pad",
28+
"sinc",
29+
]
2230

2331

2432
def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
@@ -101,6 +109,94 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
101109
return _funcs.cov(m, xp=xp)
102110

103111

112+
def expand_dims(
113+
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None
114+
) -> Array:
115+
"""
116+
Expand the shape of an array.
117+
118+
Insert (a) new axis/axes that will appear at the position(s) specified by
119+
`axis` in the expanded array shape.
120+
121+
This is ``xp.expand_dims`` for `axis` an int *or a tuple of ints*.
122+
Roughly equivalent to ``numpy.expand_dims`` for NumPy arrays.
123+
124+
Parameters
125+
----------
126+
a : array
127+
Array to have its shape expanded.
128+
axis : int or tuple of ints, optional
129+
Position(s) in the expanded axes where the new axis (or axes) is/are placed.
130+
If multiple positions are provided, they should be unique (note that a position
131+
given by a positive index could also be referred to by a negative index -
132+
that will also result in an error).
133+
Default: ``(0,)``.
134+
xp : array_namespace, optional
135+
The standard-compatible namespace for `a`. Default: infer.
136+
137+
Returns
138+
-------
139+
array
140+
`a` with an expanded shape.
141+
142+
Examples
143+
--------
144+
>>> import array_api_strict as xp
145+
>>> import array_api_extra as xpx
146+
>>> x = xp.asarray([1, 2])
147+
>>> x.shape
148+
(2,)
149+
150+
The following is equivalent to ``x[xp.newaxis, :]`` or ``x[xp.newaxis]``:
151+
152+
>>> y = xpx.expand_dims(x, axis=0, xp=xp)
153+
>>> y
154+
Array([[1, 2]], dtype=array_api_strict.int64)
155+
>>> y.shape
156+
(1, 2)
157+
158+
The following is equivalent to ``x[:, xp.newaxis]``:
159+
160+
>>> y = xpx.expand_dims(x, axis=1, xp=xp)
161+
>>> y
162+
Array([[1],
163+
[2]], dtype=array_api_strict.int64)
164+
>>> y.shape
165+
(2, 1)
166+
167+
``axis`` may also be a tuple:
168+
169+
>>> y = xpx.expand_dims(x, axis=(0, 1), xp=xp)
170+
>>> y
171+
Array([[[1, 2]]], dtype=array_api_strict.int64)
172+
173+
>>> y = xpx.expand_dims(x, axis=(2, 0), xp=xp)
174+
>>> y
175+
Array([[[1],
176+
[2]]], dtype=array_api_strict.int64)
177+
"""
178+
if xp is None:
179+
xp = array_namespace(a)
180+
181+
if not isinstance(axis, tuple):
182+
axis = (axis,)
183+
ndim = a.ndim + len(axis)
184+
if axis != () and (min(axis) < -ndim or max(axis) >= ndim):
185+
err_msg = (
186+
f"a provided axis position is out of bounds for array of dimension {a.ndim}"
187+
)
188+
raise IndexError(err_msg)
189+
axis = tuple(dim % ndim for dim in axis)
190+
if len(set(axis)) != len(axis):
191+
err_msg = "Duplicate dimensions specified in `axis`."
192+
raise ValueError(err_msg)
193+
194+
if is_numpy_namespace(xp) or is_dask_namespace(xp) or is_jax_namespace(xp):
195+
return xp.expand_dims(a, axis=axis)
196+
197+
return _funcs.expand_dims(a, axis=axis, xp=xp)
198+
199+
104200
def isclose(
105201
a: Array | complex,
106202
b: Array | complex,
@@ -406,3 +502,97 @@ def pad(
406502
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
407503

408504
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)
505+
506+
507+
def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
508+
r"""
509+
Return the normalized sinc function.
510+
511+
The sinc function is equal to :math:`\sin(\pi x)/(\pi x)` for any argument
512+
:math:`x\ne 0`. ``sinc(0)`` takes the limit value 1, making ``sinc`` not
513+
only everywhere continuous but also infinitely differentiable.
514+
515+
.. note::
516+
517+
Note the normalization factor of ``pi`` used in the definition.
518+
This is the most commonly used definition in signal processing.
519+
Use ``sinc(x / xp.pi)`` to obtain the unnormalized sinc function
520+
:math:`\sin(x)/x` that is more common in mathematics.
521+
522+
Parameters
523+
----------
524+
x : array
525+
Array (possibly multi-dimensional) of values for which to calculate
526+
``sinc(x)``. Must have a real floating point dtype.
527+
xp : array_namespace, optional
528+
The standard-compatible namespace for `x`. Default: infer.
529+
530+
Returns
531+
-------
532+
array
533+
``sinc(x)`` calculated elementwise, which has the same shape as the input.
534+
535+
Notes
536+
-----
537+
The name sinc is short for "sine cardinal" or "sinus cardinalis".
538+
539+
The sinc function is used in various signal processing applications,
540+
including in anti-aliasing, in the construction of a Lanczos resampling
541+
filter, and in interpolation.
542+
543+
For bandlimited interpolation of discrete-time signals, the ideal
544+
interpolation kernel is proportional to the sinc function.
545+
546+
References
547+
----------
548+
#. Weisstein, Eric W. "Sinc Function." From MathWorld--A Wolfram Web
549+
Resource. https://mathworld.wolfram.com/SincFunction.html
550+
#. Wikipedia, "Sinc function",
551+
https://en.wikipedia.org/wiki/Sinc_function
552+
553+
Examples
554+
--------
555+
>>> import array_api_strict as xp
556+
>>> import array_api_extra as xpx
557+
>>> x = xp.linspace(-4, 4, 41)
558+
>>> xpx.sinc(x, xp=xp)
559+
Array([-3.89817183e-17, -4.92362781e-02,
560+
-8.40918587e-02, -8.90384387e-02,
561+
-5.84680802e-02, 3.89817183e-17,
562+
6.68206631e-02, 1.16434881e-01,
563+
1.26137788e-01, 8.50444803e-02,
564+
-3.89817183e-17, -1.03943254e-01,
565+
-1.89206682e-01, -2.16236208e-01,
566+
-1.55914881e-01, 3.89817183e-17,
567+
2.33872321e-01, 5.04551152e-01,
568+
7.56826729e-01, 9.35489284e-01,
569+
1.00000000e+00, 9.35489284e-01,
570+
7.56826729e-01, 5.04551152e-01,
571+
2.33872321e-01, 3.89817183e-17,
572+
-1.55914881e-01, -2.16236208e-01,
573+
-1.89206682e-01, -1.03943254e-01,
574+
-3.89817183e-17, 8.50444803e-02,
575+
1.26137788e-01, 1.16434881e-01,
576+
6.68206631e-02, 3.89817183e-17,
577+
-5.84680802e-02, -8.90384387e-02,
578+
-8.40918587e-02, -4.92362781e-02,
579+
-3.89817183e-17], dtype=array_api_strict.float64)
580+
"""
581+
582+
if xp is None:
583+
xp = array_namespace(x)
584+
585+
if not xp.isdtype(x.dtype, "real floating"):
586+
err_msg = "`x` must have a real floating data type."
587+
raise ValueError(err_msg)
588+
589+
if (
590+
is_numpy_namespace(xp)
591+
or is_cupy_namespace(xp)
592+
or is_jax_namespace(xp)
593+
or is_torch_namespace(xp)
594+
or is_dask_namespace(xp)
595+
):
596+
return xp.sinc(x)
597+
598+
return _funcs.sinc(x, xp=xp)

0 commit comments

Comments
 (0)