@@ -161,17 +161,34 @@ def _transform_name(self):
161161 def __str__ (self ) -> str :
162162 return f"RandomMapMapDataset(transform={ self ._transform_name } )"
163163
164+ def _random_map_element (self , element : Any , index : int ) -> T :
165+ if element is None :
166+ return None
167+ rng = self ._rng_pool .acquire_rng (index )
168+ element = self ._map_fn (element , rng )
169+ self ._rng_pool .release_rng (rng )
170+ return element
171+
164172 def __getitem__ (self , index ):
165173 if isinstance (index , slice ):
166174 return self .slice (index )
167175 element = self ._parent [index ]
168176 with self ._stats .record_self_time ():
169- if element is None :
170- return None
171- rng = self ._rng_pool .acquire_rng (index )
172- element = self ._map_fn (element , rng )
173- self ._rng_pool .release_rng (rng )
174- return self ._stats .record_output_spec (element )
177+ mapped_element = self ._random_map_element (element , index )
178+ return (
179+ self ._stats .record_output_spec (mapped_element )
180+ if mapped_element is not None
181+ else None
182+ )
183+
184+ def _getitems (self , indices : Sequence [int ]):
185+ elements = self ._parent ._getitems (indices ) # pylint: disable=protected-access
186+ with self ._stats .record_self_time (num_elements = len (indices )):
187+ processed_elements = [
188+ self ._random_map_element (element , index )
189+ for element , index in zip (elements , indices )
190+ ]
191+ return self ._stats .record_output_spec_for_batch (processed_elements )
175192
176193
177194class MapWithIndexMapDataset (dataset .MapDataset [T ]):
@@ -201,14 +218,31 @@ def __len__(self) -> int:
201218 def __str__ (self ) -> str :
202219 return f"MapWithIndexMapDataset(transform={ self ._transform_name } )"
203220
221+ def _map_with_index_fn (self , index : int , element : Any ) -> T :
222+ if element is None :
223+ return None
224+ return self ._map_fn (index , element )
225+
204226 def __getitem__ (self , index ):
205227 if isinstance (index , slice ):
206228 return self .slice (index )
207229 element = self ._parent [index ]
208230 with self ._stats .record_self_time ():
209- if element is None :
210- return None
211- return self ._stats .record_output_spec (self ._map_fn (index , element ))
231+ mapped_element = self ._map_with_index_fn (index , element )
232+ return (
233+ self ._stats .record_output_spec (mapped_element )
234+ if mapped_element is not None
235+ else None
236+ )
237+
238+ def _getitems (self , indices : Sequence [int ]):
239+ elements = self ._parent ._getitems (indices ) # pylint: disable=protected-access
240+ with self ._stats .record_self_time (num_elements = len (indices )):
241+ processed_elements = [
242+ self ._map_with_index_fn (index , element )
243+ for index , element in zip (indices , elements )
244+ ]
245+ return self ._stats .record_output_spec_for_batch (processed_elements )
212246
213247
214248class _MapDatasetIterator (dataset .DatasetIterator [T ]):
0 commit comments