Skip to content

Commit 8727cd3

Browse files
Fix/0412 (#690)
1 parent 254122c commit 8727cd3

File tree

6 files changed

+63
-34
lines changed

6 files changed

+63
-34
lines changed

docs/source/LLM/命令行参数.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@
157157
dpo参数继承了sft参数, 除此之外增加了以下参数:
158158

159159
- `--ref_model_type`: 对比模型的类型, 可以选择的`model_type`可以查看`MODEL_MAPPING.keys()`.
160+
- `--ref_model_id_or_path`: 对比模型的本地cache路径, 默认为`None`.
160161
- `--max_prompt_length`: 最大的提示长度, 该参数会传入DPOTrainer中, 使prompt长度不超过该值的设置, 默认值`1024`.
161162
- `--beta`: DPO logits的正则项,默认为0.1.
162163
- `--label_smoothing`: 是否使用DPO smoothing, 默认值为0,一般设置在0~0.5之间.
@@ -240,7 +241,7 @@ eval参数继承了infer参数,除此之外增加了以下参数:
240241
- `--eval_dataset`: 评测的官方数据集,默认值为`['ceval', 'gsm8k', 'arc']`, 此外支持`mmlu``bbh`两个数据集. 如果仅需要评测自定义数据集,可以将该参数设置为`no`.
241242
- `--eval_limit`: 每个评测集的子数据集的采样数量, 默认为`None`代表全量评测.
242243
- `--eval_few_shot`: 每个评测集的子数据集的few-shot个数, 默认为`None`代表使用数据集默认配置.
243-
- `--custom_eval_config`: 使用自定义数据集进行评测, 需要是一个本地存在的文件路径, 文件格式详见[自定义评测集](./LLM评测文档#自定义评测集).
244+
- `--custom_eval_config`: 使用自定义数据集进行评测, 需要是一个本地存在的文件路径, 文件格式详见[自定义评测集](./LLM评测文档.md#自定义评测集).
244245

245246
## app-ui 参数
246247

docs/source_en/LLM/Command-line-parameters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ The following parameters take effect when `sft_type` is set to `ia3`.
157157
dpo parameters inherit from sft parameters, with the following added parameters:
158158

159159
- `--ref_model_type`: Type of reference model, available `model_type` options can be found in `MODEL_MAPPING.keys()`.
160+
- `--ref_model_id_or_path`: The local cache dir for reference model, default `None`.
160161
- `--max_prompt_length`: Maximum prompt length, this parameter is passed to DPOTrainer, setting prompt length to not exceed this value, default is `1024`.
161162
- `--beta`: Regularization term for DPO logits, default is 0.1.
162163
- `--label_smoothing`: Whether to use DPO smoothing, default is 0, generally set between 0~0.5.

swift/llm/dpo.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,12 @@ def llm_dpo(args: DPOArguments) -> str:
5858
model_id_or_path=args.model_id_or_path,
5959
**kwargs)
6060
if args.ref_model_type is not None:
61-
ref_model, _ = get_model_tokenizer(args.ref_model_type,
62-
args.torch_dtype, model_kwargs,
63-
**kwargs)
61+
ref_model, _ = get_model_tokenizer(
62+
args.ref_model_type,
63+
args.torch_dtype,
64+
model_kwargs,
65+
model_id_or_path=args.ref_model_id_or_path,
66+
**kwargs)
6467
else:
6568
ref_model = None
6669

swift/llm/eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
from modelscope import GenerationConfig
99

1010
from swift.utils import get_logger, get_main
11-
from . import (EvalArguments, inference, inference_vllm, merge_lora,
12-
prepare_model_template)
11+
from . import EvalArguments, inference, merge_lora, prepare_model_template
1312

1413
logger = get_logger()
1514

@@ -39,6 +38,7 @@ def __init__(self, args: EvalArguments, model_name, config={}, **kwargs):
3938

4039
def predict(self, prompt: str, **kwargs):
4140
if self.args.infer_backend == 'vllm':
41+
from . import inference_vllm
4242
request_list = [{
4343
'query': prompt,
4444
'history': kwargs.get('history'),

swift/llm/utils/argument.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2-
import inspect
32
import math
43
import os
54
from dataclasses import dataclass, field
6-
from typing import Dict, List, Literal, Optional, Set, Tuple, Union
5+
from typing import List, Literal, Optional, Set, Tuple, Union
76

87
import json
98
import numpy as np
109
import torch
11-
import torch.distributed as dist
1210
import transformers
1311
from datasets import Dataset as HfDataset
1412
from datasets import concatenate_datasets
@@ -29,7 +27,7 @@
2927
register_dataset)
3028
from .model import (MODEL_MAPPING, dtype_mapping, get_additional_saved_files,
3129
get_default_lora_target_modules, get_default_template_type)
32-
from .template import TEMPLATE_MAPPING, TemplateType
30+
from .template import TEMPLATE_MAPPING
3331
from .utils import is_vllm_available
3432

3533
logger = get_logger()
@@ -845,6 +843,8 @@ class DPOArguments(SftArguments):
845843
default=None,
846844
metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'})
847845

846+
ref_model_id_or_path: Optional[str] = None
847+
848848
max_prompt_length: int = 1024
849849
beta: float = 0.1
850850
label_smoothing: float = 0.0
@@ -1169,6 +1169,9 @@ def load_from_ckpt_dir(args: InferArguments) -> None:
11691169
continue
11701170
setattr(args, key, sft_args.get(key))
11711171

1172+
if args.model_id_or_path is None:
1173+
args.model_id_or_path = sft_args.get('model_id_or_path')
1174+
11721175

11731176
def check_flash_attn(args: Union[SftArguments, InferArguments]) -> None:
11741177
model_info = MODEL_MAPPING[args.model_type]

swift/llm/utils/dataset.py

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,8 @@ def map_row(row):
549549
if response and response.startswith('Answer:'):
550550
response = response[len('Answer:') + 1:].strip()
551551
return {'query': row['query'], 'response': response}
552-
return dataset.rename_columns({'instruction': 'query', 'output': 'response'})\
552+
553+
return dataset.rename_columns({'instruction': 'query', 'output': 'response'}) \
553554
.remove_columns(['input', 'file']).map(map_row).filter(lambda row: row['response'] is not None)
554555

555556

@@ -897,36 +898,56 @@ def process_hh_rlhf_cn(dataset):
897898

898899
def reorganize_row(row):
899900
history = []
900-
if isinstance(row['context'], str):
901-
row['context'] = ast.literal_eval(row['context'])
902-
if isinstance(row['chosen'], str):
903-
row['chosen'] = ast.literal_eval(row['chosen'])
904-
if isinstance(row['rejected'], str):
905-
row['rejected'] = ast.literal_eval(row['rejected'])
906-
for idx, h in enumerate(row['context']):
907-
if idx % 2 == 0 and h['role'] != 'human':
908-
return {'query': None}
909-
if idx % 2 != 0 and h['role'] != 'assistant':
910-
return {'query': None}
911-
if idx % 2 == 0:
912-
history.append([h['text'], None])
913-
else:
914-
history[-1][-1] = h['text']
915-
if history[-1][-1] is not None:
916-
return {'query': None}
917-
query = history[-1][0]
918-
history = history[:-1]
919-
response = row['chosen']['text']
920-
rejected_response = row['rejected']['text']
901+
try:
902+
if isinstance(row['context'], str):
903+
row['context'] = ast.literal_eval(row['context'])
904+
if isinstance(row['chosen'], str):
905+
row['chosen'] = ast.literal_eval(row['chosen'])
906+
if isinstance(row['rejected'], str):
907+
row['rejected'] = ast.literal_eval(row['rejected'])
908+
for idx, h in enumerate(row['context']):
909+
if idx % 2 == 0 and h['role'] != 'human':
910+
raise ValueError()
911+
if idx % 2 != 0 and h['role'] != 'assistant':
912+
raise ValueError()
913+
if idx % 2 == 0:
914+
history.append([h['text'], None])
915+
else:
916+
history[-1][-1] = h['text']
917+
if history[-1][-1] is not None:
918+
raise ValueError()
919+
query = history[-1][0]
920+
history = history[:-1]
921+
response = row['chosen']['text']
922+
rejected_response = row['rejected']['text']
923+
except: # noqa
924+
return {
925+
'query': '',
926+
'response': '',
927+
'rejected_response': '',
928+
'history': [],
929+
}
921930
return {
922931
'query': query,
923932
'response': response,
924933
'rejected_response': rejected_response,
925934
'history': history,
926935
}
927936

928-
return dataset.map(reorganize_row).filter(
929-
lambda row: row['query'] is not None)
937+
def row_can_be_parsed(row):
938+
try:
939+
if isinstance(row['context'], str):
940+
row['context'] = ast.literal_eval(row['context'])
941+
if isinstance(row['chosen'], str):
942+
row['chosen'] = ast.literal_eval(row['chosen'])
943+
if isinstance(row['rejected'], str):
944+
row['rejected'] = ast.literal_eval(row['rejected'])
945+
return True
946+
except: # noqa
947+
return False
948+
949+
return dataset.filter(row_can_be_parsed).map(reorganize_row).filter(
950+
lambda row: row['query'])
930951

931952

932953
register_dataset(

0 commit comments

Comments
 (0)