|
1 | 1 | from types import ModuleType |
2 | | -from typing import cast |
| 2 | +from typing import Generic, TypeVar, cast |
3 | 3 |
|
4 | 4 | import numpy as np |
5 | 5 | import pytest |
|
13 | 13 | capabilities, |
14 | 14 | eager_shape, |
15 | 15 | in1d, |
| 16 | + jax_autojit, |
16 | 17 | meta_namespace, |
17 | 18 | ndindex, |
18 | 19 | ) |
|
23 | 24 |
|
24 | 25 | # mypy: disable-error-code=no-untyped-usage |
25 | 26 |
|
| 27 | +T = TypeVar("T") |
| 28 | + |
26 | 29 | # FIXME calls xp.unique_values without size |
27 | 30 | lazy_xp_function(in1d, jax_jit=False) |
28 | 31 |
|
@@ -204,3 +207,90 @@ def test_capabilities(xp: ModuleType): |
204 | 207 | if xp.__array_api_version__ >= "2024.12": |
205 | 208 | expect.add("max dimensions") |
206 | 209 | assert capabilities(xp).keys() == expect |
| 210 | + |
| 211 | + |
| 212 | +class Wrapper(Generic[T]): |
| 213 | + """Trivial opaque wrapper. Must be pickleable.""" |
| 214 | + |
| 215 | + x: T |
| 216 | + # __slots__ make this object serializable with __reduce_ex__(5), |
| 217 | + # but not with __reduce__ |
| 218 | + __slots__: tuple[str, ...] = ("x",) |
| 219 | + |
| 220 | + def __init__(self, x: T): |
| 221 | + self.x = x |
| 222 | + |
| 223 | + |
| 224 | +class TestJAXAutoJIT: |
| 225 | + def test_basic(self, jnp: ModuleType): |
| 226 | + @jax_autojit |
| 227 | + def f(x: Array, k: object = False) -> Array: |
| 228 | + return x + 1 if k else x - 1 |
| 229 | + |
| 230 | + # Basic recognition of static_argnames |
| 231 | + xp_assert_equal(f(jnp.asarray([1, 2])), jnp.asarray([0, 1])) |
| 232 | + xp_assert_equal(f(jnp.asarray([1, 2]), False), jnp.asarray([0, 1])) |
| 233 | + xp_assert_equal(f(jnp.asarray([1, 2]), True), jnp.asarray([2, 3])) |
| 234 | + xp_assert_equal(f(jnp.asarray([1, 2]), 1), jnp.asarray([2, 3])) |
| 235 | + |
| 236 | + # static argument is not an ArrayLike |
| 237 | + xp_assert_equal(f(jnp.asarray([1, 2]), "foo"), jnp.asarray([2, 3])) |
| 238 | + |
| 239 | + # static argument is not hashable, but serializable |
| 240 | + xp_assert_equal(f(jnp.asarray([1, 2]), ["foo"]), jnp.asarray([2, 3])) |
| 241 | + |
| 242 | + def test_wrapper(self, jnp: ModuleType): |
| 243 | + @jax_autojit |
| 244 | + def f(w: Wrapper[Array]) -> Wrapper[Array]: |
| 245 | + return Wrapper(w.x + 1) |
| 246 | + |
| 247 | + inp = Wrapper(jnp.asarray([1, 2])) |
| 248 | + out = f(inp).x |
| 249 | + xp_assert_equal(out, jnp.asarray([2, 3])) |
| 250 | + |
| 251 | + def test_static_hashable(self, jnp: ModuleType): |
| 252 | + """Static argument/return value is hashable, but not serializable""" |
| 253 | + |
| 254 | + class C: |
| 255 | + def __reduce__(self) -> object: # type: ignore[explicit-override,override] # pyright: ignore[reportIncompatibleMethodOverride,reportImplicitOverride] |
| 256 | + raise Exception() |
| 257 | + |
| 258 | + @jax_autojit |
| 259 | + def f(x: object) -> object: |
| 260 | + return x |
| 261 | + |
| 262 | + inp = C() |
| 263 | + out = f(inp) |
| 264 | + assert out is inp |
| 265 | + |
| 266 | + # Serializable opaque input contains non-serializable object plus array |
| 267 | + inp = Wrapper((C(), jnp.asarray([1, 2]))) |
| 268 | + out = f(inp) |
| 269 | + assert isinstance(out, Wrapper) |
| 270 | + assert out.x[0] is inp.x[0] |
| 271 | + assert out.x[1] is not inp.x[1] |
| 272 | + xp_assert_equal(out.x[1], inp.x[1]) # pyright: ignore[reportUnknownArgumentType] |
| 273 | + |
| 274 | + def test_arraylikes_are_static(self): |
| 275 | + pytest.importorskip("jax") |
| 276 | + |
| 277 | + @jax_autojit |
| 278 | + def f(x: list[int]) -> list[int]: |
| 279 | + assert isinstance(x, list) |
| 280 | + assert x == [1, 2] |
| 281 | + return [3, 4] |
| 282 | + |
| 283 | + out = f([1, 2]) |
| 284 | + assert isinstance(out, list) |
| 285 | + assert out == [3, 4] |
| 286 | + |
| 287 | + def test_repeated_objects(self, jnp: ModuleType): |
| 288 | + @jax_autojit |
| 289 | + def f(x: Array, y: Array) -> tuple[Array, Array]: |
| 290 | + z = x + y |
| 291 | + return z, z |
| 292 | + |
| 293 | + inp = jnp.asarray([1, 2]) |
| 294 | + out = f(inp, inp) |
| 295 | + assert out[0] is out[1] |
| 296 | + xp_assert_equal(out[0], jnp.asarray([2, 4])) |
0 commit comments