@@ -317,7 +317,7 @@ def _single_map(d: Dict[str, Any], map_func: MapFunc) -> Optional[Dict[str, Any]
317317
318318def _map_mp_single (shard : HfDataset , map_func : MapFunc , queue : Queue , rank : int ):
319319 batch_size = 64
320- pre_i = 0
320+ pre_i = - 1
321321 result = []
322322 for i , d in enumerate (shard ):
323323 output = map_func (d )
@@ -336,7 +336,7 @@ def _map_mp_i(dataset: HfDataset, map_func: MapFunc, num_proc: int) -> Iterator[
336336 os .environ = pre_environ
337337 queue = manager .Queue ()
338338 async_results = []
339- shard_list = [dataset .shard (num_proc , i ) for i in range (num_proc )]
339+ shard_list = [dataset .shard (num_proc , i , contiguous = True ) for i in range (num_proc )]
340340 for i in range (num_proc ):
341341 async_results .append (pool .apply_async (_map_mp_single , args = (shard_list [i ], map_func , queue , i )))
342342 while True :
@@ -350,11 +350,12 @@ def _map_mp_i(dataset: HfDataset, map_func: MapFunc, num_proc: int) -> Iterator[
350350def _map_mp (dataset : HfDataset , map_func : MapFunc , num_proc : int ) -> List [Dict [str , Any ]]:
351351 # Solving the unordered problem
352352 num_proc = min (num_proc , len (dataset ))
353- data_list = [[]] * num_proc
353+ data_list = [[] for _ in range ( num_proc )]
354354 prog_bar = tqdm (total = len (dataset ), desc = f'Map (num_proc={ num_proc } )' , dynamic_ncols = True )
355355 for d in _map_mp_i (dataset , map_func , num_proc ):
356356 data_list [d [0 ]] += d [1 ]
357357 prog_bar .update (d [2 ])
358+ prog_bar .close ()
358359 res = []
359360 for data in data_list :
360361 res += data
0 commit comments