Skip to content

Commit f15a631

Browse files
committed
WIP tests
1 parent 197a523 commit f15a631

File tree

2 files changed

+119
-14
lines changed

2 files changed

+119
-14
lines changed

tests/conftest.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -148,19 +148,7 @@ def xp(
148148
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
149149

150150
if library.like(Backend.JAX):
151-
import jax
152-
153-
# suppress unused-ignore to run mypy in -e lint as well as -e dev
154-
jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore]
155-
156-
if library == Backend.JAX_GPU:
157-
try:
158-
device = jax.devices("cuda")[0]
159-
except RuntimeError:
160-
pytest.skip("no CUDA device available")
161-
else:
162-
device = jax.devices("cpu")[0]
163-
jax.config.update("jax_default_device", device)
151+
_setup_jax(library)
164152

165153
elif library == Backend.TORCH_GPU:
166154
import torch.cuda
@@ -175,6 +163,22 @@ def xp(
175163
yield xp
176164

177165

166+
def _setup_jax(library: Backend) -> None:
167+
import jax
168+
169+
# suppress unused-ignore to run mypy in -e lint as well as -e dev
170+
jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore]
171+
172+
if library == Backend.JAX_GPU:
173+
try:
174+
device = jax.devices("cuda")[0]
175+
except RuntimeError:
176+
pytest.skip("no CUDA device available")
177+
else:
178+
device = jax.devices("cpu")[0]
179+
jax.config.update("jax_default_device", device)
180+
181+
178182
@pytest.fixture(params=[Backend.DASK]) # Can select the test with `pytest -k dask`
179183
def da(
180184
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
@@ -186,6 +190,17 @@ def da(
186190
return xp
187191

188192

193+
@pytest.fixture(params=[Backend.JAX, Backend.JAX_GPU])
194+
def jnp(
195+
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
196+
) -> ModuleType: # numpydoc ignore=PR01,RT01
197+
"""Variant of the `xp` fixture that only yields jax.numpy."""
198+
xp = pytest.importorskip("jax.numpy")
199+
_setup_jax(request.param)
200+
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
201+
return xp
202+
203+
189204
@pytest.fixture
190205
def device(
191206
library: Backend, xp: ModuleType

tests/test_helpers.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from types import ModuleType
2-
from typing import cast
2+
from typing import Generic, TypeVar, cast
33

44
import numpy as np
55
import pytest
@@ -13,6 +13,7 @@
1313
capabilities,
1414
eager_shape,
1515
in1d,
16+
jax_autojit,
1617
meta_namespace,
1718
ndindex,
1819
)
@@ -23,6 +24,8 @@
2324

2425
# mypy: disable-error-code=no-untyped-usage
2526

27+
T = TypeVar("T")
28+
2629
# FIXME calls xp.unique_values without size
2730
lazy_xp_function(in1d, jax_jit=False)
2831

@@ -204,3 +207,90 @@ def test_capabilities(xp: ModuleType):
204207
if xp.__array_api_version__ >= "2024.12":
205208
expect.add("max dimensions")
206209
assert capabilities(xp).keys() == expect
210+
211+
212+
class Wrapper(Generic[T]):
213+
"""Trivial opaque wrapper. Must be pickleable."""
214+
215+
x: T
216+
# __slots__ make this object serializable with __reduce_ex__(5),
217+
# but not with __reduce__
218+
__slots__: tuple[str, ...] = ("x",)
219+
220+
def __init__(self, x: T):
221+
self.x = x
222+
223+
224+
class TestJAXAutoJIT:
225+
def test_basic(self, jnp: ModuleType):
226+
@jax_autojit
227+
def f(x: Array, k: object = False) -> Array:
228+
return x + 1 if k else x - 1
229+
230+
# Basic recognition of static_argnames
231+
xp_assert_equal(f(jnp.asarray([1, 2])), jnp.asarray([0, 1]))
232+
xp_assert_equal(f(jnp.asarray([1, 2]), False), jnp.asarray([0, 1]))
233+
xp_assert_equal(f(jnp.asarray([1, 2]), True), jnp.asarray([2, 3]))
234+
xp_assert_equal(f(jnp.asarray([1, 2]), 1), jnp.asarray([2, 3]))
235+
236+
# static argument is not an ArrayLike
237+
xp_assert_equal(f(jnp.asarray([1, 2]), "foo"), jnp.asarray([2, 3]))
238+
239+
# static argument is not hashable, but serializable
240+
xp_assert_equal(f(jnp.asarray([1, 2]), ["foo"]), jnp.asarray([2, 3]))
241+
242+
def test_wrapper(self, jnp: ModuleType):
243+
@jax_autojit
244+
def f(w: Wrapper[Array]) -> Wrapper[Array]:
245+
return Wrapper(w.x + 1)
246+
247+
inp = Wrapper(jnp.asarray([1, 2]))
248+
out = f(inp).x
249+
xp_assert_equal(out, jnp.asarray([2, 3]))
250+
251+
def test_static_hashable(self, jnp: ModuleType):
252+
"""Static argument/return value is hashable, but not serializable"""
253+
254+
class C:
255+
def __reduce__(self) -> object: # type: ignore[explicit-override,override] # pyright: ignore[reportIncompatibleMethodOverride,reportImplicitOverride]
256+
raise Exception()
257+
258+
@jax_autojit
259+
def f(x: object) -> object:
260+
return x
261+
262+
inp = C()
263+
out = f(inp)
264+
assert out is inp
265+
266+
# Serializable opaque input contains non-serializable object plus array
267+
inp = Wrapper((C(), jnp.asarray([1, 2])))
268+
out = f(inp)
269+
assert isinstance(out, Wrapper)
270+
assert out.x[0] is inp.x[0]
271+
assert out.x[1] is not inp.x[1]
272+
xp_assert_equal(out.x[1], inp.x[1]) # pyright: ignore[reportUnknownArgumentType]
273+
274+
def test_arraylikes_are_static(self):
275+
pytest.importorskip("jax")
276+
277+
@jax_autojit
278+
def f(x: list[int]) -> list[int]:
279+
assert isinstance(x, list)
280+
assert x == [1, 2]
281+
return [3, 4]
282+
283+
out = f([1, 2])
284+
assert isinstance(out, list)
285+
assert out == [3, 4]
286+
287+
def test_repeated_objects(self, jnp: ModuleType):
288+
@jax_autojit
289+
def f(x: Array, y: Array) -> tuple[Array, Array]:
290+
z = x + y
291+
return z, z
292+
293+
inp = jnp.asarray([1, 2])
294+
out = f(inp, inp)
295+
assert out[0] is out[1]
296+
xp_assert_equal(out[0], jnp.asarray([2, 4]))

0 commit comments

Comments
 (0)