Skip to content

Commit 09255ce

Browse files
authored
[dataset] fix self-cognition & load_from_cache_file (#4426)
1 parent ab41c74 commit 09255ce

File tree

3 files changed

+28
-13
lines changed

3 files changed

+28
-13
lines changed

swift/llm/argument/base_args/data_args.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ class DataArguments:
2222
streaming (bool): Flag to enable streaming of datasets. Default is False.
2323
download_mode (Literal): Mode for downloading datasets. Default is 'reuse_dataset_if_exists'.
2424
columns: Used for manual column mapping of datasets.
25-
model_name (List[str]): List containing Chinese and English names of the model. Default is [None, None].
26-
model_author (List[str]): List containing Chinese and English names of the model author.
27-
Default is [None, None].
25+
model_name (List[str]): List containing Chinese and English names of the model. Default is None.
26+
model_author (List[str]): List containing Chinese and English names of the model author. Default is None.
2827
custom_dataset_info (Optional[str]): Path to custom dataset_info.json file. Default is None.
2928
"""
3029
# dataset_id or dataset_dir or dataset_path
@@ -49,9 +48,8 @@ class DataArguments:
4948
strict: bool = False
5049
remove_unused_columns: bool = True
5150
# Chinese name and English name
52-
model_name: List[str] = field(default_factory=lambda: [None, None], metadata={'help': "e.g. ['小黄', 'Xiao Huang']"})
53-
model_author: List[str] = field(
54-
default_factory=lambda: [None, None], metadata={'help': "e.g. ['魔搭', 'ModelScope']"})
51+
model_name: Optional[List[str]] = field(default=None, metadata={'help': "e.g. ['小黄', 'Xiao Huang']"})
52+
model_author: Optional[List[str]] = field(default=None, metadata={'help': "e.g. ['魔搭', 'ModelScope']"})
5553

5654
custom_dataset_info: List[str] = field(default_factory=list) # .json
5755

swift/llm/dataset/dataset/llm.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -825,14 +825,18 @@ def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
825825

826826

827827
class SelfCognitionPreprocessor(ResponsePreprocessor):
828-
name: Optional[Tuple[str, str]] = None
829-
author: Optional[Tuple[str, str]] = None
830828

831829
def __init__(self, *args, query_suffix: str = '', response_prefix: str = '', **kwargs):
832830
self.query_suffix = query_suffix
833831
self.response_prefix = response_prefix
832+
self.name: Optional[Tuple[str, str]] = None
833+
self.author: Optional[Tuple[str, str]] = None
834834
super().__init__(*args, **kwargs)
835835

836+
def set_name_author(self, name, author):
837+
self.name = name
838+
self.author = author
839+
836840
def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
837841
for key in ['name', 'author']:
838842
val = getattr(self, key)
@@ -863,4 +867,5 @@ def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
863867
SubsetDataset(
864868
'empty_think', preprocess_func=SelfCognitionPreprocessor(response_prefix='<think>\n\n</think>\n\n')),
865869
],
870+
dataset_name='self-cognition',
866871
tags=['chat', 'self-cognition', '🔥']))

swift/llm/dataset/loader.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -422,18 +422,30 @@ def load(
422422

423423

424424
def init_self_cognition_preprocessor(
425+
dataset_meta: Optional[DatasetMeta],
425426
model_name: Union[Tuple[str, str], List[str], None] = None,
426427
model_author: Union[Tuple[str, str], List[str], None] = None,
427428
) -> None:
428-
from .dataset.llm import SelfCognitionPreprocessor
429+
if dataset_meta is None or model_name is None and model_author is None:
430+
return
431+
kwargs = {}
429432
# zh, en
430-
for key in ['model_name', 'model_author']:
431-
val = locals()[key]
433+
for key in ['name', 'author']:
434+
val = locals()[f'model_{key}']
432435
if isinstance(val, str):
433436
val = [val]
434437
if val is not None and val[0] is not None and (len(val) == 1 or val[1] is None):
435438
val = (val[0], val[0])
436-
setattr(SelfCognitionPreprocessor, key[len('model_'):], val)
439+
kwargs[key] = val
440+
441+
from .dataset.llm import SelfCognitionPreprocessor
442+
preprocess_funcs = [dataset_meta.preprocess_func]
443+
preprocess_funcs += [subset.preprocess_func for subset in dataset_meta.subsets if isinstance(subset, SubsetDataset)]
444+
for preprocess_func in preprocess_funcs:
445+
if isinstance(preprocess_func, SelfCognitionPreprocessor):
446+
preprocess_func.set_name_author(**kwargs)
447+
logger.info_once(f"SelfCognitionPreprocessor has been successfully configured with name: {kwargs['name']}, "
448+
f"author: {kwargs['author']}.")
437449

438450

439451
def load_dataset(
@@ -479,7 +491,7 @@ def load_dataset(
479491
Returns:
480492
The train dataset and val dataset
481493
"""
482-
init_self_cognition_preprocessor(model_name, model_author)
494+
init_self_cognition_preprocessor(DATASET_MAPPING.get('self-cognition'), model_name, model_author)
483495
if isinstance(datasets, str):
484496
datasets = [datasets]
485497
if not isinstance(seed, np.random.RandomState):

0 commit comments

Comments
 (0)