Skip to content

Commit 9d03a99

Browse files
authored
ENH: lazy_xp_function support for iterators (#418)
1 parent 9e11827 commit 9d03a99

File tree

3 files changed

+51
-6
lines changed

3 files changed

+51
-6
lines changed

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import math
77
import pickle
88
import types
9-
from collections.abc import Callable, Generator, Iterable
9+
from collections.abc import Callable, Generator, Iterable, Iterator
1010
from functools import wraps
1111
from types import ModuleType
1212
from typing import (
@@ -512,13 +512,24 @@ class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01
512512
convert them to/from PyTrees.
513513
"""
514514

515-
obj: T
515+
_obj: Any
516+
_is_iter: bool
516517
_registered: ClassVar[bool] = False
517-
__slots__: tuple[str, ...] = ("obj",)
518+
__slots__: tuple[str, ...] = ("_is_iter", "_obj")
518519

519520
def __init__(self, obj: T) -> None: # numpydoc ignore=GL08
520521
self._register()
521-
self.obj = obj
522+
if isinstance(obj, Iterator):
523+
self._obj = list(obj)
524+
self._is_iter = True
525+
else:
526+
self._obj = obj
527+
self._is_iter = False
528+
529+
@property
530+
def obj(self) -> T: # numpydoc ignore=RT01
531+
"""Return wrapped object."""
532+
return iter(self._obj) if self._is_iter else self._obj
522533

523534
@classmethod
524535
def _register(cls) -> None: # numpydoc ignore=SS06
@@ -531,7 +542,7 @@ def _register(cls) -> None: # numpydoc ignore=SS06
531542

532543
jax.tree_util.register_pytree_node(
533544
cls,
534-
lambda obj: pickle_flatten(obj, jax.Array), # pyright: ignore[reportUnknownArgumentType]
545+
lambda instance: pickle_flatten(instance, jax.Array), # pyright: ignore[reportUnknownArgumentType]
535546
lambda aux_data, children: pickle_unflatten(children, aux_data), # pyright: ignore[reportUnknownArgumentType]
536547
)
537548
cls._registered = True
@@ -556,6 +567,7 @@ def jax_autojit(
556567
- Automatically descend into non-array return values and find ``jax.Array`` objects
557568
inside them, then rebuild them downstream of exiting the JIT, swapping the JAX
558569
tracer objects with concrete arrays.
570+
- Returned iterators are immediately completely consumed.
559571
560572
See Also
561573
--------

tests/test_helpers.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Iterator
12
from types import ModuleType
23
from typing import TYPE_CHECKING, Generic, TypeVar, cast
34

@@ -417,3 +418,16 @@ def f(x: list[int]) -> list[int]:
417418
out = f([1, 2])
418419
assert isinstance(out, list)
419420
assert out == [3, 4]
421+
422+
def test_iterators(self, jnp: ModuleType):
423+
@jax_autojit
424+
def f(x: Array) -> Iterator[Array]:
425+
return (x + i for i in range(2))
426+
427+
inp = jnp.asarray([1, 2])
428+
out = f(inp)
429+
assert isinstance(out, Iterator)
430+
xp_assert_equal(next(out), jnp.asarray([1, 2]))
431+
xp_assert_equal(next(out), jnp.asarray([2, 3]))
432+
with pytest.raises(StopIteration):
433+
_ = next(out)

tests/test_testing.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Callable
1+
from collections.abc import Callable, Iterator
22
from types import ModuleType
33
from typing import cast
44

@@ -468,3 +468,22 @@ def test_patch_lazy_xp_functions_deprecated_monkeypatch(
468468
monkeypatch.undo()
469469
y = non_materializable5(x)
470470
xp_assert_equal(y, x)
471+
472+
473+
def my_iter(x: Array) -> Iterator[Array]:
474+
yield x[0, :]
475+
yield x[1, :]
476+
477+
478+
lazy_xp_function(my_iter)
479+
480+
481+
def test_patch_lazy_xp_functions_iter(xp: ModuleType):
482+
x = xp.asarray([[1.0, 2.0], [3.0, 4.0]])
483+
it = my_iter(x)
484+
485+
assert isinstance(it, Iterator)
486+
xp_assert_equal(next(it), x[0, :])
487+
xp_assert_equal(next(it), x[1, :])
488+
with pytest.raises(StopIteration):
489+
_ = next(it)

0 commit comments

Comments
 (0)