File tree Expand file tree Collapse file tree 2 files changed +21
-0
lines changed
grain/_src/python/dataset/transformations Expand file tree Collapse file tree 2 files changed +21
-0
lines changed Original file line number Diff line number Diff line change @@ -44,6 +44,14 @@ def __getitem__(self, index):
4444 return self .slice (index )
4545 return tuple (p [index ] for p in self ._parents )
4646
47+ def _getitems (self , indices : Sequence [int ]):
48+ # p._getitems(indices) returns a list of elements of the requested indices.
49+ # We get a list of lists that we need to zip.
50+ parent_elements = [
51+ p ._getitems (indices ) for p in self .parents # pylint: disable=protected-access
52+ ]
53+ return list (zip (* parent_elements ))
54+
4755 def __str__ (self ) -> str :
4856 return f"ZipMapDataset(parents={ self ._parents } "
4957
Original file line number Diff line number Diff line change @@ -55,6 +55,19 @@ def test_getitem(self, ds_idx_list):
5555 for i in range (20 ):
5656 self .assertEqual (ds [i ], tuple (i + ds_idx for ds_idx in ds_idx_list ))
5757
58+ @parameterized .parameters (
59+ {"ds_idx_list" : x }
60+ for x in list (itertools .combinations (range (3 ), 3 ))
61+ + list (itertools .combinations (range (3 ), 2 ))
62+ + list (itertools .combinations (range (3 ), 1 ))
63+ )
64+ def test_getitems (self , ds_idx_list ):
65+ ds = zip_ds .ZipMapDataset (parents = [self .ds_list [i ] for i in ds_idx_list ])
66+ indices = [0 , 5 , 19 ]
67+ expected_elements = [ds [i ] for i in indices ]
68+ actual_elements = ds ._getitems (indices )
69+ self .assertEqual (expected_elements , actual_elements )
70+
5871
5972class ZipIterDatasetTest (parameterized .TestCase ):
6073
You can’t perform that action at this time.
0 commit comments