@@ -317,3 +317,35 @@ def refine():
317317 save_path = cfg ['save_path' ]
318318 for dataset in dataset_dict .values ():
319319 dataset .dump (save_path )
320+
321+ def deduplicate ():
322+ from ..config import api_init_config
323+ from dataflow .data import DataFlowDSDict
324+ from dataflow .utils .registry import FORMATTER_REGISTRY
325+ from dataflow .core import ScoreRecord
326+ cfg = api_init_config ()
327+ dataset_dict = DataFlowDSDict ()
328+
329+ if isinstance (cfg .yaml , str ):
330+ with open (cfg .yaml , 'r' ) as f :
331+ cfg .yaml = yaml .safe_load (f ) # 解析成字典
332+
333+ for scorer_name , args in cfg .yaml .items ():
334+ if "num_workers" in cfg :
335+ args ["num_workers" ] = cfg .num_workers
336+ if "model_cache_path" in cfg :
337+ args ["model_cache_dir" ] = cfg .model_cache_path
338+ processor = get_processor (scorer_name , args )
339+ if processor .data_type not in dataset_dict .keys ():
340+ formatter = FORMATTER_REGISTRY .get ('TextFormatter' )(cfg ['data' ], cfg ['key' ], cfg ['sft_single_round' ], cfg ['sft_multi_round' ], cfg ['RLHF' ])
341+ datasets = formatter .load_dataset ()
342+ dataset_dict [processor .data_type ] = datasets
343+ else :
344+ datasets = dataset_dict [processor .data_type ]
345+ processed_dataset = processor (datasets )
346+ dataset_dict [processor .data_type ] = processed_dataset
347+ print (processed_dataset )
348+
349+
350+
351+
0 commit comments