6
6
import math
7
7
import pickle
8
8
import types
9
- from collections .abc import Callable , Generator , Iterable
9
+ from collections .abc import Callable , Generator , Iterable , Iterator
10
10
from functools import wraps
11
11
from types import ModuleType
12
12
from typing import (
@@ -512,13 +512,24 @@ class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01
512
512
convert them to/from PyTrees.
513
513
"""
514
514
515
- obj : T
515
+ _obj : Any
516
+ _is_iter : bool
516
517
_registered : ClassVar [bool ] = False
517
- __slots__ : tuple [str , ...] = ("obj" , )
518
+ __slots__ : tuple [str , ...] = ("_is_iter" , "_obj" )
518
519
519
520
def __init__ (self , obj : T ) -> None : # numpydoc ignore=GL08
520
521
self ._register ()
521
- self .obj = obj
522
+ if isinstance (obj , Iterator ):
523
+ self ._obj = list (obj )
524
+ self ._is_iter = True
525
+ else :
526
+ self ._obj = obj
527
+ self ._is_iter = False
528
+
529
+ @property
530
+ def obj (self ) -> T : # numpydoc ignore=RT01
531
+ """Return wrapped object."""
532
+ return iter (self ._obj ) if self ._is_iter else self ._obj
522
533
523
534
@classmethod
524
535
def _register (cls ) -> None : # numpydoc ignore=SS06
@@ -531,7 +542,7 @@ def _register(cls) -> None: # numpydoc ignore=SS06
531
542
532
543
jax .tree_util .register_pytree_node (
533
544
cls ,
534
- lambda obj : pickle_flatten (obj , jax .Array ), # pyright: ignore[reportUnknownArgumentType]
545
+ lambda instance : pickle_flatten (instance , jax .Array ), # pyright: ignore[reportUnknownArgumentType]
535
546
lambda aux_data , children : pickle_unflatten (children , aux_data ), # pyright: ignore[reportUnknownArgumentType]
536
547
)
537
548
cls ._registered = True
@@ -556,6 +567,7 @@ def jax_autojit(
556
567
- Automatically descend into non-array return values and find ``jax.Array`` objects
557
568
inside them, then rebuild them downstream of exiting the JIT, swapping the JAX
558
569
tracer objects with concrete arrays.
570
+ - Returned iterators are immediately completely consumed.
559
571
560
572
See Also
561
573
--------
0 commit comments