Skip to content

Commit 68ab6e9

Browse files
committed
Add Grain iterator checkpoint/resume and fix num_batches
Enable deterministic mid-epoch resume for Grain datasets by saving and restoring the DatasetIterator state through BackupAndRestore. Also fix num_batches to return the actual count for finite MapDatasets so progress bars work correctly.
1 parent 18a18d3 commit 68ab6e9

File tree

8 files changed

+248
-25
lines changed

8 files changed

+248
-25
lines changed

keras/src/backend/jax/trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,10 @@ def fit(
418418
self._symbolic_build(iterator=epoch_iterator)
419419
epoch_iterator.reset()
420420

421+
# Expose the iterator so callbacks (e.g. BackupAndRestore) can
422+
# save / restore data-pipeline state for fault tolerance.
423+
self._epoch_iterator = epoch_iterator
424+
421425
# Container that configures and calls callbacks.
422426
if not isinstance(callbacks, callbacks_module.CallbackList):
423427
callbacks = callbacks_module.CallbackList(
@@ -541,6 +545,7 @@ def fit(
541545
# are done.
542546
if getattr(self, "_eval_epoch_iterator", None) is not None:
543547
del self._eval_epoch_iterator
548+
self._epoch_iterator = None
544549
if training_finished:
545550
callbacks.on_train_end(logs=training_logs)
546551
self._jax_state = None

keras/src/backend/tensorflow/trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,10 @@ def fit(
372372
self._maybe_symbolic_build(iterator=epoch_iterator)
373373
epoch_iterator.reset()
374374

375+
# Expose the iterator so callbacks (e.g. BackupAndRestore) can
376+
# save / restore data-pipeline state for fault tolerance.
377+
self._epoch_iterator = epoch_iterator
378+
375379
# Container that configures and calls callbacks.
376380
if not isinstance(callbacks, callbacks_module.CallbackList):
377381
callbacks = callbacks_module.CallbackList(
@@ -449,6 +453,7 @@ def fit(
449453
# If _eval_epoch_iterator exists, delete it after all epochs are done.
450454
if getattr(self, "_eval_epoch_iterator", None) is not None:
451455
del self._eval_epoch_iterator
456+
self._epoch_iterator = None
452457
callbacks.on_train_end(logs=training_logs)
453458
return self.history
454459

keras/src/backend/torch/trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,10 @@ def fit(
236236
self._symbolic_build(iterator=epoch_iterator)
237237
epoch_iterator.reset()
238238

239+
# Expose the iterator so callbacks (e.g. BackupAndRestore) can
240+
# save / restore data-pipeline state for fault tolerance.
241+
self._epoch_iterator = epoch_iterator
242+
239243
# Container that configures and calls callbacks.
240244
if not isinstance(callbacks, callbacks_module.CallbackList):
241245
callbacks = callbacks_module.CallbackList(
@@ -324,6 +328,7 @@ def fit(
324328
# If _eval_epoch_iterator exists, delete it after all epochs are done.
325329
if getattr(self, "_eval_epoch_iterator", None) is not None:
326330
del self._eval_epoch_iterator
331+
self._epoch_iterator = None
327332
callbacks.on_train_end(logs=training_logs)
328333
return self.history
329334

keras/src/callbacks/backup_and_restore.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,14 @@ def _load_model(self):
151151
epoch = training_metadata["epoch"]
152152
self.model._initial_epoch = epoch
153153

154+
# Restore data-pipeline iterator state when available (e.g.
155+
# Grain datasets support deterministic mid-epoch resume).
156+
iterator_state = training_metadata.get("iterator_state")
157+
if iterator_state is not None:
158+
epoch_iterator = getattr(self.model, "_epoch_iterator", None)
159+
if epoch_iterator is not None:
160+
epoch_iterator.set_iterator_state(iterator_state)
161+
154162
def on_epoch_end(self, epoch, logs=None):
155163
self._current_epoch = epoch + 1
156164
self._last_batch_seen = 0
@@ -187,6 +195,13 @@ def _save_model(self):
187195
"epoch": self._current_epoch,
188196
"batch": self._last_batch_seen,
189197
}
198+
# Persist data-pipeline iterator state when the adapter
199+
# supports it (e.g. Grain).
200+
epoch_iterator = getattr(self.model, "_epoch_iterator", None)
201+
if epoch_iterator is not None:
202+
iterator_state = epoch_iterator.get_iterator_state()
203+
if iterator_state is not None:
204+
training_metadata["iterator_state"] = iterator_state
190205
f.write(json.dumps(training_metadata))
191206

192207
def _should_save_on_batch(self, batch):

keras/src/trainers/data_adapters/data_adapter.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,24 @@ def partial_batch_size(self):
8888
"""
8989
raise NotImplementedError
9090

91+
def get_iterator_state(self):
92+
"""Return serializable state for the current data iterator.
93+
94+
Adapters that support deterministic checkpoint/resume (e.g. Grain)
95+
override this to return a small dict that can reconstruct the
96+
iterator position. The default returns ``None`` (not supported).
97+
"""
98+
return None
99+
100+
def set_iterator_state(self, state):
101+
"""Restore the data iterator to a previously saved state.
102+
103+
Called before the next ``iter()`` call so the iterator resumes
104+
from the saved position. Adapters that do not support
105+
checkpointing ignore this (the default is a no-op).
106+
"""
107+
pass
108+
91109
def on_epoch_begin(self):
92110
"""A hook called before each epoch."""
93111
pass

keras/src/trainers/data_adapters/grain_dataset_adapter.py

Lines changed: 84 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import itertools
2+
import sys
23

34
import numpy as np
45

@@ -9,17 +10,43 @@
910
from keras.src.utils.module_utils import tensorflow as tf
1011

1112

13+
class _TrackableIterable:
14+
"""Wrapper that captures the live ``DatasetIterator`` on ``iter()``.
15+
16+
When the ``EpochIterator`` calls ``iter()`` on the object returned by
17+
``get_numpy_iterator()`` / ``get_jax_iterator()``, this wrapper
18+
stores the resulting iterator on the adapter so that
19+
``get_iterator_state()`` can reach it. If a pending state was
20+
previously set via ``set_iterator_state()``, it is applied to the
21+
fresh iterator immediately.
22+
"""
23+
24+
def __init__(self, dataset, adapter):
25+
self._dataset = dataset
26+
self._adapter = adapter
27+
28+
def __iter__(self):
29+
it = iter(self._dataset)
30+
self._adapter._live_iterator = it
31+
if self._adapter._pending_iterator_state is not None:
32+
if hasattr(it, "set_state"):
33+
it.set_state(self._adapter._pending_iterator_state)
34+
self._adapter._pending_iterator_state = None
35+
return it
36+
37+
1238
class GrainDatasetAdapter(DataAdapter):
13-
"""Adapter that handles `grain.DataLoader`, `grain.MapDataset` and
14-
`grain.IterDataset`.
39+
"""Adapter that handles ``grain.DataLoader``, ``grain.MapDataset`` and
40+
``grain.IterDataset``.
1541
"""
1642

1743
def __init__(self, dataset):
1844
"""Initialize the GrainDatasetAdapter.
1945
2046
Args:
2147
dataset: A Grain dataset instance. Must be one of
22-
`grain.DataLoader`, `grain.MapDataset`, or `grain.IterDataset`.
48+
``grain.DataLoader``, ``grain.MapDataset``, or
49+
``grain.IterDataset``.
2350
"""
2451

2552
if not isinstance(
@@ -32,17 +59,19 @@ def __init__(self, dataset):
3259
)
3360

3461
self._dataset = dataset
62+
self._live_iterator = None
63+
self._pending_iterator_state = None
3564

3665
batch_size, output_signature = self._get_dataset_info(dataset)
3766
self._batch_size = batch_size
3867
self._output_signature = output_signature
3968
self._output_tf_signature = None
4069

4170
def _get_dataset_info(self, dataset):
42-
"""Get the `batch_size` and `output_signature` from the dataset.
71+
"""Get the ``batch_size`` and ``output_signature`` from the dataset.
4372
44-
We use a small list of batches to infer the `batch_size` and
45-
`output_signature`.
73+
We use a small list of batches to infer the ``batch_size`` and
74+
``output_signature``.
4675
"""
4776
batches = list(
4877
itertools.islice(
@@ -73,9 +102,9 @@ def convert_to_numpy(x):
73102
if isinstance(x, (np.ndarray, SharedMemoryArrayMetadata)):
74103
return x
75104
else:
76-
# Using `__array__` should handle `tf.Tensor`, `jax.np.ndarray`,
77-
# `torch.Tensor`, as well as any other tensor-like object that
78-
# has added numpy support.
105+
# Using ``__array__`` should handle ``tf.Tensor``,
106+
# ``jax.np.ndarray``, ``torch.Tensor``, as well as any
107+
# other tensor-like object that has added numpy support.
79108
if hasattr(x, "__array__"):
80109
if data_adapter_utils.is_torch_tensor(x):
81110
x = x.cpu()
@@ -90,20 +119,21 @@ def map(self, x):
90119

91120
if isinstance(self._dataset, (grain.MapDataset, grain.IterDataset)):
92121
dataset = self._dataset.map(ConvertToNumpy())
122+
return _TrackableIterable(dataset, self)
93123
else:
94-
# Instantiate a new `DataLoader`.
124+
# Instantiate a new ``DataLoader``.
95125
dataset = grain.DataLoader(
96126
data_source=self._dataset._data_source,
97127
sampler=self._dataset._sampler,
98-
# Append `ConvertToNumpy`.
128+
# Append ``ConvertToNumpy``.
99129
operations=list(self._dataset._operations) + [ConvertToNumpy()],
100130
worker_count=self._dataset._multiprocessing_options.num_workers,
101131
worker_buffer_size=self._dataset._multiprocessing_options.per_worker_buffer_size,
102132
shard_options=self._dataset._shard_options,
103133
read_options=self._dataset._read_options,
104134
enable_profiling=self._dataset._multiprocessing_options.enable_profiling,
105135
)
106-
return dataset
136+
return dataset
107137

108138
def get_jax_iterator(self):
109139
def convert_to_jax_compatible(x):
@@ -121,12 +151,13 @@ def map(self, x):
121151

122152
if isinstance(self._dataset, (grain.MapDataset, grain.IterDataset)):
123153
dataset = self._dataset.map(ConvertToJaxCompatible())
154+
return _TrackableIterable(dataset, self)
124155
else:
125-
# Instantiate a new `DataLoader`.
156+
# Instantiate a new ``DataLoader``.
126157
dataset = grain.DataLoader(
127158
data_source=self._dataset._data_source,
128159
sampler=self._dataset._sampler,
129-
# Append `ConvertToJaxCompatible`.
160+
# Append ``ConvertToJaxCompatible``.
130161
operations=list(self._dataset._operations)
131162
+ [ConvertToJaxCompatible()],
132163
worker_count=self._dataset._multiprocessing_options.num_workers,
@@ -135,7 +166,7 @@ def map(self, x):
135166
read_options=self._dataset._read_options,
136167
enable_profiling=self._dataset._multiprocessing_options.enable_profiling,
137168
)
138-
return dataset
169+
return dataset
139170

140171
def get_tf_dataset(self):
141172
def convert_to_tf(x):
@@ -151,7 +182,7 @@ class ConvertToTF(grain.transforms.Map):
151182
def map(self, x):
152183
return tree.map_structure(convert_to_tf, x)
153184

154-
# `tf.data.Dataset.from_generator` does not support lists as output.
185+
# ``tf.data.Dataset.from_generator`` does not support lists as output.
155186
# We convert lists to tuples.
156187
class ListToTuple(grain.transforms.Map):
157188
def map(self, x):
@@ -161,11 +192,11 @@ def map(self, x):
161192
dataset = self._dataset.map(ConvertToTF())
162193
dataset = dataset.map(ListToTuple())
163194
else:
164-
# Instantiate a new `DataLoader`.
195+
# Instantiate a new ``DataLoader``.
165196
dataset = grain.DataLoader(
166197
data_source=self._dataset._data_source,
167198
sampler=self._dataset._sampler,
168-
# Append `ConvertToTF` and `ListToTuple`.
199+
# Append ``ConvertToTF`` and ``ListToTuple``.
169200
operations=list(self._dataset._operations)
170201
+ [ConvertToTF(), ListToTuple()],
171202
worker_count=self._dataset._multiprocessing_options.num_workers,
@@ -196,13 +227,46 @@ def __init__(self, iterable):
196227
def __iter__(self):
197228
return iter(self.iterable)
198229

199-
# `batch_size=None` indicates that we should not re-batch
230+
if isinstance(self._dataset, (grain.MapDataset, grain.IterDataset)):
231+
iterable = _TrackableIterable(self._dataset, self)
232+
else:
233+
iterable = self._dataset
234+
235+
# ``batch_size=None`` indicates that we should not re-batch
200236
return torch_data.DataLoader(
201-
ConverterIterableDataset(self._dataset), batch_size=None
237+
ConverterIterableDataset(iterable), batch_size=None
202238
)
203239

240+
# ------------------------------------------------------------------
241+
# Iterator checkpoint / resume
242+
# ------------------------------------------------------------------
243+
244+
def get_iterator_state(self):
245+
if self._live_iterator is not None and hasattr(
246+
self._live_iterator, "get_state"
247+
):
248+
return self._live_iterator.get_state()
249+
return None
250+
251+
def set_iterator_state(self, state):
252+
if state is not None:
253+
self._pending_iterator_state = state
254+
255+
# ------------------------------------------------------------------
256+
# Metadata
257+
# ------------------------------------------------------------------
258+
204259
@property
205260
def num_batches(self):
261+
if isinstance(self._dataset, grain.MapDataset):
262+
try:
263+
length = len(self._dataset)
264+
except TypeError:
265+
return None
266+
# ``repeat(None)`` sets length to ``sys.maxsize``.
267+
if length >= sys.maxsize:
268+
return None
269+
return length
206270
return None
207271

208272
@property

0 commit comments

Comments
 (0)