|
18 | 18 | from ._lib._utils._helpers import asarrays, eager_shape
|
19 | 19 | from ._lib._utils._typing import Array, DType
|
20 | 20 |
|
21 |
| -__all__ = ["isclose", "nan_to_num", "one_hot", "pad"] |
| 21 | +__all__ = ["expand_dims", "isclose", "nan_to_num", "one_hot", "pad", "sinc"] |
| 22 | + |
| 23 | + |
| 24 | +def expand_dims( |
| 25 | + a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None |
| 26 | +) -> Array: |
| 27 | + """ |
| 28 | + Expand the shape of an array. |
| 29 | +
|
| 30 | + Insert (a) new axis/axes that will appear at the position(s) specified by |
| 31 | + `axis` in the expanded array shape. |
| 32 | +
|
| 33 | + This is ``xp.expand_dims`` for `axis` an int *or a tuple of ints*. |
| 34 | + Roughly equivalent to ``numpy.expand_dims`` for NumPy arrays. |
| 35 | +
|
| 36 | + Parameters |
| 37 | + ---------- |
| 38 | + a : array |
| 39 | + Array to have its shape expanded. |
| 40 | + axis : int or tuple of ints, optional |
| 41 | + Position(s) in the expanded axes where the new axis (or axes) is/are placed. |
| 42 | + If multiple positions are provided, they should be unique (note that a position |
| 43 | + given by a positive index could also be referred to by a negative index - |
| 44 | + that will also result in an error). |
| 45 | + Default: ``(0,)``. |
| 46 | + xp : array_namespace, optional |
| 47 | + The standard-compatible namespace for `a`. Default: infer. |
| 48 | +
|
| 49 | + Returns |
| 50 | + ------- |
| 51 | + array |
| 52 | + `a` with an expanded shape. |
| 53 | +
|
| 54 | + Examples |
| 55 | + -------- |
| 56 | + >>> import array_api_strict as xp |
| 57 | + >>> import array_api_extra as xpx |
| 58 | + >>> x = xp.asarray([1, 2]) |
| 59 | + >>> x.shape |
| 60 | + (2,) |
| 61 | +
|
| 62 | + The following is equivalent to ``x[xp.newaxis, :]`` or ``x[xp.newaxis]``: |
| 63 | +
|
| 64 | + >>> y = xpx.expand_dims(x, axis=0, xp=xp) |
| 65 | + >>> y |
| 66 | + Array([[1, 2]], dtype=array_api_strict.int64) |
| 67 | + >>> y.shape |
| 68 | + (1, 2) |
| 69 | +
|
| 70 | + The following is equivalent to ``x[:, xp.newaxis]``: |
| 71 | +
|
| 72 | + >>> y = xpx.expand_dims(x, axis=1, xp=xp) |
| 73 | + >>> y |
| 74 | + Array([[1], |
| 75 | + [2]], dtype=array_api_strict.int64) |
| 76 | + >>> y.shape |
| 77 | + (2, 1) |
| 78 | +
|
| 79 | + ``axis`` may also be a tuple: |
| 80 | +
|
| 81 | + >>> y = xpx.expand_dims(x, axis=(0, 1), xp=xp) |
| 82 | + >>> y |
| 83 | + Array([[[1, 2]]], dtype=array_api_strict.int64) |
| 84 | +
|
| 85 | + >>> y = xpx.expand_dims(x, axis=(2, 0), xp=xp) |
| 86 | + >>> y |
| 87 | + Array([[[1], |
| 88 | + [2]]], dtype=array_api_strict.int64) |
| 89 | + """ |
| 90 | + if xp is None: |
| 91 | + xp = array_namespace(a) |
| 92 | + |
| 93 | + if not isinstance(axis, tuple): |
| 94 | + axis = (axis,) |
| 95 | + ndim = a.ndim + len(axis) |
| 96 | + if axis != () and (min(axis) < -ndim or max(axis) >= ndim): |
| 97 | + err_msg = ( |
| 98 | + f"a provided axis position is out of bounds for array of dimension {a.ndim}" |
| 99 | + ) |
| 100 | + raise IndexError(err_msg) |
| 101 | + axis = tuple(dim % ndim for dim in axis) |
| 102 | + if len(set(axis)) != len(axis): |
| 103 | + err_msg = "Duplicate dimensions specified in `axis`." |
| 104 | + raise ValueError(err_msg) |
| 105 | + |
| 106 | + if is_numpy_namespace(xp) or is_dask_namespace(xp) or is_jax_namespace(xp): |
| 107 | + return xp.expand_dims(a, axis=axis) |
| 108 | + |
| 109 | + return _funcs.expand_dims(a, axis=axis, xp=xp) |
22 | 110 |
|
23 | 111 |
|
24 | 112 | def isclose(
|
@@ -328,6 +416,100 @@ def pad(
|
328 | 416 | return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)
|
329 | 417 |
|
330 | 418 |
|
| 419 | +def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: |
| 420 | + r""" |
| 421 | + Return the normalized sinc function. |
| 422 | +
|
| 423 | + The sinc function is equal to :math:`\sin(\pi x)/(\pi x)` for any argument |
| 424 | + :math:`x\ne 0`. ``sinc(0)`` takes the limit value 1, making ``sinc`` not |
| 425 | + only everywhere continuous but also infinitely differentiable. |
| 426 | +
|
| 427 | + .. note:: |
| 428 | +
|
| 429 | + Note the normalization factor of ``pi`` used in the definition. |
| 430 | + This is the most commonly used definition in signal processing. |
| 431 | + Use ``sinc(x / xp.pi)`` to obtain the unnormalized sinc function |
| 432 | + :math:`\sin(x)/x` that is more common in mathematics. |
| 433 | +
|
| 434 | + Parameters |
| 435 | + ---------- |
| 436 | + x : array |
| 437 | + Array (possibly multi-dimensional) of values for which to calculate |
| 438 | + ``sinc(x)``. Must have a real floating point dtype. |
| 439 | + xp : array_namespace, optional |
| 440 | + The standard-compatible namespace for `x`. Default: infer. |
| 441 | +
|
| 442 | + Returns |
| 443 | + ------- |
| 444 | + array |
| 445 | + ``sinc(x)`` calculated elementwise, which has the same shape as the input. |
| 446 | +
|
| 447 | + Notes |
| 448 | + ----- |
| 449 | + The name sinc is short for "sine cardinal" or "sinus cardinalis". |
| 450 | +
|
| 451 | + The sinc function is used in various signal processing applications, |
| 452 | + including in anti-aliasing, in the construction of a Lanczos resampling |
| 453 | + filter, and in interpolation. |
| 454 | +
|
| 455 | + For bandlimited interpolation of discrete-time signals, the ideal |
| 456 | + interpolation kernel is proportional to the sinc function. |
| 457 | +
|
| 458 | + References |
| 459 | + ---------- |
| 460 | + #. Weisstein, Eric W. "Sinc Function." From MathWorld--A Wolfram Web |
| 461 | + Resource. https://mathworld.wolfram.com/SincFunction.html |
| 462 | + #. Wikipedia, "Sinc function", |
| 463 | + https://en.wikipedia.org/wiki/Sinc_function |
| 464 | +
|
| 465 | + Examples |
| 466 | + -------- |
| 467 | + >>> import array_api_strict as xp |
| 468 | + >>> import array_api_extra as xpx |
| 469 | + >>> x = xp.linspace(-4, 4, 41) |
| 470 | + >>> xpx.sinc(x, xp=xp) |
| 471 | + Array([-3.89817183e-17, -4.92362781e-02, |
| 472 | + -8.40918587e-02, -8.90384387e-02, |
| 473 | + -5.84680802e-02, 3.89817183e-17, |
| 474 | + 6.68206631e-02, 1.16434881e-01, |
| 475 | + 1.26137788e-01, 8.50444803e-02, |
| 476 | + -3.89817183e-17, -1.03943254e-01, |
| 477 | + -1.89206682e-01, -2.16236208e-01, |
| 478 | + -1.55914881e-01, 3.89817183e-17, |
| 479 | + 2.33872321e-01, 5.04551152e-01, |
| 480 | + 7.56826729e-01, 9.35489284e-01, |
| 481 | + 1.00000000e+00, 9.35489284e-01, |
| 482 | + 7.56826729e-01, 5.04551152e-01, |
| 483 | + 2.33872321e-01, 3.89817183e-17, |
| 484 | + -1.55914881e-01, -2.16236208e-01, |
| 485 | + -1.89206682e-01, -1.03943254e-01, |
| 486 | + -3.89817183e-17, 8.50444803e-02, |
| 487 | + 1.26137788e-01, 1.16434881e-01, |
| 488 | + 6.68206631e-02, 3.89817183e-17, |
| 489 | + -5.84680802e-02, -8.90384387e-02, |
| 490 | + -8.40918587e-02, -4.92362781e-02, |
| 491 | + -3.89817183e-17], dtype=array_api_strict.float64) |
| 492 | + """ |
| 493 | + |
| 494 | + if xp is None: |
| 495 | + xp = array_namespace(x) |
| 496 | + |
| 497 | + if not xp.isdtype(x.dtype, "real floating"): |
| 498 | + err_msg = "`x` must have a real floating data type." |
| 499 | + raise ValueError(err_msg) |
| 500 | + |
| 501 | + if ( |
| 502 | + is_numpy_namespace(xp) |
| 503 | + or is_cupy_namespace(xp) |
| 504 | + or is_jax_namespace(xp) |
| 505 | + or is_torch_namespace(xp) |
| 506 | + or is_dask_namespace(xp) |
| 507 | + ): |
| 508 | + return xp.sinc(x) |
| 509 | + |
| 510 | + return _funcs.sinc(x, xp=xp) |
| 511 | + |
| 512 | + |
331 | 513 | def partition(
|
332 | 514 | a: Array,
|
333 | 515 | kth: int,
|
|
0 commit comments