Skip to content

Commit 87c86ac

Browse files
Grain Teamcopybara-github
authored andcommitted
Internal
PiperOrigin-RevId: 836794431
1 parent 8001385 commit 87c86ac

File tree

3 files changed

+138
-3
lines changed

3 files changed

+138
-3
lines changed

grain/_src/python/dataset/transformations/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ py_test(
142142
"//grain/_src/python/dataset",
143143
"//grain/_src/python/dataset:base",
144144
"@abseil-py//absl/testing:absltest",
145+
"@abseil-py//absl/testing:parameterized",
145146
"@pypi//numpy:pkg",
146147
],
147148
)

grain/_src/python/dataset/transformations/mix.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from __future__ import annotations
1717

1818
import bisect
19+
import collections
1920
from collections.abc import Sequence
2021
import sys
2122
from typing import Any, Mapping, TypeVar
@@ -132,10 +133,38 @@ def __getitem__(self, index):
132133
if sys.version_info >= (3, 11):
133134
e.add_note(
134135
f"Exception caught while processing dataset @ {dataset_index=},"
135-
f" {index_in_dataset=}"
136136
)
137137
raise e
138138

139+
def _getitems(self, indices: Sequence[int]) -> Sequence[T | None]:
140+
"""Returns a sequence of elements corresponding to the given indices."""
141+
142+
with self._stats.record_self_time(num_elements=len(indices)):
143+
parent_and_key_pairs = (self._selection_map[i] for i in indices)
144+
145+
# Group indices by parent.
146+
# keys_by_parent will store {parent_idx: [key1, key2, ...]}
147+
# where key is the index in the parent dataset.
148+
keys_by_parent = collections.defaultdict(list)
149+
# original_indices_by_parent will store {parent_idx: [original_idx1, ...]}
150+
# where original_idx is the position in the input `indices` list.
151+
original_indices_by_parent = collections.defaultdict(list)
152+
153+
for i, (parent_idx, key) in enumerate(parent_and_key_pairs):
154+
keys_by_parent[parent_idx].append(key)
155+
original_indices_by_parent[parent_idx].append(i)
156+
157+
mixed_elements = [None] * len(indices)
158+
for parent_idx, keys in keys_by_parent.items():
159+
# Call _getitems for each parent.
160+
parent_elements = self.parents[parent_idx]._getitems(keys) # pylint: disable=protected-access
161+
original_indices = original_indices_by_parent[parent_idx]
162+
# Assign elements to their proper index in the final mixed dataset.
163+
for i, element in enumerate(parent_elements):
164+
mixed_elements[original_indices[i]] = element
165+
166+
return self._stats.record_output_spec_for_batch(mixed_elements)
167+
139168

140169
class _MixedDatasetIterator(dataset.DatasetIterator[T]):
141170
"""Iterator that mixes elements from iterators based on given proportions.

grain/_src/python/dataset/transformations/mix_test.py

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Callable, Tuple
1818

1919
from absl.testing import absltest
20+
from absl.testing import parameterized
2021
from grain._src.python.dataset import base
2122
from grain._src.python.dataset import dataset
2223
from grain._src.python.dataset.transformations import mix
@@ -154,7 +155,7 @@ def _subset_and_shuffle_dataset(index):
154155
self.assertEqual(expected_dataset, unrolled_dataset)
155156

156157

157-
class MixedMapDatasetTest(absltest.TestCase):
158+
class MixedMapDatasetTest(parameterized.TestCase):
158159

159160
def setUp(self):
160161
super().setUp()
@@ -169,7 +170,7 @@ def test_len(self):
169170
# Equal proportions.
170171
ds = mix.MixedMapDataset([ds1, ds2, ds3])
171172
self.assertLen(ds, 15)
172-
# Heigher weight for second dataset.
173+
# Higher weight for second dataset.
173174
ds = mix.MixedMapDataset([ds1, ds2, ds3], proportions=[1, 2, 1])
174175
self.assertLen(ds, 5 + 10 + 5)
175176

@@ -333,6 +334,99 @@ def _subset_and_shuffle_dataset(index):
333334

334335
self.assertEqual(list(ds), expected_dataset)
335336

337+
@parameterized.named_parameters(
338+
dict(
339+
testcase_name="equal_proportions",
340+
proportions=[1, 1],
341+
expected_values=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
342+
),
343+
dict(
344+
testcase_name="float_proportions",
345+
proportions=[0.75, 0.25],
346+
expected_values=[0, 2, 4, 1, 6, 8, 0, 3, 2, 4],
347+
),
348+
dict(
349+
testcase_name="integer_proportions",
350+
proportions=[1, 2],
351+
expected_values=[0, 1, 3, 2, 5, 7, 4, 9, 1, 6],
352+
),
353+
)
354+
def test_getitems_mixing(self, proportions, expected_values):
355+
mixed_ds = mix.MixedMapDataset(
356+
parents=[self.even_ds, self.odd_ds],
357+
proportions=proportions,
358+
)
359+
# Request a batch of indices that results in interleaved elements
360+
indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
361+
actual_values = mixed_ds._getitems(indices)
362+
self.assertEqual(expected_values, actual_values)
363+
364+
def test_getitems_interleaved_map(self):
365+
def _inteleaved_dataset(index):
366+
if index > 9:
367+
raise IndexError("index our of range")
368+
ds = index % 2
369+
ds_index = index // 2
370+
return (ds, ds_index)
371+
372+
interleaved_map = ExplicitSelectionMap(10, _inteleaved_dataset)
373+
374+
ds = mix.MixedMapDataset(
375+
parents=[self.even_ds, self.odd_ds], selection_map=interleaved_map
376+
)
377+
378+
expected_dataset = list(range(10))
379+
self.assertEqual(ds._getitems(list(range(10))), expected_dataset)
380+
381+
def test_getitems_sequential_map(self):
382+
def _sequential_dataset(index):
383+
if index > 9:
384+
raise IndexError("index our of range")
385+
if index < 5:
386+
ds = 0
387+
else:
388+
ds = 1
389+
ds_index = index % 5
390+
return (ds, ds_index)
391+
392+
sequential_map = ExplicitSelectionMap(10, _sequential_dataset)
393+
394+
ds = mix.MixedMapDataset(
395+
parents=[self.even_ds, self.odd_ds], selection_map=sequential_map
396+
)
397+
398+
expected_dataset = list(range(0, 10, 2)) + list(range(1, 10, 2))
399+
self.assertEqual(ds._getitems(list(range(10))), expected_dataset)
400+
401+
def test_getitems_subset_and_shuffle_map(self):
402+
first_epoch = [0, 1, 2, 3, 4]
403+
second_epoch = [1, 0, 3, 2, 4]
404+
405+
expected_dataset = first_epoch + second_epoch
406+
407+
def _subset_and_shuffle_dataset(index):
408+
if index > 9:
409+
raise IndexError("index our of range")
410+
if index < 5:
411+
ds = index % 2
412+
ds_index = index // 2
413+
else:
414+
mapped_index = second_epoch[index - 5]
415+
ds = mapped_index % 2
416+
ds_index = mapped_index // 2
417+
return (ds, ds_index)
418+
419+
subset_and_shuffle_map = ExplicitSelectionMap(
420+
10, _subset_and_shuffle_dataset
421+
)
422+
423+
ds = mix.MixedMapDataset(
424+
parents=[self.even_ds, self.odd_ds],
425+
selection_map=subset_and_shuffle_map,
426+
)
427+
428+
self.assertEqual(ds._getitems(list(range(10))), expected_dataset)
429+
336430

337431
class MixedIterDatasetTest(absltest.TestCase):
338432

@@ -627,6 +721,17 @@ def test_cannot_concatenate_infinite_datasets(self):
627721
):
628722
_ = mix._ConcatSelectionMap([zeros, ones])
629723

724+
def test_getitems_concatenate_finite_datasets(self):
725+
evens = dataset.MapDataset.range(0, 10, 2)
726+
odds = dataset.MapDataset.range(1, 10, 2)
727+
ds = mix.ConcatenateMapDataset([evens, odds])
728+
self.assertLen(evens, 5)
729+
self.assertLen(odds, 5)
730+
self.assertLen(ds, 10)
731+
732+
expected_values = [0, 2, 4, 6, 8, 1, 3, 5, 7, 9]
733+
self.assertEqual(ds._getitems(list(range(10))), expected_values)
734+
630735

631736
if __name__ == "__main__":
632737
absltest.main()

0 commit comments

Comments
 (0)