1919from tqdm .auto import tqdm
2020from transformers .utils import strtobool
2121
22- from swift .utils import (get_logger , get_seed , is_dist , is_local_master , read_from_jsonl , safe_ddp_context ,
23- transform_jsonl_to_df )
22+ from swift .utils import get_logger , get_seed , is_dist , is_local_master , read_from_jsonl , transform_jsonl_to_df
2423from swift .utils .torch_utils import _find_local_mac
2524from .media import MediaCache , MediaTag
2625from .preprocess import (AlpacaPreprocessor , ClsPreprocessor , ComposePreprocessor , ConversationsPreprocessor ,
2726 ListPreprocessor , PreprocessFunc , RenameColumnsPreprocessor , SmartPreprocessor ,
2827 TextGenerationPreprocessor , preprocess_sharegpt )
2928from .utils import download_dataset
3029
30+ dataset_enable_cache = strtobool (os .environ .get ('DATASET_ENABLE_CACHE' , 'False' ))
31+
3132
3233def _update_fingerprint_mac (* args , ** kwargs ):
3334 mac = _find_local_mac ().replace (':' , '' )
@@ -378,7 +379,7 @@ def _post_preprocess(
378379 train_sample = dataset_sample - val_sample
379380 assert isinstance (val_sample , int )
380381 train_dataset , val_dataset = train_dataset .train_test_split (
381- test_size = val_sample , seed = get_seed (random_state ), load_from_cache_file = False ).values ()
382+ test_size = val_sample , seed = get_seed (random_state ), load_from_cache_file = dataset_enable_cache ).values ()
382383
383384 assert train_sample > 0
384385 train_dataset = sample_dataset (train_dataset , train_sample , random_state )
@@ -445,7 +446,8 @@ def preprocess_row(row):
445446 return {'image' : [], 'conversations' : []}
446447 return {'image' : [image ]}
447448
448- dataset = dataset .map (preprocess_row , load_from_cache_file = False ).filter (lambda row : row ['conversations' ])
449+ dataset = dataset .map (
450+ preprocess_row , load_from_cache_file = dataset_enable_cache ).filter (lambda row : row ['conversations' ])
449451 return ConversationsPreprocessor (
450452 user_role = 'human' , assistant_role = 'gpt' , media_type = 'image' , error_strategy = 'delete' )(
451453 dataset )
@@ -490,7 +492,7 @@ def preprocess_row(row):
490492 else :
491493 return {'images' : []}
492494
493- return dataset .map (preprocess_row , load_from_cache_file = False ).filter (lambda row : row ['images' ])
495+ return dataset .map (preprocess_row , load_from_cache_file = dataset_enable_cache ).filter (lambda row : row ['images' ])
494496
495497
496498def get_mantis_dataset (dataset_id : str ,
@@ -575,7 +577,7 @@ def preprocess_image(example):
575577 example ['images' ] = []
576578 return example
577579
578- dataset = dataset .map (preprocess_image , load_from_cache_file = False ).filter (lambda row : row ['images' ])
580+ dataset = dataset .map (preprocess_image , load_from_cache_file = dataset_enable_cache ).filter (lambda row : row ['images' ])
579581 return ConversationsPreprocessor (
580582 user_role = 'user' ,
581583 assistant_role = 'assistant' ,
@@ -666,7 +668,7 @@ def preprocess(row):
666668 'query' : np .random .choice (caption_prompt ),
667669 }
668670
669- return dataset .map (preprocess , load_from_cache_file = False )
671+ return dataset .map (preprocess , load_from_cache_file = dataset_enable_cache )
670672
671673
672674register_dataset (
@@ -717,11 +719,9 @@ def _preprocess_aishell1_dataset(dataset: HfDataset) -> HfDataset:
717719
718720
719721def _preprocess_video_chatgpt (dataset : HfDataset ) -> HfDataset :
720- from datasets .download .download_manager import DownloadManager
721722 url = 'https://modelscope.cn/datasets/huangjintao/VideoChatGPT/resolve/master/videos.zip'
722- with safe_ddp_context ():
723- local_dir = DownloadManager ().download_and_extract (url )
724- local_dir = os .path .join (str (local_dir ), 'Test_Videos' )
723+ local_dir = MediaCache .download (url , 'video_chatgpt' )
724+ local_dir = os .path .join (local_dir , 'Test_Videos' )
725725 # only `.mp4`
726726 mp4_set = [file [:- 4 ] for file in os .listdir (local_dir ) if file .endswith ('mp4' )]
727727 query = []
@@ -794,7 +794,7 @@ def map_row(row):
794794 return response
795795
796796 dataset = AlpacaPreprocessor ()(dataset )
797- return dataset .map (map_row , load_from_cache_file = False )
797+ return dataset .map (map_row , load_from_cache_file = dataset_enable_cache )
798798
799799
800800register_dataset (
@@ -821,7 +821,7 @@ def map_row(row):
821821 title = match .group (1 )
822822 return {'response' : title }
823823
824- return dataset .map (map_row , load_from_cache_file = False ).filter (lambda row : row ['response' ])
824+ return dataset .map (map_row , load_from_cache_file = dataset_enable_cache ).filter (lambda row : row ['response' ])
825825
826826
827827register_dataset (
@@ -1002,7 +1002,8 @@ def reorganize_row(row):
10021002 'history' : history ,
10031003 }
10041004
1005- return dataset .map (reorganize_row , load_from_cache_file = False ).filter (lambda row : row ['query' ] is not None )
1005+ return dataset .map (
1006+ reorganize_row , load_from_cache_file = dataset_enable_cache ).filter (lambda row : row ['query' ] is not None )
10061007
10071008
10081009register_dataset (
@@ -1067,7 +1068,7 @@ def row_can_be_parsed(row):
10671068 return False
10681069
10691070 return dataset .filter (row_can_be_parsed ).map (
1070- reorganize_row , load_from_cache_file = False ).filter (lambda row : row ['query' ])
1071+ reorganize_row , load_from_cache_file = dataset_enable_cache ).filter (lambda row : row ['query' ])
10711072
10721073
10731074register_dataset (
@@ -1137,7 +1138,8 @@ def preprocess_image(example):
11371138 return example
11381139
11391140 dataset = dataset .map (
1140- preprocess_image , load_from_cache_file = False ).filter (lambda example : example ['images' ] is not None )
1141+ preprocess_image ,
1142+ load_from_cache_file = dataset_enable_cache ).filter (lambda example : example ['images' ] is not None )
11411143 processer = ConversationsPreprocessor (
11421144 user_role = 'human' , assistant_role = 'gpt' , media_type = 'image' , media_key = 'images' , error_strategy = 'delete' )
11431145 return processer (dataset )
@@ -1182,8 +1184,8 @@ def preprocess(row):
11821184 return {'response' : '' , 'image' : None }
11831185
11841186 return dataset .map (
1185- preprocess ,
1186- load_from_cache_file = False ). filter ( lambda row : row . get ( 'response' )). rename_columns ( {'image' : 'images' })
1187+ preprocess , load_from_cache_file = dataset_enable_cache ). filter ( lambda row : row . get ( 'response' )). rename_columns (
1188+ {'image' : 'images' })
11871189
11881190
11891191def preprocess_refcoco_unofficial_caption (dataset ):
@@ -1209,7 +1211,7 @@ def preprocess(row):
12091211 res ['response' ] = ''
12101212 return res
12111213
1212- return dataset .map (preprocess , load_from_cache_file = False ).filter (lambda row : row .get ('response' ))
1214+ return dataset .map (preprocess , load_from_cache_file = dataset_enable_cache ).filter (lambda row : row .get ('response' ))
12131215
12141216
12151217register_dataset (
@@ -1254,7 +1256,7 @@ def preprocess(row):
12541256 res ['response' ] = ''
12551257 return res
12561258
1257- return dataset .map (preprocess , load_from_cache_file = False ).filter (lambda row : row .get ('response' ))
1259+ return dataset .map (preprocess , load_from_cache_file = dataset_enable_cache ).filter (lambda row : row .get ('response' ))
12581260
12591261
12601262register_dataset (
@@ -1323,7 +1325,8 @@ def preprocess_image(example):
13231325 return example
13241326
13251327 dataset = dataset .map (
1326- preprocess_image , load_from_cache_file = False ).filter (lambda example : example ['images' ] is not None )
1328+ preprocess_image ,
1329+ load_from_cache_file = dataset_enable_cache ).filter (lambda example : example ['images' ] is not None )
13271330 processer = ConversationsPreprocessor (
13281331 user_role = 'human' , assistant_role = 'gpt' , media_type = 'image' , media_key = 'images' , error_strategy = 'delete' )
13291332 return processer (dataset )
@@ -1386,7 +1389,7 @@ def preprocess(row):
13861389 else :
13871390 return {'image' : '' }
13881391
1389- dataset = dataset .map (preprocess , load_from_cache_file = False ).filter (lambda row : row ['image' ])
1392+ dataset = dataset .map (preprocess , load_from_cache_file = dataset_enable_cache ).filter (lambda row : row ['image' ])
13901393 return ConversationsPreprocessor (
13911394 user_role = 'human' , assistant_role = 'gpt' , media_type = 'image' , error_strategy = 'delete' )(
13921395 dataset )
@@ -1412,7 +1415,7 @@ def reorganize_row(row):
14121415 'rejected_response' : row ['answer_en' ],
14131416 }
14141417
1415- return dataset .map (reorganize_row , load_from_cache_file = False )
1418+ return dataset .map (reorganize_row , load_from_cache_file = dataset_enable_cache )
14161419
14171420
14181421def process_ultrafeedback_kto (dataset : HfDataset ):
@@ -1424,7 +1427,7 @@ def reorganize_row(row):
14241427 'label' : row ['label' ],
14251428 }
14261429
1427- return dataset .map (reorganize_row , load_from_cache_file = False )
1430+ return dataset .map (reorganize_row , load_from_cache_file = dataset_enable_cache )
14281431
14291432
14301433register_dataset (
@@ -1466,7 +1469,8 @@ def preprocess_row(row):
14661469 'response' : output ,
14671470 }
14681471
1469- return dataset .map (preprocess_row , load_from_cache_file = False ).filter (lambda row : row ['query' ] and row ['response' ])
1472+ return dataset .map (
1473+ preprocess_row , load_from_cache_file = dataset_enable_cache ).filter (lambda row : row ['query' ] and row ['response' ])
14701474
14711475
14721476register_dataset (
@@ -1495,7 +1499,7 @@ def preprocess_row(row):
14951499 'response' : response ,
14961500 }
14971501
1498- return dataset .map (preprocess_row , load_from_cache_file = False )
1502+ return dataset .map (preprocess_row , load_from_cache_file = dataset_enable_cache )
14991503
15001504
15011505register_dataset (
@@ -1537,7 +1541,7 @@ def preprocess(row):
15371541 'query' : query ,
15381542 }
15391543
1540- return dataset .map (preprocess , load_from_cache_file = False ).rename_column ('image' , 'images' )
1544+ return dataset .map (preprocess , load_from_cache_file = dataset_enable_cache ).rename_column ('image' , 'images' )
15411545
15421546
15431547register_dataset (
@@ -1560,7 +1564,7 @@ def preprocess(row):
15601564 'query' : query ,
15611565 }
15621566
1563- return dataset .map (preprocess , load_from_cache_file = False ).rename_column ('image' , 'images' )
1567+ return dataset .map (preprocess , load_from_cache_file = dataset_enable_cache ).rename_column ('image' , 'images' )
15641568
15651569
15661570register_dataset (
@@ -1584,7 +1588,7 @@ def preprocess(row):
15841588 'query' : query ,
15851589 }
15861590
1587- return dataset .map (preprocess , load_from_cache_file = False ).rename_column ('image' , 'images' )
1591+ return dataset .map (preprocess , load_from_cache_file = dataset_enable_cache ).rename_column ('image' , 'images' )
15881592
15891593
15901594register_dataset (
@@ -1606,7 +1610,8 @@ def preprocess_row(row):
16061610 return {'query' : query , 'response' : f'{ solution } \n So the final answer is:{ response } ' }
16071611
16081612 return dataset .map (
1609- preprocess_row , load_from_cache_file = False ).filter (lambda row : row ['image' ]).rename_columns ({'image' : 'images' })
1613+ preprocess_row ,
1614+ load_from_cache_file = dataset_enable_cache ).filter (lambda row : row ['image' ]).rename_columns ({'image' : 'images' })
16101615
16111616
16121617register_dataset (
@@ -1660,7 +1665,7 @@ def preprocess_row(row):
16601665
16611666 return {'images' : images , 'response' : response , 'objects' : json .dumps (objects or [], ensure_ascii = False )}
16621667
1663- return dataset .map (preprocess_row , load_from_cache_file = False ).filter (lambda row : row ['objects' ])
1668+ return dataset .map (preprocess_row , load_from_cache_file = dataset_enable_cache ).filter (lambda row : row ['objects' ])
16641669
16651670
16661671register_dataset (
@@ -1687,7 +1692,7 @@ def preprocess_row(row):
16871692 else :
16881693 return {'query' : '' , 'response' : '' , 'images' : '' }
16891694
1690- return dataset .map (preprocess_row , load_from_cache_file = False ).filter (lambda row : row ['query' ])
1695+ return dataset .map (preprocess_row , load_from_cache_file = dataset_enable_cache ).filter (lambda row : row ['query' ])
16911696
16921697
16931698register_dataset (
@@ -1720,7 +1725,7 @@ def preprocess_row(row):
17201725 return {'messages' : rounds }
17211726
17221727 dataset = dataset .map (
1723- preprocess_row , load_from_cache_file = False ).map (
1728+ preprocess_row , load_from_cache_file = dataset_enable_cache ).map (
17241729 ConversationsPreprocessor (
17251730 user_role = 'user' ,
17261731 assistant_role = 'assistant' ,
@@ -1730,7 +1735,7 @@ def preprocess_row(row):
17301735 media_key = 'images' ,
17311736 media_type = 'image' ,
17321737 ).preprocess ,
1733- load_from_cache_file = False )
1738+ load_from_cache_file = dataset_enable_cache )
17341739 return dataset
17351740
17361741
@@ -1787,8 +1792,8 @@ def preprocess(row):
17871792 }
17881793
17891794 return dataset .map (
1790- preprocess ,
1791- load_from_cache_file = False ). filter ( lambda r : r ['source' ] != 'toxic-dpo-v0.2' and r ['query' ] is not None )
1795+ preprocess , load_from_cache_file = dataset_enable_cache ). filter (
1796+ lambda r : r ['source' ] != 'toxic-dpo-v0.2' and r ['query' ] is not None )
17921797
17931798
17941799register_dataset (
@@ -1814,7 +1819,7 @@ def preprocess(row):
18141819 'response' : response ,
18151820 }
18161821
1817- return dataset .map (preprocess , load_from_cache_file = False )
1822+ return dataset .map (preprocess , load_from_cache_file = dataset_enable_cache )
18181823
18191824
18201825register_dataset (
@@ -2116,7 +2121,7 @@ def reorganize_row(row):
21162121 'response' : convs [- 1 ]['value' ]
21172122 }
21182123
2119- return dataset .map (reorganize_row , load_from_cache_file = False )
2124+ return dataset .map (reorganize_row , load_from_cache_file = dataset_enable_cache )
21202125
21212126
21222127register_dataset (
0 commit comments