@@ -314,7 +314,6 @@ def __call__(
314314 dataset = sample_dataset (dataset , self .dataset_sample , True , self .random_state )
315315
316316 map_kwargs = {'batched' : True , 'batch_size' : batch_size }
317- cache_file_name = None
318317 if isinstance (dataset , HfDataset ):
319318 if not load_from_cache_file and is_dist () and not is_master ():
320319 load_from_cache_file = True
@@ -326,29 +325,28 @@ def __call__(
326325 dataset = RowPreprocessor .get_features_dataset (dataset )
327326 if 'solution' in dataset .features :
328327 with safe_ddp_context (None , True ):
329- if not dataset .cache_files :
330- cache_file_name = os .path .join (get_cache_dir (), 'datasets' , 'map_cache' ,
331- f'{ dataset ._fingerprint } .arrow' )
332- dataset = dataset .map (
333- lambda x : { '__#solution' : x [ 'solution' ]}, ** map_kwargs , cache_file_name = cache_file_name )
328+ if isinstance ( dataset , HfDataset ) and not dataset .cache_files :
329+ map_kwargs [ ' cache_file_name' ] = os .path .join (get_cache_dir (), 'datasets' , 'map_cache' ,
330+ f'{ dataset ._fingerprint } .arrow' )
331+ dataset = dataset .map (lambda x : { '__#solution' : x [ 'solution' ]}, ** map_kwargs )
332+ map_kwargs . pop ( 'cache_file_name' , None )
334333 dataset = self ._rename_columns (dataset )
335334 dataset = self .prepare_dataset (dataset )
336335 dataset = self ._cast_pil_image (dataset )
337336
338337 ignore_max_length_error = True if isinstance (dataset , HfDataset ) and num_proc > 1 else False
339338 with self ._patch_arrow_writer (), safe_ddp_context (None , True ):
340339 try :
341- if not dataset .cache_files :
342- cache_file_name = os .path .join (get_cache_dir (), 'datasets' , 'map_cache' ,
343- f'{ dataset ._fingerprint } .arrow' )
340+ if isinstance ( dataset , HfDataset ) and not dataset .cache_files :
341+ map_kwargs [ ' cache_file_name' ] = os .path .join (get_cache_dir (), 'datasets' , 'map_cache' ,
342+ f'{ dataset ._fingerprint } .arrow' )
344343 dataset_mapped = dataset .map (
345344 self .batched_preprocess ,
346345 fn_kwargs = {
347346 'strict' : strict ,
348347 'ignore_max_length_error' : ignore_max_length_error
349348 },
350349 remove_columns = list (dataset .features .keys ()),
351- cache_file_name = cache_file_name ,
352350 ** map_kwargs )
353351 except NotImplementedError :
354352 pass
0 commit comments