Skip to content

Commit 7811bfc

Browse files
committed
at() tests for jax.jit
1 parent 910ffc0 commit 7811bfc

File tree

2 files changed

+93
-23
lines changed

2 files changed

+93
-23
lines changed

src/array_api_extra/_lib/_at.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import operator
77
from collections.abc import Callable
88
from enum import Enum
9+
from numbers import Number
910
from types import ModuleType
1011
from typing import ClassVar, cast
1112

@@ -188,7 +189,7 @@ def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01
188189
def _update_common(
189190
self,
190191
at_op: _AtOp,
191-
y: Array,
192+
y: Array | Number,
192193
/,
193194
copy: bool | None,
194195
xp: ModuleType | None,
@@ -253,7 +254,7 @@ def _update_common(
253254

254255
def set(
255256
self,
256-
y: Array,
257+
y: Array | Number,
257258
/,
258259
copy: bool | None = None,
259260
xp: ModuleType | None = None,
@@ -269,8 +270,8 @@ def set(
269270
def _iop(
270271
self,
271272
at_op: _AtOp,
272-
elwise_op: Callable[[Array, Array], Array],
273-
y: Array,
273+
elwise_op: Callable[[Array, Array | Number], Array],
274+
y: Array | Number,
274275
/,
275276
copy: bool | None,
276277
xp: ModuleType | None,
@@ -294,7 +295,7 @@ def _iop(
294295

295296
def add(
296297
self,
297-
y: Array,
298+
y: Array | Number,
298299
/,
299300
copy: bool | None = None,
300301
xp: ModuleType | None = None,
@@ -308,7 +309,7 @@ def add(
308309

309310
def subtract(
310311
self,
311-
y: Array,
312+
y: Array | Number,
312313
/,
313314
copy: bool | None = None,
314315
xp: ModuleType | None = None,
@@ -318,7 +319,7 @@ def subtract(
318319

319320
def multiply(
320321
self,
321-
y: Array,
322+
y: Array | Number,
322323
/,
323324
copy: bool | None = None,
324325
xp: ModuleType | None = None,
@@ -328,7 +329,7 @@ def multiply(
328329

329330
def divide(
330331
self,
331-
y: Array,
332+
y: Array | Number,
332333
/,
333334
copy: bool | None = None,
334335
xp: ModuleType | None = None,
@@ -338,7 +339,7 @@ def divide(
338339

339340
def power(
340341
self,
341-
y: Array,
342+
y: Array | Number,
342343
/,
343344
copy: bool | None = None,
344345
xp: ModuleType | None = None,
@@ -348,7 +349,7 @@ def power(
348349

349350
def min(
350351
self,
351-
y: Array,
352+
y: Array | Number,
352353
/,
353354
copy: bool | None = None,
354355
xp: ModuleType | None = None,
@@ -361,7 +362,7 @@ def min(
361362

362363
def max(
363364
self,
364-
y: Array,
365+
y: Array | Number,
365366
/,
366367
copy: bool | None = None,
367368
xp: ModuleType | None = None,

tests/test_at.py

Lines changed: 81 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import pickle
12
from collections.abc import Callable, Generator
23
from contextlib import contextmanager
4+
from numbers import Number
35
from types import ModuleType
46
from typing import cast
57

@@ -11,7 +13,45 @@
1113
from array_api_extra._lib._at import _AtOp
1214
from array_api_extra._lib._testing import xp_assert_equal
1315
from array_api_extra._lib._utils._compat import array_namespace, is_writeable_array
14-
from array_api_extra._lib._utils._typing import Array
16+
from array_api_extra._lib._utils._typing import Array, Index
17+
from array_api_extra.testing import lazy_xp_function
18+
19+
20+
def at_op(
21+
x: Array,
22+
idx: Index,
23+
op: _AtOp,
24+
y: Array | Number,
25+
copy: bool | None = None,
26+
xp: ModuleType | None = None,
27+
) -> Array:
28+
"""
29+
Wrapper around at(x, idx).op(y, copy=copy, xp=xp).
30+
31+
This is a hack to allow wrapping `at()` with `lazy_xp_function`.
32+
"""
33+
if isinstance(idx, (slice | tuple)):
34+
return _at_op(x, None, pickle.dumps(idx), op, y, copy=copy, xp=xp)
35+
return _at_op(x, idx, None, op, y, copy=copy, xp=xp)
36+
37+
38+
def _at_op(
39+
x: Array,
40+
idx: Index | None,
41+
idx_pickle: bytes | None,
42+
op: _AtOp,
43+
y: Array | Number,
44+
copy: bool | None = None,
45+
xp: ModuleType | None = None,
46+
) -> Array:
47+
"""jitted helper of at_op"""
48+
if idx_pickle:
49+
idx = pickle.loads(idx_pickle)
50+
meth = cast(Callable[..., Array], getattr(at(x, idx), op.value)) # type: ignore[no-any-explicit]
51+
return meth(y, copy=copy, xp=xp)
52+
53+
54+
lazy_xp_function(_at_op, static_argnames=("op", "idx_pickle", "copy", "xp"))
1555

1656

1757
@contextmanager
@@ -43,7 +83,7 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
4383
],
4484
)
4585
@pytest.mark.parametrize(
46-
("op", "arg", "expect"),
86+
("op", "y", "expect"),
4787
[
4888
(_AtOp.SET, 40.0, [10.0, 40.0, 40.0]),
4989
(_AtOp.ADD, 40.0, [10.0, 60.0, 70.0]),
@@ -55,21 +95,51 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
5595
(_AtOp.MAX, 25.0, [10.0, 25.0, 30.0]),
5696
],
5797
)
98+
@pytest.mark.parametrize(
99+
("bool_mask", "shaped_y"),
100+
[
101+
(False, False),
102+
pytest.param(
103+
True,
104+
False,
105+
marks=(
106+
pytest.mark.skip_xp_backend(Backend.JAX, reason="TODO special case"),
107+
pytest.mark.skip_xp_backend(Backend.DASK, reason="TODO special case"),
108+
),
109+
),
110+
pytest.param(
111+
True,
112+
True,
113+
marks=(
114+
pytest.mark.skip_xp_backend(
115+
Backend.JAX, reason="bool mask update with shaped rhs"
116+
),
117+
pytest.mark.skip_xp_backend(
118+
Backend.DASK, reason="bool mask update with shaped rhs"
119+
),
120+
),
121+
),
122+
],
123+
)
58124
def test_update_ops(
59125
xp: ModuleType,
60126
kwargs: dict[str, bool | None],
61127
expect_copy: bool | None,
62128
op: _AtOp,
63-
arg: float,
129+
y: float,
64130
expect: list[float],
131+
bool_mask: bool,
132+
shaped_y: bool,
65133
):
66-
array = xp.asarray([10.0, 20.0, 30.0])
134+
x = xp.asarray([10.0, 20.0, 30.0])
135+
idx = xp.asarray([False, True, True]) if bool_mask else slice(1, None)
136+
if shaped_y:
137+
y = xp.asarray([y, y])
67138

68-
with assert_copy(array, expect_copy):
69-
func = cast(Callable[..., Array], getattr(at(array)[1:], op.value)) # type: ignore[no-any-explicit]
70-
y = func(arg, **kwargs)
71-
assert isinstance(y, type(array))
72-
xp_assert_equal(y, xp.asarray(expect))
139+
with assert_copy(x, expect_copy):
140+
z = at_op(x, idx, op, y, **kwargs) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
141+
assert isinstance(z, type(x))
142+
xp_assert_equal(z, xp.asarray(expect))
73143

74144

75145
def test_copy_invalid():
@@ -121,7 +191,6 @@ def test_iops_incompatible_dtype(op: _AtOp, copy: bool):
121191
UFuncTypeError: Cannot cast ufunc 'divide' output from dtype('float64')
122192
to dtype('int64') with casting rule 'same_kind'
123193
"""
124-
a = np.asarray([2, 4])
125-
func = cast(Callable[..., Array], getattr(at(a)[:], op.value)) # type: ignore[no-any-explicit]
194+
x = np.asarray([2, 4])
126195
with pytest.raises(TypeError, match="Cannot cast ufunc"):
127-
func(1.1, copy=copy)
196+
at_op(x, slice(None), op, 1.1, copy=copy)

0 commit comments

Comments
 (0)