11import itertools
2+ import sys
23
34import numpy as np
45
910from keras .src .utils .module_utils import tensorflow as tf
1011
1112
13+ class _TrackableIterable :
14+ """Wrapper that captures the live ``DatasetIterator`` on ``iter()``.
15+
16+ When the ``EpochIterator`` calls ``iter()`` on the object returned by
17+ ``get_numpy_iterator()`` / ``get_jax_iterator()``, this wrapper
18+ stores the resulting iterator on the adapter so that
19+ ``get_iterator_state()`` can reach it. If a pending state was
20+ previously set via ``set_iterator_state()``, it is applied to the
21+ fresh iterator immediately.
22+ """
23+
24+ def __init__ (self , dataset , adapter ):
25+ self ._dataset = dataset
26+ self ._adapter = adapter
27+
28+ def __iter__ (self ):
29+ it = iter (self ._dataset )
30+ self ._adapter ._live_iterator = it
31+ if self ._adapter ._pending_iterator_state is not None :
32+ if hasattr (it , "set_state" ):
33+ it .set_state (self ._adapter ._pending_iterator_state )
34+ self ._adapter ._pending_iterator_state = None
35+ return it
36+
37+
1238class GrainDatasetAdapter (DataAdapter ):
13- """Adapter that handles `grain.DataLoader`, `grain.MapDataset` and
14- `grain.IterDataset`.
39+ """Adapter that handles `` grain.DataLoader`` , `` grain.MapDataset` ` and
40+ `` grain.IterDataset` `.
1541 """
1642
1743 def __init__ (self , dataset ):
1844 """Initialize the GrainDatasetAdapter.
1945
2046 Args:
2147 dataset: A Grain dataset instance. Must be one of
22- `grain.DataLoader`, `grain.MapDataset`, or `grain.IterDataset`.
48+ ``grain.DataLoader``, ``grain.MapDataset``, or
49+ ``grain.IterDataset``.
2350 """
2451
2552 if not isinstance (
@@ -32,17 +59,19 @@ def __init__(self, dataset):
3259 )
3360
3461 self ._dataset = dataset
62+ self ._live_iterator = None
63+ self ._pending_iterator_state = None
3564
3665 batch_size , output_signature = self ._get_dataset_info (dataset )
3766 self ._batch_size = batch_size
3867 self ._output_signature = output_signature
3968 self ._output_tf_signature = None
4069
4170 def _get_dataset_info (self , dataset ):
42- """Get the `batch_size` and `output_signature` from the dataset.
71+ """Get the `` batch_size`` and `` output_signature` ` from the dataset.
4372
44- We use a small list of batches to infer the `batch_size` and
45- `output_signature`.
73+ We use a small list of batches to infer the `` batch_size` ` and
74+ `` output_signature` `.
4675 """
4776 batches = list (
4877 itertools .islice (
@@ -73,9 +102,9 @@ def convert_to_numpy(x):
73102 if isinstance (x , (np .ndarray , SharedMemoryArrayMetadata )):
74103 return x
75104 else :
76- # Using `__array__` should handle `tf.Tensor`, `jax.np.ndarray `,
77- # `torch.Tensor`, as well as any other tensor-like object that
78- # has added numpy support.
105+ # Using `` __array__`` should handle `` tf.Tensor``,
106+ # ``jax.np.ndarray``, `` torch.Tensor`` , as well as any
107+ # other tensor-like object that has added numpy support.
79108 if hasattr (x , "__array__" ):
80109 if data_adapter_utils .is_torch_tensor (x ):
81110 x = x .cpu ()
@@ -90,20 +119,21 @@ def map(self, x):
90119
91120 if isinstance (self ._dataset , (grain .MapDataset , grain .IterDataset )):
92121 dataset = self ._dataset .map (ConvertToNumpy ())
122+ return _TrackableIterable (dataset , self )
93123 else :
94- # Instantiate a new `DataLoader`.
124+ # Instantiate a new `` DataLoader` `.
95125 dataset = grain .DataLoader (
96126 data_source = self ._dataset ._data_source ,
97127 sampler = self ._dataset ._sampler ,
98- # Append `ConvertToNumpy`.
128+ # Append `` ConvertToNumpy` `.
99129 operations = list (self ._dataset ._operations ) + [ConvertToNumpy ()],
100130 worker_count = self ._dataset ._multiprocessing_options .num_workers ,
101131 worker_buffer_size = self ._dataset ._multiprocessing_options .per_worker_buffer_size ,
102132 shard_options = self ._dataset ._shard_options ,
103133 read_options = self ._dataset ._read_options ,
104134 enable_profiling = self ._dataset ._multiprocessing_options .enable_profiling ,
105135 )
106- return dataset
136+ return dataset
107137
108138 def get_jax_iterator (self ):
109139 def convert_to_jax_compatible (x ):
@@ -121,12 +151,13 @@ def map(self, x):
121151
122152 if isinstance (self ._dataset , (grain .MapDataset , grain .IterDataset )):
123153 dataset = self ._dataset .map (ConvertToJaxCompatible ())
154+ return _TrackableIterable (dataset , self )
124155 else :
125- # Instantiate a new `DataLoader`.
156+ # Instantiate a new `` DataLoader` `.
126157 dataset = grain .DataLoader (
127158 data_source = self ._dataset ._data_source ,
128159 sampler = self ._dataset ._sampler ,
129- # Append `ConvertToJaxCompatible`.
160+ # Append `` ConvertToJaxCompatible` `.
130161 operations = list (self ._dataset ._operations )
131162 + [ConvertToJaxCompatible ()],
132163 worker_count = self ._dataset ._multiprocessing_options .num_workers ,
@@ -135,7 +166,7 @@ def map(self, x):
135166 read_options = self ._dataset ._read_options ,
136167 enable_profiling = self ._dataset ._multiprocessing_options .enable_profiling ,
137168 )
138- return dataset
169+ return dataset
139170
140171 def get_tf_dataset (self ):
141172 def convert_to_tf (x ):
@@ -151,7 +182,7 @@ class ConvertToTF(grain.transforms.Map):
151182 def map (self , x ):
152183 return tree .map_structure (convert_to_tf , x )
153184
154- # `tf.data.Dataset.from_generator` does not support lists as output.
185+ # `` tf.data.Dataset.from_generator` ` does not support lists as output.
155186 # We convert lists to tuples.
156187 class ListToTuple (grain .transforms .Map ):
157188 def map (self , x ):
@@ -161,11 +192,11 @@ def map(self, x):
161192 dataset = self ._dataset .map (ConvertToTF ())
162193 dataset = dataset .map (ListToTuple ())
163194 else :
164- # Instantiate a new `DataLoader`.
195+ # Instantiate a new `` DataLoader` `.
165196 dataset = grain .DataLoader (
166197 data_source = self ._dataset ._data_source ,
167198 sampler = self ._dataset ._sampler ,
168- # Append `ConvertToTF` and `ListToTuple`.
199+ # Append `` ConvertToTF`` and `` ListToTuple` `.
169200 operations = list (self ._dataset ._operations )
170201 + [ConvertToTF (), ListToTuple ()],
171202 worker_count = self ._dataset ._multiprocessing_options .num_workers ,
@@ -196,13 +227,46 @@ def __init__(self, iterable):
196227 def __iter__ (self ):
197228 return iter (self .iterable )
198229
199- # `batch_size=None` indicates that we should not re-batch
230+ if isinstance (self ._dataset , (grain .MapDataset , grain .IterDataset )):
231+ iterable = _TrackableIterable (self ._dataset , self )
232+ else :
233+ iterable = self ._dataset
234+
235+ # ``batch_size=None`` indicates that we should not re-batch
200236 return torch_data .DataLoader (
201- ConverterIterableDataset (self . _dataset ), batch_size = None
237+ ConverterIterableDataset (iterable ), batch_size = None
202238 )
203239
240+ # ------------------------------------------------------------------
241+ # Iterator checkpoint / resume
242+ # ------------------------------------------------------------------
243+
244+ def get_iterator_state (self ):
245+ if self ._live_iterator is not None and hasattr (
246+ self ._live_iterator , "get_state"
247+ ):
248+ return self ._live_iterator .get_state ()
249+ return None
250+
251+ def set_iterator_state (self , state ):
252+ if state is not None :
253+ self ._pending_iterator_state = state
254+
255+ # ------------------------------------------------------------------
256+ # Metadata
257+ # ------------------------------------------------------------------
258+
204259 @property
205260 def num_batches (self ):
261+ if isinstance (self ._dataset , grain .MapDataset ):
262+ try :
263+ length = len (self ._dataset )
264+ except TypeError :
265+ return None
266+ # ``repeat(None)`` sets length to ``sys.maxsize``.
267+ if length >= sys .maxsize :
268+ return None
269+ return length
206270 return None
207271
208272 @property
0 commit comments