Skip to content

Commit 784dc7e

Browse files
Grain Teamcopybara-github
authored andcommitted
Internal
PiperOrigin-RevId: 833422113
1 parent 18c543f commit 784dc7e

File tree

2 files changed

+68
-10
lines changed

2 files changed

+68
-10
lines changed

grain/_src/python/dataset/transformations/map.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -161,17 +161,34 @@ def _transform_name(self):
161161
def __str__(self) -> str:
162162
return f"RandomMapMapDataset(transform={self._transform_name})"
163163

164+
def _random_map_element(self, element: Any, index: int) -> T:
165+
if element is None:
166+
return None
167+
rng = self._rng_pool.acquire_rng(index)
168+
element = self._map_fn(element, rng)
169+
self._rng_pool.release_rng(rng)
170+
return element
171+
164172
def __getitem__(self, index):
165173
if isinstance(index, slice):
166174
return self.slice(index)
167175
element = self._parent[index]
168176
with self._stats.record_self_time():
169-
if element is None:
170-
return None
171-
rng = self._rng_pool.acquire_rng(index)
172-
element = self._map_fn(element, rng)
173-
self._rng_pool.release_rng(rng)
174-
return self._stats.record_output_spec(element)
177+
mapped_element = self._random_map_element(element, index)
178+
return (
179+
self._stats.record_output_spec(mapped_element)
180+
if mapped_element is not None
181+
else None
182+
)
183+
184+
def _getitems(self, indices: Sequence[int]):
185+
elements = self._parent._getitems(indices) # pylint: disable=protected-access
186+
with self._stats.record_self_time(num_elements=len(indices)):
187+
processed_elements = [
188+
self._random_map_element(element, index)
189+
for element, index in zip(elements, indices)
190+
]
191+
return self._stats.record_output_spec_for_batch(processed_elements)
175192

176193

177194
class MapWithIndexMapDataset(dataset.MapDataset[T]):
@@ -201,14 +218,31 @@ def __len__(self) -> int:
201218
def __str__(self) -> str:
202219
return f"MapWithIndexMapDataset(transform={self._transform_name})"
203220

221+
def _map_with_index_fn(self, index: int, element: Any) -> T:
222+
if element is None:
223+
return None
224+
return self._map_fn(index, element)
225+
204226
def __getitem__(self, index):
205227
if isinstance(index, slice):
206228
return self.slice(index)
207229
element = self._parent[index]
208230
with self._stats.record_self_time():
209-
if element is None:
210-
return None
211-
return self._stats.record_output_spec(self._map_fn(index, element))
231+
mapped_element = self._map_with_index_fn(index, element)
232+
return (
233+
self._stats.record_output_spec(mapped_element)
234+
if mapped_element is not None
235+
else None
236+
)
237+
238+
def _getitems(self, indices: Sequence[int]):
239+
elements = self._parent._getitems(indices) # pylint: disable=protected-access
240+
with self._stats.record_self_time(num_elements=len(indices)):
241+
processed_elements = [
242+
self._map_with_index_fn(index, element)
243+
for index, element in zip(indices, elements)
244+
]
245+
return self._stats.record_output_spec_for_batch(processed_elements)
212246

213247

214248
class _MapDatasetIterator(dataset.DatasetIterator[T]):

grain/_src/python/dataset/transformations/map_test.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,19 @@ def test_random_map_raises_with_no_seed(self, transform):
204204
with self.assertRaises(ValueError):
205205
map_ds.RandomMapMapDataset(self.range_ds, transform)
206206

207+
@parameterized.parameters(
208+
dict(indices=list(range(10))),
209+
dict(indices=[0, 4, 8]),
210+
dict(indices=[3, 1, 7, 2]),
211+
)
212+
def test_random_map_data_with_get_items(self, indices):
213+
ds = map_ds.RandomMapMapDataset(
214+
self.range_ds, RandomMapWithDeterminismTransform(), seed=42
215+
)
216+
expected_data = [ds[i] for i in indices]
217+
actual_data = ds._getitems(indices)
218+
np.testing.assert_equal(expected_data, actual_data)
219+
207220
@parameterized.parameters(0, 1, 42)
208221
def test_random_map_checkpointing(self, random_map_seed):
209222
ds = map_ds.RandomMapMapDataset(
@@ -309,7 +322,7 @@ def test_random_map_checkpointing(self, random_map_seed):
309322
assert_equal_output_after_checkpoint(ds)
310323

311324

312-
class MapWithIndexMapDatasetTest(absltest.TestCase):
325+
class MapWithIndexMapDatasetTest(parameterized.TestCase):
313326

314327
def setUp(self):
315328
super().setUp()
@@ -328,6 +341,17 @@ def test_getitem(self):
328341
self.assertEqual(ds[4], (4, 4))
329342
self.assertEqual(ds[5], (5, 5))
330343

344+
@parameterized.parameters(
345+
dict(indices=list(range(3, 6))),
346+
dict(indices=[3, 4, 5]),
347+
dict(indices=[4, 3, 5]),
348+
)
349+
def test_map_with_index_data_with_get_items(self, indices):
350+
ds = map_ds.MapWithIndexMapDataset(self.range_ds, AddIndexTransform())
351+
expected_data = [ds[i] for i in indices]
352+
actual_data = ds._getitems(indices)
353+
self.assertEqual(expected_data, actual_data)
354+
331355

332356
class MapWithIndexIterDatasetTest(absltest.TestCase):
333357

0 commit comments

Comments
 (0)