|
20 | 20 |
|
21 | 21 | from swift.utils import get_logger, get_seed, is_dist, is_local_master, read_from_jsonl, transform_jsonl_to_df |
22 | 22 | from .preprocess import (AlpacaPreprocessor, ClsPreprocessor, ComposePreprocessor, ConversationsPreprocessor, |
23 | | - PreprocessFunc, RenameColumnsPreprocessor, SmartPreprocessor, TextGenerationPreprocessor) |
| 23 | + PreprocessFunc, RenameColumnsPreprocessor, SmartPreprocessor, TextGenerationPreprocessor, |
| 24 | + preprocess_sharegpt) |
24 | 25 | from .template import History |
25 | 26 | from .utils import download_dataset |
26 | 27 |
|
@@ -213,21 +214,26 @@ def register_local_dataset( |
213 | 214 |
|
214 | 215 |
|
215 | 216 | def register_dataset_info(dataset_name: str, d_info: Dict[str, Any], **kwargs) -> None: |
216 | | - if 'dataset_path' in d_info: |
217 | | - base_dir = kwargs.pop('base_dir', None) |
218 | | - register_local_dataset(dataset_name, d_info.pop('dataset_path', None), base_dir, **d_info) |
219 | | - return |
220 | | - |
221 | | - assert 'dataset_id' in d_info or 'hf_dataset_id' in d_info |
222 | 217 | preprocess_func = None |
223 | 218 | if 'columns' in d_info: |
224 | 219 | preprocess_func = RenameColumnsPreprocessor(d_info['columns']) |
225 | 220 | d_info.pop('columns') |
| 221 | + d_info['preprocess_func'] = preprocess_func |
226 | 222 | elif 'conversations' in d_info: |
227 | 223 | preprocess_func = ConversationsPreprocessor(**d_info['conversations']) |
228 | 224 | d_info.pop('conversations') |
| 225 | + d_info['preprocess_func'] = preprocess_func |
| 226 | + |
| 227 | + if 'dataset_path' in d_info: |
| 228 | + base_dir = kwargs.pop('base_dir', None) |
| 229 | + register_local_dataset(dataset_name, d_info.pop('dataset_path', None), base_dir, **d_info) |
| 230 | + return |
| 231 | + |
| 232 | + assert 'dataset_id' in d_info or 'hf_dataset_id' in d_info |
| 233 | + |
229 | 234 | dataset_id = d_info.pop('dataset_id', None) |
230 | 235 | subsets = d_info.pop('subsets', None) |
| 236 | + preprocess_func = d_info.pop('preprocess_func', None) |
231 | 237 | register_dataset(dataset_name, dataset_id, subsets, preprocess_func, get_dataset_from_repo, **d_info, exist_ok=True) |
232 | 238 |
|
233 | 239 |
|
@@ -809,30 +815,10 @@ def reorganize_row(row): |
809 | 815 | get_dataset_from_repo, |
810 | 816 | tags=['rlhf', 'dpo', 'pairwise']) |
811 | 817 |
|
812 | | - |
813 | | -def _preprocess_sharegpt(dataset: HfDataset) -> HfDataset: |
814 | | - query = [] |
815 | | - response = [] |
816 | | - history: List[History] = [] |
817 | | - for d in tqdm(dataset): |
818 | | - if isinstance(d['conversation'], str): |
819 | | - try: |
820 | | - conversation = ast.literal_eval(d['conversation']) |
821 | | - except SyntaxError: |
822 | | - continue |
823 | | - query.append(conversation[-1]['human']) |
824 | | - response.append(conversation[-1]['assistant']) |
825 | | - h = [] |
826 | | - for c in conversation[:-1]: |
827 | | - h.append([c['human'], c['assistant']]) |
828 | | - history.append(h) |
829 | | - return HfDataset.from_dict({'query': query, 'response': response, 'history': history}) |
830 | | - |
831 | | - |
832 | 818 | register_dataset( |
833 | 819 | DatasetName.sharegpt, |
834 | 820 | 'huangjintao/sharegpt', ['common-zh', 'computer-zh', 'unknow-zh', 'common-en', 'computer-en'], |
835 | | - _preprocess_sharegpt, |
| 821 | + preprocess_sharegpt, |
836 | 822 | get_dataset_from_repo, |
837 | 823 | tags=['chat', 'general', 'multi-round']) |
838 | 824 |
|
|
0 commit comments