1717from typing import Callable , Tuple
1818
1919from absl .testing import absltest
20+ from absl .testing import parameterized
2021from grain ._src .python .dataset import base
2122from grain ._src .python .dataset import dataset
2223from grain ._src .python .dataset .transformations import mix
@@ -154,7 +155,7 @@ def _subset_and_shuffle_dataset(index):
154155 self .assertEqual (expected_dataset , unrolled_dataset )
155156
156157
157- class MixedMapDatasetTest (absltest .TestCase ):
158+ class MixedMapDatasetTest (parameterized .TestCase ):
158159
159160 def setUp (self ):
160161 super ().setUp ()
@@ -169,7 +170,7 @@ def test_len(self):
169170 # Equal proportions.
170171 ds = mix .MixedMapDataset ([ds1 , ds2 , ds3 ])
171172 self .assertLen (ds , 15 )
172- # Heigher weight for second dataset.
173+ # Higher weight for second dataset.
173174 ds = mix .MixedMapDataset ([ds1 , ds2 , ds3 ], proportions = [1 , 2 , 1 ])
174175 self .assertLen (ds , 5 + 10 + 5 )
175176
@@ -333,6 +334,99 @@ def _subset_and_shuffle_dataset(index):
333334
334335 self .assertEqual (list (ds ), expected_dataset )
335336
337+ @parameterized .named_parameters (
338+ dict (
339+ testcase_name = "equal_proportions" ,
340+ proportions = [1 , 1 ],
341+ expected_values = [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ],
342+ ),
343+ dict (
344+ testcase_name = "float_proportions" ,
345+ proportions = [0.75 , 0.25 ],
346+ expected_values = [0 , 2 , 4 , 1 , 6 , 8 , 0 , 3 , 2 , 4 ],
347+ ),
348+ dict (
349+ testcase_name = "integer_proportions" ,
350+ proportions = [1 , 2 ],
351+ expected_values = [0 , 1 , 3 , 2 , 5 , 7 , 4 , 9 , 1 , 6 ],
352+ ),
353+ )
354+ def test_getitems_mixing (self , proportions , expected_values ):
355+ mixed_ds = mix .MixedMapDataset (
356+ parents = [self .even_ds , self .odd_ds ],
357+ proportions = proportions ,
358+ )
359+ # Request a batch of indices that results in interleaved elements
360+ indices = [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ]
361+ actual_values = mixed_ds ._getitems (indices )
362+ self .assertEqual (expected_values , actual_values )
363+
364+ def test_getitems_interleaved_map (self ):
365+ def _inteleaved_dataset (index ):
366+ if index > 9 :
367+ raise IndexError ("index our of range" )
368+ ds = index % 2
369+ ds_index = index // 2
370+ return (ds , ds_index )
371+
372+ interleaved_map = ExplicitSelectionMap (10 , _inteleaved_dataset )
373+
374+ ds = mix .MixedMapDataset (
375+ parents = [self .even_ds , self .odd_ds ], selection_map = interleaved_map
376+ )
377+
378+ expected_dataset = list (range (10 ))
379+ self .assertEqual (ds ._getitems (list (range (10 ))), expected_dataset )
380+
381+ def test_getitems_sequential_map (self ):
382+ def _sequential_dataset (index ):
383+ if index > 9 :
384+ raise IndexError ("index our of range" )
385+ if index < 5 :
386+ ds = 0
387+ else :
388+ ds = 1
389+ ds_index = index % 5
390+ return (ds , ds_index )
391+
392+ sequential_map = ExplicitSelectionMap (10 , _sequential_dataset )
393+
394+ ds = mix .MixedMapDataset (
395+ parents = [self .even_ds , self .odd_ds ], selection_map = sequential_map
396+ )
397+
398+ expected_dataset = list (range (0 , 10 , 2 )) + list (range (1 , 10 , 2 ))
399+ self .assertEqual (ds ._getitems (list (range (10 ))), expected_dataset )
400+
401+ def test_getitems_subset_and_shuffle_map (self ):
402+ first_epoch = [0 , 1 , 2 , 3 , 4 ]
403+ second_epoch = [1 , 0 , 3 , 2 , 4 ]
404+
405+ expected_dataset = first_epoch + second_epoch
406+
407+ def _subset_and_shuffle_dataset (index ):
408+ if index > 9 :
409+ raise IndexError ("index our of range" )
410+ if index < 5 :
411+ ds = index % 2
412+ ds_index = index // 2
413+ else :
414+ mapped_index = second_epoch [index - 5 ]
415+ ds = mapped_index % 2
416+ ds_index = mapped_index // 2
417+ return (ds , ds_index )
418+
419+ subset_and_shuffle_map = ExplicitSelectionMap (
420+ 10 , _subset_and_shuffle_dataset
421+ )
422+
423+ ds = mix .MixedMapDataset (
424+ parents = [self .even_ds , self .odd_ds ],
425+ selection_map = subset_and_shuffle_map ,
426+ )
427+
428+ self .assertEqual (ds ._getitems (list (range (10 ))), expected_dataset )
429+
336430
337431class MixedIterDatasetTest (absltest .TestCase ):
338432
@@ -627,6 +721,17 @@ def test_cannot_concatenate_infinite_datasets(self):
627721 ):
628722 _ = mix ._ConcatSelectionMap ([zeros , ones ])
629723
724+ def test_getitems_concatenate_finite_datasets (self ):
725+ evens = dataset .MapDataset .range (0 , 10 , 2 )
726+ odds = dataset .MapDataset .range (1 , 10 , 2 )
727+ ds = mix .ConcatenateMapDataset ([evens , odds ])
728+ self .assertLen (evens , 5 )
729+ self .assertLen (odds , 5 )
730+ self .assertLen (ds , 10 )
731+
732+ expected_values = [0 , 2 , 4 , 6 , 8 , 1 , 3 , 5 , 7 , 9 ]
733+ self .assertEqual (ds ._getitems (list (range (10 ))), expected_values )
734+
630735
631736if __name__ == "__main__" :
632737 absltest .main ()
0 commit comments