Skip to content

Commit 4cb5e38

Browse files
authored
PyDataset now implements __iter__. (#21330)
It can now be iterated without failure when it has a finite number of batches. Previously, it was iterable by accident due to a legacy feature of Python that predates the introduction of `__iter__`. Specifically, Python would see `__getitem__` implemented and would iterate by passing sequential indices. However, `__getitem__` was expected to throw an `IndexError` to indicate the end of the iterator as `__len__` is ignored by Python. Instead, we return an iterator that knows about the length of the dataset. Fixes #21151
1 parent 81821e0 commit 4cb5e38

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

keras/src/trainers/data_adapters/py_dataset_adapter.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,24 @@ def __getitem__(self, index):
152152
Returns:
153153
A batch
154154
"""
155+
del index
155156
raise NotImplementedError
156157

158+
def __iter__(self):
159+
index_range = None
160+
try:
161+
num_batches = self.num_batches
162+
if num_batches is not None:
163+
index_range = range(num_batches)
164+
except NotImplementedError:
165+
pass
166+
167+
if index_range is None:
168+
index_range = itertools.count()
169+
170+
for index in index_range:
171+
yield self[index]
172+
157173
@property
158174
def num_batches(self):
159175
"""Number of batches in the PyDataset.

keras/src/trainers/data_adapters/py_dataset_adapter_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,3 +422,32 @@ def test_exception_reported(
422422
expected_exception_class, "Expected exception"
423423
):
424424
next(it)
425+
426+
def test_iterate_finite(self):
427+
py_dataset = ExamplePyDataset(
428+
np.ones((6, 11), dtype="int32"),
429+
np.zeros((6, 11), dtype="int32"),
430+
batch_size=2,
431+
)
432+
batches = [batch for batch in py_dataset]
433+
self.assertLen(batches, 3)
434+
435+
def test_iterate_infinite_with_none_num_batches(self):
436+
py_dataset = ExamplePyDataset(
437+
np.ones((6, 11), dtype="int32"),
438+
np.zeros((6, 11), dtype="int32"),
439+
batch_size=2,
440+
infinite=True,
441+
)
442+
for index, _ in enumerate(py_dataset):
443+
if index >= 10:
444+
break
445+
446+
def test_iterate_infinite_with_no_len(self):
447+
class NoLenDataset(py_dataset_adapter.PyDataset):
448+
def __getitem__(self, idx):
449+
yield np.ones((2, 11), dtype="int32")
450+
451+
for index, _ in enumerate(NoLenDataset()):
452+
if index >= 10:
453+
break

0 commit comments

Comments
 (0)