Skip to content

Commit f10c121

Browse files
Grain Teamcopybara-github
authored andcommitted
Internal
PiperOrigin-RevId: 833971814
1 parent 20b4532 commit f10c121

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

grain/_src/python/dataset/transformations/zip.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@ def __getitem__(self, index):
4444
return self.slice(index)
4545
return tuple(p[index] for p in self._parents)
4646

47+
def _getitems(self, indices: Sequence[int]):
48+
# p._getitems(indices) returns a list of elements of the requested indices.
49+
# We get a list of lists that we need to zip.
50+
parent_elements = [
51+
p._getitems(indices) for p in self.parents # pylint: disable=protected-access
52+
]
53+
return list(zip(*parent_elements))
54+
4755
def __str__(self) -> str:
4856
return f"ZipMapDataset(parents={self._parents}"
4957

grain/_src/python/dataset/transformations/zip_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,19 @@ def test_getitem(self, ds_idx_list):
5555
for i in range(20):
5656
self.assertEqual(ds[i], tuple(i + ds_idx for ds_idx in ds_idx_list))
5757

58+
@parameterized.parameters(
59+
{"ds_idx_list": x}
60+
for x in list(itertools.combinations(range(3), 3))
61+
+ list(itertools.combinations(range(3), 2))
62+
+ list(itertools.combinations(range(3), 1))
63+
)
64+
def test_getitems(self, ds_idx_list):
65+
ds = zip_ds.ZipMapDataset(parents=[self.ds_list[i] for i in ds_idx_list])
66+
indices = [0, 5, 19]
67+
expected_elements = [ds[i] for i in indices]
68+
actual_elements = ds._getitems(indices)
69+
self.assertEqual(expected_elements, actual_elements)
70+
5871

5972
class ZipIterDatasetTest(parameterized.TestCase):
6073

0 commit comments

Comments
 (0)