Skip to content

Commit 8d5af47

Browse files
committed
Merge remote-tracking branch 'upstream/main' into partition
2 parents c2827da + 747f994 commit 8d5af47

File tree

4 files changed

+199
-167
lines changed

4 files changed

+199
-167
lines changed

src/array_api_extra/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
from ._delegation import (
44
argpartition,
5+
expand_dims,
56
isclose,
67
nan_to_num,
78
one_hot,
89
pad,
910
partition,
11+
sinc,
1012
)
1113
from ._lib._at import at
1214
from ._lib._funcs import (
@@ -16,11 +18,9 @@
1618
cov,
1719
create_diagonal,
1820
default_dtype,
19-
expand_dims,
2021
kron,
2122
nunique,
2223
setdiff1d,
23-
sinc,
2424
)
2525
from ._lib._lazy import lazy_apply
2626

src/array_api_extra/_delegation.py

Lines changed: 183 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,95 @@
1818
from ._lib._utils._helpers import asarrays, eager_shape
1919
from ._lib._utils._typing import Array, DType
2020

21-
__all__ = ["isclose", "nan_to_num", "one_hot", "pad"]
21+
__all__ = ["expand_dims", "isclose", "nan_to_num", "one_hot", "pad", "sinc"]
22+
23+
24+
def expand_dims(
25+
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None
26+
) -> Array:
27+
"""
28+
Expand the shape of an array.
29+
30+
Insert (a) new axis/axes that will appear at the position(s) specified by
31+
`axis` in the expanded array shape.
32+
33+
This is ``xp.expand_dims`` for `axis` an int *or a tuple of ints*.
34+
Roughly equivalent to ``numpy.expand_dims`` for NumPy arrays.
35+
36+
Parameters
37+
----------
38+
a : array
39+
Array to have its shape expanded.
40+
axis : int or tuple of ints, optional
41+
Position(s) in the expanded axes where the new axis (or axes) is/are placed.
42+
If multiple positions are provided, they should be unique (note that a position
43+
given by a positive index could also be referred to by a negative index -
44+
that will also result in an error).
45+
Default: ``(0,)``.
46+
xp : array_namespace, optional
47+
The standard-compatible namespace for `a`. Default: infer.
48+
49+
Returns
50+
-------
51+
array
52+
`a` with an expanded shape.
53+
54+
Examples
55+
--------
56+
>>> import array_api_strict as xp
57+
>>> import array_api_extra as xpx
58+
>>> x = xp.asarray([1, 2])
59+
>>> x.shape
60+
(2,)
61+
62+
The following is equivalent to ``x[xp.newaxis, :]`` or ``x[xp.newaxis]``:
63+
64+
>>> y = xpx.expand_dims(x, axis=0, xp=xp)
65+
>>> y
66+
Array([[1, 2]], dtype=array_api_strict.int64)
67+
>>> y.shape
68+
(1, 2)
69+
70+
The following is equivalent to ``x[:, xp.newaxis]``:
71+
72+
>>> y = xpx.expand_dims(x, axis=1, xp=xp)
73+
>>> y
74+
Array([[1],
75+
[2]], dtype=array_api_strict.int64)
76+
>>> y.shape
77+
(2, 1)
78+
79+
``axis`` may also be a tuple:
80+
81+
>>> y = xpx.expand_dims(x, axis=(0, 1), xp=xp)
82+
>>> y
83+
Array([[[1, 2]]], dtype=array_api_strict.int64)
84+
85+
>>> y = xpx.expand_dims(x, axis=(2, 0), xp=xp)
86+
>>> y
87+
Array([[[1],
88+
[2]]], dtype=array_api_strict.int64)
89+
"""
90+
if xp is None:
91+
xp = array_namespace(a)
92+
93+
if not isinstance(axis, tuple):
94+
axis = (axis,)
95+
ndim = a.ndim + len(axis)
96+
if axis != () and (min(axis) < -ndim or max(axis) >= ndim):
97+
err_msg = (
98+
f"a provided axis position is out of bounds for array of dimension {a.ndim}"
99+
)
100+
raise IndexError(err_msg)
101+
axis = tuple(dim % ndim for dim in axis)
102+
if len(set(axis)) != len(axis):
103+
err_msg = "Duplicate dimensions specified in `axis`."
104+
raise ValueError(err_msg)
105+
106+
if is_numpy_namespace(xp) or is_dask_namespace(xp) or is_jax_namespace(xp):
107+
return xp.expand_dims(a, axis=axis)
108+
109+
return _funcs.expand_dims(a, axis=axis, xp=xp)
22110

23111

24112
def isclose(
@@ -328,6 +416,100 @@ def pad(
328416
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)
329417

330418

419+
def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
420+
r"""
421+
Return the normalized sinc function.
422+
423+
The sinc function is equal to :math:`\sin(\pi x)/(\pi x)` for any argument
424+
:math:`x\ne 0`. ``sinc(0)`` takes the limit value 1, making ``sinc`` not
425+
only everywhere continuous but also infinitely differentiable.
426+
427+
.. note::
428+
429+
Note the normalization factor of ``pi`` used in the definition.
430+
This is the most commonly used definition in signal processing.
431+
Use ``sinc(x / xp.pi)`` to obtain the unnormalized sinc function
432+
:math:`\sin(x)/x` that is more common in mathematics.
433+
434+
Parameters
435+
----------
436+
x : array
437+
Array (possibly multi-dimensional) of values for which to calculate
438+
``sinc(x)``. Must have a real floating point dtype.
439+
xp : array_namespace, optional
440+
The standard-compatible namespace for `x`. Default: infer.
441+
442+
Returns
443+
-------
444+
array
445+
``sinc(x)`` calculated elementwise, which has the same shape as the input.
446+
447+
Notes
448+
-----
449+
The name sinc is short for "sine cardinal" or "sinus cardinalis".
450+
451+
The sinc function is used in various signal processing applications,
452+
including in anti-aliasing, in the construction of a Lanczos resampling
453+
filter, and in interpolation.
454+
455+
For bandlimited interpolation of discrete-time signals, the ideal
456+
interpolation kernel is proportional to the sinc function.
457+
458+
References
459+
----------
460+
#. Weisstein, Eric W. "Sinc Function." From MathWorld--A Wolfram Web
461+
Resource. https://mathworld.wolfram.com/SincFunction.html
462+
#. Wikipedia, "Sinc function",
463+
https://en.wikipedia.org/wiki/Sinc_function
464+
465+
Examples
466+
--------
467+
>>> import array_api_strict as xp
468+
>>> import array_api_extra as xpx
469+
>>> x = xp.linspace(-4, 4, 41)
470+
>>> xpx.sinc(x, xp=xp)
471+
Array([-3.89817183e-17, -4.92362781e-02,
472+
-8.40918587e-02, -8.90384387e-02,
473+
-5.84680802e-02, 3.89817183e-17,
474+
6.68206631e-02, 1.16434881e-01,
475+
1.26137788e-01, 8.50444803e-02,
476+
-3.89817183e-17, -1.03943254e-01,
477+
-1.89206682e-01, -2.16236208e-01,
478+
-1.55914881e-01, 3.89817183e-17,
479+
2.33872321e-01, 5.04551152e-01,
480+
7.56826729e-01, 9.35489284e-01,
481+
1.00000000e+00, 9.35489284e-01,
482+
7.56826729e-01, 5.04551152e-01,
483+
2.33872321e-01, 3.89817183e-17,
484+
-1.55914881e-01, -2.16236208e-01,
485+
-1.89206682e-01, -1.03943254e-01,
486+
-3.89817183e-17, 8.50444803e-02,
487+
1.26137788e-01, 1.16434881e-01,
488+
6.68206631e-02, 3.89817183e-17,
489+
-5.84680802e-02, -8.90384387e-02,
490+
-8.40918587e-02, -4.92362781e-02,
491+
-3.89817183e-17], dtype=array_api_strict.float64)
492+
"""
493+
494+
if xp is None:
495+
xp = array_namespace(x)
496+
497+
if not xp.isdtype(x.dtype, "real floating"):
498+
err_msg = "`x` must have a real floating data type."
499+
raise ValueError(err_msg)
500+
501+
if (
502+
is_numpy_namespace(xp)
503+
or is_cupy_namespace(xp)
504+
or is_jax_namespace(xp)
505+
or is_torch_namespace(xp)
506+
or is_dask_namespace(xp)
507+
):
508+
return xp.sinc(x)
509+
510+
return _funcs.sinc(x, xp=xp)
511+
512+
331513
def partition(
332514
a: Array,
333515
kth: int,

0 commit comments

Comments
 (0)