Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ enable_error_code = ["ignore-without-code", "truthy-bool"]
# https://github.com/data-apis/array-api-typing
disallow_any_expr = false
# false positives with input validation
disable_error_code = ["redundant-expr", "unreachable"]
disable_error_code = ["redundant-expr", "unreachable", "no-any-return"]

[[tool.mypy.overrides]]
# slow/unavailable on Windows; do not add to the lint env
Expand Down
59 changes: 34 additions & 25 deletions src/array_api_extra/_lib/_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
is_jax_array,
is_writeable_array,
)
from ._utils._typing import Array, Index
from ._utils._typing import Array, SetIndex


class _AtOp(Enum):
Expand Down Expand Up @@ -43,7 +43,13 @@ def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[
return self.value


_undef = object()
class Undef(Enum):
"""Sentinel for undefined values."""

UNDEF = 0


_undef = Undef.UNDEF


class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
Expand Down Expand Up @@ -188,16 +194,16 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
"""

_x: Array
_idx: Index
_idx: SetIndex | Undef
__slots__: ClassVar[tuple[str, ...]] = ("_idx", "_x")

def __init__(
self, x: Array, idx: Index = _undef, /
self, x: Array, idx: SetIndex | Undef = _undef, /
) -> None: # numpydoc ignore=GL08
self._x = x
self._idx = idx

def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01
def __getitem__(self, idx: SetIndex, /) -> at: # numpydoc ignore=PR01,RT01
"""
Allow for the alternate syntax ``at(x)[start:stop:step]``.

Expand All @@ -212,9 +218,9 @@ def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01
def _op(
self,
at_op: _AtOp,
in_place_op: Callable[[Array, Array | object], Array] | None,
in_place_op: Callable[[Array, Array | complex], Array] | None,
out_of_place_op: Callable[[Array, Array], Array] | None,
y: Array | object,
y: Array | complex,
/,
copy: bool | None,
xp: ModuleType | None,
Expand All @@ -226,7 +232,7 @@ def _op(
----------
at_op : _AtOp
Method of JAX's Array.at[].
in_place_op : Callable[[Array, Array | object], Array] | None
in_place_op : Callable[[Array, Array | complex], Array] | None
In-place operation to apply on mutable backends::

x[idx] = in_place_op(x[idx], y)
Expand All @@ -245,7 +251,7 @@ def _op(

x = xp.where(idx, y, x)

y : array or object
y : array or complex
Right-hand side of the operation.
copy : bool or None
Whether to copy the input array. See the class docstring for details.
Expand All @@ -260,7 +266,7 @@ def _op(
x, idx = self._x, self._idx
xp = array_namespace(x, y) if xp is None else xp

if idx is _undef:
if isinstance(idx, Undef):
msg = (
"Index has not been set.\n"
"Usage: either\n"
Expand Down Expand Up @@ -306,7 +312,10 @@ def _op(
if copy or (copy is None and not writeable):
if is_jax_array(x):
# Use JAX's at[]
func = cast(Callable[[Array], Array], getattr(x.at[idx], at_op.value))
func = cast(
Callable[[Array | complex], Array],
getattr(x.at[idx], at_op.value), # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue,reportUnknownArgumentType]
)
out = func(y)
# Undo int->float promotion on JAX after _AtOp.DIVIDE
return xp.astype(out, x.dtype, copy=False)
Expand All @@ -315,10 +324,10 @@ def _op(
# with a copy followed by an update

x = xp.asarray(x, copy=True)
if writeable is False:
# A copy of a read-only numpy array is writeable
# Note: this assumes that a copy of a writeable array is writeable
writeable = None
# A copy of a read-only numpy array is writeable
# Note: this assumes that a copy of a writeable array is writeable
assert not writeable
writeable = None

if writeable is None:
writeable = is_writeable_array(x)
Expand All @@ -328,14 +337,14 @@ def _op(
raise ValueError(msg)

if in_place_op: # add(), subtract(), ...
x[self._idx] = in_place_op(x[self._idx], y)
x[idx] = in_place_op(x[idx], y)
else: # set()
x[self._idx] = y
x[idx] = y
return x

def set(
self,
y: Array | object,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
Expand All @@ -345,7 +354,7 @@ def set(

def add(
self,
y: Array | object,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
Expand All @@ -359,7 +368,7 @@ def add(

def subtract(
self,
y: Array | object,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
Expand All @@ -371,7 +380,7 @@ def subtract(

def multiply(
self,
y: Array | object,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
Expand All @@ -383,7 +392,7 @@ def multiply(

def divide(
self,
y: Array | object,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
Expand All @@ -395,7 +404,7 @@ def divide(

def power(
self,
y: Array | object,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
Expand All @@ -405,7 +414,7 @@ def power(

def min(
self,
y: Array | object,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
Expand All @@ -417,7 +426,7 @@ def min(

def max(
self,
y: Array | object,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
Expand Down
59 changes: 24 additions & 35 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import warnings
from collections.abc import Sequence
from types import ModuleType
from typing import TYPE_CHECKING, cast
from typing import cast

from ._at import at
from ._utils import _compat, _helpers
from ._utils._compat import array_namespace, is_jax_array
from ._utils._helpers import asarrays, ndindex
from ._utils._helpers import asarrays, eager_shape, ndindex
from ._utils._typing import Array

__all__ = [
Expand Down Expand Up @@ -211,11 +211,13 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
m = xp.astype(m, dtype)

avg = _helpers.mean(m, axis=1, xp=xp)
fact = m.shape[1] - 1

m_shape = eager_shape(m)
fact = m_shape[1] - 1

if fact <= 0:
warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2)
fact = 0.0
fact = 0

m -= avg[:, None]
m_transpose = m.T
Expand Down Expand Up @@ -274,8 +276,10 @@ def create_diagonal(
if x.ndim == 0:
err_msg = "`x` must be at least 1-dimensional."
raise ValueError(err_msg)
batch_dims = x.shape[:-1]
n = x.shape[-1] + abs(offset)

x_shape = eager_shape(x)
batch_dims = x_shape[:-1]
n = x_shape[-1] + abs(offset)
diag = xp.zeros((*batch_dims, n**2), dtype=x.dtype, device=_compat.device(x))

target_slice = slice(
Expand Down Expand Up @@ -385,10 +389,6 @@ def isclose(
) -> Array: # numpydoc ignore=PR01,RT01
"""See docstring in array_api_extra._delegation."""
a, b = asarrays(a, b, xp=xp)
# FIXME https://github.com/microsoft/pyright/issues/10085
if TYPE_CHECKING: # pragma: nocover
assert _compat.is_array_api_obj(a)
assert _compat.is_array_api_obj(b)

a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
Expand Down Expand Up @@ -505,24 +505,17 @@ def kron(
if xp is None:
xp = array_namespace(a, b)
a, b = asarrays(a, b, xp=xp)
# FIXME https://github.com/microsoft/pyright/issues/10085
if TYPE_CHECKING: # pragma: nocover
assert _compat.is_array_api_obj(a)
assert _compat.is_array_api_obj(b)

singletons = (1,) * (b.ndim - a.ndim)
a = xp.broadcast_to(a, singletons + a.shape)
# FIXME https://github.com/microsoft/pyright/issues/10085
if TYPE_CHECKING: # pragma: nocover
assert _compat.is_array_api_obj(a)
a = cast(Array, xp.broadcast_to(a, singletons + a.shape))

nd_b, nd_a = b.ndim, a.ndim
nd_max = max(nd_b, nd_a)
if nd_a == 0 or nd_b == 0:
return xp.multiply(a, b)

a_shape = a.shape
b_shape = b.shape
a_shape = eager_shape(a)
b_shape = eager_shape(b)

# Equalise the shapes by prepending smaller one with 1s
a_shape = (1,) * max(0, nd_b - nd_a) + a_shape
Expand Down Expand Up @@ -587,16 +580,14 @@ def pad(
) -> Array: # numpydoc ignore=PR01,RT01
"""See docstring in `array_api_extra._delegation.py`."""
# make pad_width a list of length-2 tuples of ints
x_ndim = cast(int, x.ndim)

if isinstance(pad_width, int):
pad_width_seq = [(pad_width, pad_width)] * x_ndim
pad_width_seq = [(pad_width, pad_width)] * x.ndim
elif (
isinstance(pad_width, tuple)
and len(pad_width) == 2
and all(isinstance(i, int) for i in pad_width)
):
pad_width_seq = [cast(tuple[int, int], pad_width)] * x_ndim
pad_width_seq = [cast(tuple[int, int], pad_width)] * x.ndim
else:
pad_width_seq = cast(list[tuple[int, int]], list(pad_width))

Expand All @@ -608,7 +599,8 @@ def pad(
msg = f"expect a 2-tuple (before, after), got {w_tpl}."
raise ValueError(msg)

sh = x.shape[ax]
sh = eager_shape(x)[ax]

if w_tpl[0] == 0 and w_tpl[1] == 0:
sl = slice(None, None, None)
else:
Expand Down Expand Up @@ -674,20 +666,17 @@ def setdiff1d(
"""
if xp is None:
xp = array_namespace(x1, x2)
x1, x2 = asarrays(x1, x2, xp=xp)
# FIXME https://github.com/microsoft/pyright/issues/10103
x1_, x2_ = asarrays(x1, x2, xp=xp)

if assume_unique:
x1 = xp.reshape(x1, (-1,))
x2 = xp.reshape(x2, (-1,))
x1_ = xp.reshape(x1_, (-1,))
x2_ = xp.reshape(x2_, (-1,))
else:
x1 = xp.unique_values(x1)
x2 = xp.unique_values(x2)

# FIXME https://github.com/microsoft/pyright/issues/10085
if TYPE_CHECKING: # pragma: nocover
assert _compat.is_array_api_obj(x1)
x1_ = xp.unique_values(x1_)
x2_ = xp.unique_values(x2_)

return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]
return x1_[_helpers.in1d(x1_, x2_, assume_unique=True, invert=True, xp=xp)]


def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
Expand Down
Loading