We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f1652b6 commit ec90884Copy full SHA for ec90884
src/array_api_extra/_funcs.py
@@ -669,7 +669,7 @@ class at: # pylint: disable=invalid-name
669
670
_x: Array
671
_idx: Index
672
- __slots__: ClassVar[tuple[str, str]] = ("_idx", "_x")
+ __slots__: ClassVar[tuple[str, ...]] = ("_idx", "_x")
673
674
def __init__(self, x: Array, idx: Index = _undef, /) -> None:
675
self._x = x
@@ -728,7 +728,7 @@ def _update_common(
728
729
if copy:
730
if is_jax_array(x):
731
- # Use JAX's at[] or other library that with the same duck-type API
+ # Use JAX's at[]
732
func = getattr(x.at[idx], at_op)
733
return func(y), None
734
# Emulate at[] behaviour for non-JAX arrays
0 commit comments