Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 43 additions & 9 deletions grain/_src/python/dataset/transformations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,17 +161,34 @@ def _transform_name(self):
def __str__(self) -> str:
return f"RandomMapMapDataset(transform={self._transform_name})"

def _random_map_element(self, element: Any, index: int) -> T:
if element is None:
return None
rng = self._rng_pool.acquire_rng(index)
element = self._map_fn(element, rng)
self._rng_pool.release_rng(rng)
return element

def __getitem__(self, index):
if isinstance(index, slice):
return self.slice(index)
element = self._parent[index]
with self._stats.record_self_time():
if element is None:
return None
rng = self._rng_pool.acquire_rng(index)
element = self._map_fn(element, rng)
self._rng_pool.release_rng(rng)
return self._stats.record_output_spec(element)
mapped_element = self._random_map_element(element, index)
return (
self._stats.record_output_spec(mapped_element)
if mapped_element is not None
else None
)

def _getitems(self, indices: Sequence[int]):
elements = self._parent._getitems(indices) # pylint: disable=protected-access
with self._stats.record_self_time(num_elements=len(indices)):
processed_elements = [
self._random_map_element(element, index)
for element, index in zip(elements, indices)
]
return self._stats.record_output_spec_for_batch(processed_elements)


class MapWithIndexMapDataset(dataset.MapDataset[T]):
Expand Down Expand Up @@ -201,14 +218,31 @@ def __len__(self) -> int:
def __str__(self) -> str:
return f"MapWithIndexMapDataset(transform={self._transform_name})"

def _map_with_index_fn(self, index: int, element: Any) -> T:
if element is None:
return None
return self._map_fn(index, element)

def __getitem__(self, index):
if isinstance(index, slice):
return self.slice(index)
element = self._parent[index]
with self._stats.record_self_time():
if element is None:
return None
return self._stats.record_output_spec(self._map_fn(index, element))
mapped_element = self._map_with_index_fn(index, element)
return (
self._stats.record_output_spec(mapped_element)
if mapped_element is not None
else None
)

def _getitems(self, indices: Sequence[int]):
elements = self._parent._getitems(indices) # pylint: disable=protected-access
with self._stats.record_self_time(num_elements=len(indices)):
processed_elements = [
self._map_with_index_fn(index, element)
for index, element in zip(indices, elements)
]
return self._stats.record_output_spec_for_batch(processed_elements)


class _MapDatasetIterator(dataset.DatasetIterator[T]):
Expand Down
26 changes: 25 additions & 1 deletion grain/_src/python/dataset/transformations/map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,19 @@ def test_random_map_raises_with_no_seed(self, transform):
with self.assertRaises(ValueError):
map_ds.RandomMapMapDataset(self.range_ds, transform)

@parameterized.parameters(
dict(indices=list(range(10))),
dict(indices=[0, 4, 8]),
dict(indices=[3, 1, 7, 2]),
)
def test_random_map_data_with_get_items(self, indices):
ds = map_ds.RandomMapMapDataset(
self.range_ds, RandomMapWithDeterminismTransform(), seed=42
)
expected_data = [ds[i] for i in indices]
actual_data = ds._getitems(indices)
np.testing.assert_equal(expected_data, actual_data)

@parameterized.parameters(0, 1, 42)
def test_random_map_checkpointing(self, random_map_seed):
ds = map_ds.RandomMapMapDataset(
Expand Down Expand Up @@ -309,7 +322,7 @@ def test_random_map_checkpointing(self, random_map_seed):
assert_equal_output_after_checkpoint(ds)


class MapWithIndexMapDatasetTest(absltest.TestCase):
class MapWithIndexMapDatasetTest(parameterized.TestCase):

def setUp(self):
super().setUp()
Expand All @@ -328,6 +341,17 @@ def test_getitem(self):
self.assertEqual(ds[4], (4, 4))
self.assertEqual(ds[5], (5, 5))

@parameterized.parameters(
dict(indices=list(range(3, 6))),
dict(indices=[3, 4, 5]),
dict(indices=[4, 3, 5]),
)
def test_map_with_index_data_with_get_items(self, indices):
ds = map_ds.MapWithIndexMapDataset(self.range_ds, AddIndexTransform())
expected_data = [ds[i] for i in indices]
actual_data = ds._getitems(indices)
self.assertEqual(expected_data, actual_data)


class MapWithIndexIterDatasetTest(absltest.TestCase):

Expand Down
Loading