Skip to content

Commit 1e710bf

Browse files
committed
at() tests for jax.jit
1 parent 7dae32c commit 1e710bf

File tree

2 files changed

+97
-23
lines changed

2 files changed

+97
-23
lines changed

src/array_api_extra/_lib/_at.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import operator
77
from collections.abc import Callable
88
from enum import Enum
9+
from numbers import Number
910
from types import ModuleType
1011
from typing import ClassVar, cast
1112

@@ -188,7 +189,7 @@ def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01
188189
def _update_common(
189190
self,
190191
at_op: _AtOp,
191-
y: Array,
192+
y: Array | Number,
192193
/,
193194
copy: bool | None,
194195
xp: ModuleType | None,
@@ -253,7 +254,7 @@ def _update_common(
253254

254255
def set(
255256
self,
256-
y: Array,
257+
y: Array | Number,
257258
/,
258259
copy: bool | None = None,
259260
xp: ModuleType | None = None,
@@ -269,8 +270,8 @@ def set(
269270
def _iop(
270271
self,
271272
at_op: _AtOp,
272-
elwise_op: Callable[[Array, Array], Array],
273-
y: Array,
273+
elwise_op: Callable[[Array, Array | Number], Array],
274+
y: Array | Number,
274275
/,
275276
copy: bool | None,
276277
xp: ModuleType | None,
@@ -294,7 +295,7 @@ def _iop(
294295

295296
def add(
296297
self,
297-
y: Array,
298+
y: Array | Number,
298299
/,
299300
copy: bool | None = None,
300301
xp: ModuleType | None = None,
@@ -308,7 +309,7 @@ def add(
308309

309310
def subtract(
310311
self,
311-
y: Array,
312+
y: Array | Number,
312313
/,
313314
copy: bool | None = None,
314315
xp: ModuleType | None = None,
@@ -318,7 +319,7 @@ def subtract(
318319

319320
def multiply(
320321
self,
321-
y: Array,
322+
y: Array | Number,
322323
/,
323324
copy: bool | None = None,
324325
xp: ModuleType | None = None,
@@ -328,7 +329,7 @@ def multiply(
328329

329330
def divide(
330331
self,
331-
y: Array,
332+
y: Array | Number,
332333
/,
333334
copy: bool | None = None,
334335
xp: ModuleType | None = None,
@@ -338,7 +339,7 @@ def divide(
338339

339340
def power(
340341
self,
341-
y: Array,
342+
y: Array | Number,
342343
/,
343344
copy: bool | None = None,
344345
xp: ModuleType | None = None,
@@ -348,7 +349,7 @@ def power(
348349

349350
def min(
350351
self,
351-
y: Array,
352+
y: Array | Number,
352353
/,
353354
copy: bool | None = None,
354355
xp: ModuleType | None = None,
@@ -361,7 +362,7 @@ def min(
361362

362363
def max(
363364
self,
364-
y: Array,
365+
y: Array | Number,
365366
/,
366367
copy: bool | None = None,
367368
xp: ModuleType | None = None,

tests/test_at.py

Lines changed: 85 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import pickle
12
from collections.abc import Callable, Generator
23
from contextlib import contextmanager
4+
from numbers import Number
35
from types import ModuleType
46
from typing import cast
57

@@ -11,7 +13,48 @@
1113
from array_api_extra._lib._at import _AtOp
1214
from array_api_extra._lib._testing import xp_assert_equal
1315
from array_api_extra._lib._utils._compat import array_namespace, is_writeable_array
14-
from array_api_extra._lib._utils._typing import Array
16+
from array_api_extra._lib._utils._typing import Array, Index
17+
from array_api_extra.testing import lazy_xp_function
18+
19+
20+
def at_op(
21+
x: Array,
22+
idx: Index,
23+
op: _AtOp,
24+
y: Array | Number,
25+
copy: bool | None = None,
26+
xp: ModuleType | None = None,
27+
) -> Array:
28+
"""
29+
Wrapper around at(x, idx).op(y, copy=copy, xp=xp).
30+
31+
This is a hack to allow wrapping `at()` with `lazy_xp_function`.
32+
For clarity, at() itself works inside jax.jit without hacks; this is
33+
just a workaround for when one wants to apply jax.jit to `at()` directly,
34+
which is not a common use case.
35+
"""
36+
if isinstance(idx, (slice | tuple)):
37+
return _at_op(x, None, pickle.dumps(idx), op, y, copy=copy, xp=xp)
38+
return _at_op(x, idx, None, op, y, copy=copy, xp=xp)
39+
40+
41+
def _at_op(
42+
x: Array,
43+
idx: Index | None,
44+
idx_pickle: bytes | None,
45+
op: _AtOp,
46+
y: Array | Number,
47+
copy: bool | None = None,
48+
xp: ModuleType | None = None,
49+
) -> Array:
50+
"""jitted helper of at_op"""
51+
if idx_pickle:
52+
idx = pickle.loads(idx_pickle)
53+
meth = cast(Callable[..., Array], getattr(at(x, idx), op.value)) # type: ignore[no-any-explicit]
54+
return meth(y, copy=copy, xp=xp)
55+
56+
57+
lazy_xp_function(_at_op, static_argnames=("op", "idx_pickle", "copy", "xp"))
1558

1659

1760
@contextmanager
@@ -43,7 +86,7 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
4386
],
4487
)
4588
@pytest.mark.parametrize(
46-
("op", "arg", "expect"),
89+
("op", "y", "expect"),
4790
[
4891
(_AtOp.SET, 40.0, [10.0, 40.0, 40.0]),
4992
(_AtOp.ADD, 40.0, [10.0, 60.0, 70.0]),
@@ -55,21 +98,52 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
5598
(_AtOp.MAX, 25.0, [10.0, 25.0, 30.0]),
5699
],
57100
)
101+
@pytest.mark.parametrize(
102+
("bool_mask", "shaped_y"),
103+
[
104+
(False, False),
105+
(False, True),
106+
pytest.param(
107+
True,
108+
False,
109+
marks=(
110+
pytest.mark.skip_xp_backend(Backend.JAX, reason="TODO special case"),
111+
pytest.mark.skip_xp_backend(Backend.DASK, reason="TODO special case"),
112+
),
113+
),
114+
pytest.param(
115+
True,
116+
True,
117+
marks=(
118+
pytest.mark.skip_xp_backend(
119+
Backend.JAX, reason="bool mask update with shaped rhs"
120+
),
121+
pytest.mark.skip_xp_backend(
122+
Backend.DASK, reason="bool mask update with shaped rhs"
123+
),
124+
),
125+
),
126+
],
127+
)
58128
def test_update_ops(
59129
xp: ModuleType,
60130
kwargs: dict[str, bool | None],
61131
expect_copy: bool | None,
62132
op: _AtOp,
63-
arg: float,
133+
y: float,
64134
expect: list[float],
135+
bool_mask: bool,
136+
shaped_y: bool,
65137
):
66-
array = xp.asarray([10.0, 20.0, 30.0])
138+
x = xp.asarray([10.0, 20.0, 30.0])
139+
idx = xp.asarray([False, True, True]) if bool_mask else slice(1, None)
140+
if shaped_y:
141+
y = xp.asarray([y, y])
67142

68-
with assert_copy(array, expect_copy):
69-
func = cast(Callable[..., Array], getattr(at(array)[1:], op.value)) # type: ignore[no-any-explicit]
70-
y = func(arg, **kwargs)
71-
assert isinstance(y, type(array))
72-
xp_assert_equal(y, xp.asarray(expect))
143+
with assert_copy(x, expect_copy):
144+
z = at_op(x, idx, op, y, **kwargs) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
145+
assert isinstance(z, type(x))
146+
xp_assert_equal(z, xp.asarray(expect))
73147

74148

75149
def test_copy_invalid():
@@ -121,7 +195,6 @@ def test_iops_incompatible_dtype(op: _AtOp, copy: bool):
121195
UFuncTypeError: Cannot cast ufunc 'divide' output from dtype('float64')
122196
to dtype('int64') with casting rule 'same_kind'
123197
"""
124-
a = np.asarray([2, 4])
125-
func = cast(Callable[..., Array], getattr(at(a)[:], op.value)) # type: ignore[no-any-explicit]
198+
x = np.asarray([2, 4])
126199
with pytest.raises(TypeError, match="Cannot cast ufunc"):
127-
func(1.1, copy=copy)
200+
at_op(x, slice(None), op, 1.1, copy=copy)

0 commit comments

Comments
 (0)