|
18 | 18 | from ._lib._utils._helpers import asarrays
|
19 | 19 | from ._lib._utils._typing import Array, DType
|
20 | 20 |
|
21 |
| -__all__ = ["cov", "isclose", "nan_to_num", "one_hot", "pad"] |
| 21 | +__all__ = [ |
| 22 | + "cov", |
| 23 | + "expand_dims", |
| 24 | + "isclose", |
| 25 | + "nan_to_num", |
| 26 | + "one_hot", |
| 27 | + "pad", |
| 28 | + "sinc", |
| 29 | +] |
22 | 30 |
|
23 | 31 |
|
24 | 32 | def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
|
@@ -101,6 +109,94 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
|
101 | 109 | return _funcs.cov(m, xp=xp)
|
102 | 110 |
|
103 | 111 |
|
| 112 | +def expand_dims( |
| 113 | + a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None |
| 114 | +) -> Array: |
| 115 | + """ |
| 116 | + Expand the shape of an array. |
| 117 | +
|
| 118 | + Insert (a) new axis/axes that will appear at the position(s) specified by |
| 119 | + `axis` in the expanded array shape. |
| 120 | +
|
| 121 | + This is ``xp.expand_dims`` for `axis` an int *or a tuple of ints*. |
| 122 | + Roughly equivalent to ``numpy.expand_dims`` for NumPy arrays. |
| 123 | +
|
| 124 | + Parameters |
| 125 | + ---------- |
| 126 | + a : array |
| 127 | + Array to have its shape expanded. |
| 128 | + axis : int or tuple of ints, optional |
| 129 | + Position(s) in the expanded axes where the new axis (or axes) is/are placed. |
| 130 | + If multiple positions are provided, they should be unique (note that a position |
| 131 | + given by a positive index could also be referred to by a negative index - |
| 132 | + that will also result in an error). |
| 133 | + Default: ``(0,)``. |
| 134 | + xp : array_namespace, optional |
| 135 | + The standard-compatible namespace for `a`. Default: infer. |
| 136 | +
|
| 137 | + Returns |
| 138 | + ------- |
| 139 | + array |
| 140 | + `a` with an expanded shape. |
| 141 | +
|
| 142 | + Examples |
| 143 | + -------- |
| 144 | + >>> import array_api_strict as xp |
| 145 | + >>> import array_api_extra as xpx |
| 146 | + >>> x = xp.asarray([1, 2]) |
| 147 | + >>> x.shape |
| 148 | + (2,) |
| 149 | +
|
| 150 | + The following is equivalent to ``x[xp.newaxis, :]`` or ``x[xp.newaxis]``: |
| 151 | +
|
| 152 | + >>> y = xpx.expand_dims(x, axis=0, xp=xp) |
| 153 | + >>> y |
| 154 | + Array([[1, 2]], dtype=array_api_strict.int64) |
| 155 | + >>> y.shape |
| 156 | + (1, 2) |
| 157 | +
|
| 158 | + The following is equivalent to ``x[:, xp.newaxis]``: |
| 159 | +
|
| 160 | + >>> y = xpx.expand_dims(x, axis=1, xp=xp) |
| 161 | + >>> y |
| 162 | + Array([[1], |
| 163 | + [2]], dtype=array_api_strict.int64) |
| 164 | + >>> y.shape |
| 165 | + (2, 1) |
| 166 | +
|
| 167 | + ``axis`` may also be a tuple: |
| 168 | +
|
| 169 | + >>> y = xpx.expand_dims(x, axis=(0, 1), xp=xp) |
| 170 | + >>> y |
| 171 | + Array([[[1, 2]]], dtype=array_api_strict.int64) |
| 172 | +
|
| 173 | + >>> y = xpx.expand_dims(x, axis=(2, 0), xp=xp) |
| 174 | + >>> y |
| 175 | + Array([[[1], |
| 176 | + [2]]], dtype=array_api_strict.int64) |
| 177 | + """ |
| 178 | + if xp is None: |
| 179 | + xp = array_namespace(a) |
| 180 | + |
| 181 | + if not isinstance(axis, tuple): |
| 182 | + axis = (axis,) |
| 183 | + ndim = a.ndim + len(axis) |
| 184 | + if axis != () and (min(axis) < -ndim or max(axis) >= ndim): |
| 185 | + err_msg = ( |
| 186 | + f"a provided axis position is out of bounds for array of dimension {a.ndim}" |
| 187 | + ) |
| 188 | + raise IndexError(err_msg) |
| 189 | + axis = tuple(dim % ndim for dim in axis) |
| 190 | + if len(set(axis)) != len(axis): |
| 191 | + err_msg = "Duplicate dimensions specified in `axis`." |
| 192 | + raise ValueError(err_msg) |
| 193 | + |
| 194 | + if is_numpy_namespace(xp) or is_dask_namespace(xp) or is_jax_namespace(xp): |
| 195 | + return xp.expand_dims(a, axis=axis) |
| 196 | + |
| 197 | + return _funcs.expand_dims(a, axis=axis, xp=xp) |
| 198 | + |
| 199 | + |
104 | 200 | def isclose(
|
105 | 201 | a: Array | complex,
|
106 | 202 | b: Array | complex,
|
@@ -406,3 +502,97 @@ def pad(
|
406 | 502 | return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
|
407 | 503 |
|
408 | 504 | return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)
|
| 505 | + |
| 506 | + |
| 507 | +def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: |
| 508 | + r""" |
| 509 | + Return the normalized sinc function. |
| 510 | +
|
| 511 | + The sinc function is equal to :math:`\sin(\pi x)/(\pi x)` for any argument |
| 512 | + :math:`x\ne 0`. ``sinc(0)`` takes the limit value 1, making ``sinc`` not |
| 513 | + only everywhere continuous but also infinitely differentiable. |
| 514 | +
|
| 515 | + .. note:: |
| 516 | +
|
| 517 | + Note the normalization factor of ``pi`` used in the definition. |
| 518 | + This is the most commonly used definition in signal processing. |
| 519 | + Use ``sinc(x / xp.pi)`` to obtain the unnormalized sinc function |
| 520 | + :math:`\sin(x)/x` that is more common in mathematics. |
| 521 | +
|
| 522 | + Parameters |
| 523 | + ---------- |
| 524 | + x : array |
| 525 | + Array (possibly multi-dimensional) of values for which to calculate |
| 526 | + ``sinc(x)``. Must have a real floating point dtype. |
| 527 | + xp : array_namespace, optional |
| 528 | + The standard-compatible namespace for `x`. Default: infer. |
| 529 | +
|
| 530 | + Returns |
| 531 | + ------- |
| 532 | + array |
| 533 | + ``sinc(x)`` calculated elementwise, which has the same shape as the input. |
| 534 | +
|
| 535 | + Notes |
| 536 | + ----- |
| 537 | + The name sinc is short for "sine cardinal" or "sinus cardinalis". |
| 538 | +
|
| 539 | + The sinc function is used in various signal processing applications, |
| 540 | + including in anti-aliasing, in the construction of a Lanczos resampling |
| 541 | + filter, and in interpolation. |
| 542 | +
|
| 543 | + For bandlimited interpolation of discrete-time signals, the ideal |
| 544 | + interpolation kernel is proportional to the sinc function. |
| 545 | +
|
| 546 | + References |
| 547 | + ---------- |
| 548 | + #. Weisstein, Eric W. "Sinc Function." From MathWorld--A Wolfram Web |
| 549 | + Resource. https://mathworld.wolfram.com/SincFunction.html |
| 550 | + #. Wikipedia, "Sinc function", |
| 551 | + https://en.wikipedia.org/wiki/Sinc_function |
| 552 | +
|
| 553 | + Examples |
| 554 | + -------- |
| 555 | + >>> import array_api_strict as xp |
| 556 | + >>> import array_api_extra as xpx |
| 557 | + >>> x = xp.linspace(-4, 4, 41) |
| 558 | + >>> xpx.sinc(x, xp=xp) |
| 559 | + Array([-3.89817183e-17, -4.92362781e-02, |
| 560 | + -8.40918587e-02, -8.90384387e-02, |
| 561 | + -5.84680802e-02, 3.89817183e-17, |
| 562 | + 6.68206631e-02, 1.16434881e-01, |
| 563 | + 1.26137788e-01, 8.50444803e-02, |
| 564 | + -3.89817183e-17, -1.03943254e-01, |
| 565 | + -1.89206682e-01, -2.16236208e-01, |
| 566 | + -1.55914881e-01, 3.89817183e-17, |
| 567 | + 2.33872321e-01, 5.04551152e-01, |
| 568 | + 7.56826729e-01, 9.35489284e-01, |
| 569 | + 1.00000000e+00, 9.35489284e-01, |
| 570 | + 7.56826729e-01, 5.04551152e-01, |
| 571 | + 2.33872321e-01, 3.89817183e-17, |
| 572 | + -1.55914881e-01, -2.16236208e-01, |
| 573 | + -1.89206682e-01, -1.03943254e-01, |
| 574 | + -3.89817183e-17, 8.50444803e-02, |
| 575 | + 1.26137788e-01, 1.16434881e-01, |
| 576 | + 6.68206631e-02, 3.89817183e-17, |
| 577 | + -5.84680802e-02, -8.90384387e-02, |
| 578 | + -8.40918587e-02, -4.92362781e-02, |
| 579 | + -3.89817183e-17], dtype=array_api_strict.float64) |
| 580 | + """ |
| 581 | + |
| 582 | + if xp is None: |
| 583 | + xp = array_namespace(x) |
| 584 | + |
| 585 | + if not xp.isdtype(x.dtype, "real floating"): |
| 586 | + err_msg = "`x` must have a real floating data type." |
| 587 | + raise ValueError(err_msg) |
| 588 | + |
| 589 | + if ( |
| 590 | + is_numpy_namespace(xp) |
| 591 | + or is_cupy_namespace(xp) |
| 592 | + or is_jax_namespace(xp) |
| 593 | + or is_torch_namespace(xp) |
| 594 | + or is_dask_namespace(xp) |
| 595 | + ): |
| 596 | + return xp.sinc(x) |
| 597 | + |
| 598 | + return _funcs.sinc(x, xp=xp) |
0 commit comments