| 
1 | 1 | from types import ModuleType  | 
2 |  | -from typing import cast  | 
 | 2 | +from typing import Generic, TypeVar, cast  | 
3 | 3 | 
 
  | 
4 | 4 | import numpy as np  | 
5 | 5 | import pytest  | 
 | 
13 | 13 |     capabilities,  | 
14 | 14 |     eager_shape,  | 
15 | 15 |     in1d,  | 
 | 16 | +    jax_autojit,  | 
16 | 17 |     meta_namespace,  | 
17 | 18 |     ndindex,  | 
18 | 19 | )  | 
 | 
23 | 24 | 
 
  | 
24 | 25 | # mypy: disable-error-code=no-untyped-usage  | 
25 | 26 | 
 
  | 
 | 27 | +T = TypeVar("T")  | 
 | 28 | + | 
26 | 29 | # FIXME calls xp.unique_values without size  | 
27 | 30 | lazy_xp_function(in1d, jax_jit=False)  | 
28 | 31 | 
 
  | 
@@ -204,3 +207,90 @@ def test_capabilities(xp: ModuleType):  | 
204 | 207 |     if xp.__array_api_version__ >= "2024.12":  | 
205 | 208 |         expect.add("max dimensions")  | 
206 | 209 |     assert capabilities(xp).keys() == expect  | 
 | 210 | + | 
 | 211 | + | 
 | 212 | +class Wrapper(Generic[T]):  | 
 | 213 | +    """Trivial opaque wrapper. Must be pickleable."""  | 
 | 214 | + | 
 | 215 | +    x: T  | 
 | 216 | +    # __slots__ make this object serializable with __reduce_ex__(5),  | 
 | 217 | +    # but not with __reduce__  | 
 | 218 | +    __slots__: tuple[str, ...] = ("x",)  | 
 | 219 | + | 
 | 220 | +    def __init__(self, x: T):  | 
 | 221 | +        self.x = x  | 
 | 222 | + | 
 | 223 | + | 
 | 224 | +class TestJAXAutoJIT:  | 
 | 225 | +    def test_basic(self, jnp: ModuleType):  | 
 | 226 | +        @jax_autojit  | 
 | 227 | +        def f(x: Array, k: object = False) -> Array:  | 
 | 228 | +            return x + 1 if k else x - 1  | 
 | 229 | + | 
 | 230 | +        # Basic recognition of static_argnames  | 
 | 231 | +        xp_assert_equal(f(jnp.asarray([1, 2])), jnp.asarray([0, 1]))  | 
 | 232 | +        xp_assert_equal(f(jnp.asarray([1, 2]), False), jnp.asarray([0, 1]))  | 
 | 233 | +        xp_assert_equal(f(jnp.asarray([1, 2]), True), jnp.asarray([2, 3]))  | 
 | 234 | +        xp_assert_equal(f(jnp.asarray([1, 2]), 1), jnp.asarray([2, 3]))  | 
 | 235 | + | 
 | 236 | +        # static argument is not an ArrayLike  | 
 | 237 | +        xp_assert_equal(f(jnp.asarray([1, 2]), "foo"), jnp.asarray([2, 3]))  | 
 | 238 | + | 
 | 239 | +        # static argument is not hashable, but serializable  | 
 | 240 | +        xp_assert_equal(f(jnp.asarray([1, 2]), ["foo"]), jnp.asarray([2, 3]))  | 
 | 241 | + | 
 | 242 | +    def test_wrapper(self, jnp: ModuleType):  | 
 | 243 | +        @jax_autojit  | 
 | 244 | +        def f(w: Wrapper[Array]) -> Wrapper[Array]:  | 
 | 245 | +            return Wrapper(w.x + 1)  | 
 | 246 | + | 
 | 247 | +        inp = Wrapper(jnp.asarray([1, 2]))  | 
 | 248 | +        out = f(inp).x  | 
 | 249 | +        xp_assert_equal(out, jnp.asarray([2, 3]))  | 
 | 250 | + | 
 | 251 | +    def test_static_hashable(self, jnp: ModuleType):  | 
 | 252 | +        """Static argument/return value is hashable, but not serializable"""  | 
 | 253 | + | 
 | 254 | +        class C:  | 
 | 255 | +            def __reduce__(self) -> object:  # type: ignore[explicit-override,override]  # pyright: ignore[reportIncompatibleMethodOverride,reportImplicitOverride]  | 
 | 256 | +                raise Exception()  | 
 | 257 | + | 
 | 258 | +        @jax_autojit  | 
 | 259 | +        def f(x: object) -> object:  | 
 | 260 | +            return x  | 
 | 261 | + | 
 | 262 | +        inp = C()  | 
 | 263 | +        out = f(inp)  | 
 | 264 | +        assert out is inp  | 
 | 265 | + | 
 | 266 | +        # Serializable opaque input contains non-serializable object plus array  | 
 | 267 | +        inp = Wrapper((C(), jnp.asarray([1, 2])))  | 
 | 268 | +        out = f(inp)  | 
 | 269 | +        assert isinstance(out, Wrapper)  | 
 | 270 | +        assert out.x[0] is inp.x[0]  | 
 | 271 | +        assert out.x[1] is not inp.x[1]  | 
 | 272 | +        xp_assert_equal(out.x[1], inp.x[1])  # pyright: ignore[reportUnknownArgumentType]  | 
 | 273 | + | 
 | 274 | +    def test_arraylikes_are_static(self):  | 
 | 275 | +        pytest.importorskip("jax")  | 
 | 276 | + | 
 | 277 | +        @jax_autojit  | 
 | 278 | +        def f(x: list[int]) -> list[int]:  | 
 | 279 | +            assert isinstance(x, list)  | 
 | 280 | +            assert x == [1, 2]  | 
 | 281 | +            return [3, 4]  | 
 | 282 | + | 
 | 283 | +        out = f([1, 2])  | 
 | 284 | +        assert isinstance(out, list)  | 
 | 285 | +        assert out == [3, 4]  | 
 | 286 | + | 
 | 287 | +    def test_repeated_objects(self, jnp: ModuleType):  | 
 | 288 | +        @jax_autojit  | 
 | 289 | +        def f(x: Array, y: Array) -> tuple[Array, Array]:  | 
 | 290 | +            z = x + y  | 
 | 291 | +            return z, z  | 
 | 292 | + | 
 | 293 | +        inp = jnp.asarray([1, 2])  | 
 | 294 | +        out = f(inp, inp)  | 
 | 295 | +        assert out[0] is out[1]  | 
 | 296 | +        xp_assert_equal(out[0], jnp.asarray([2, 4]))  | 
0 commit comments