Skip to content

Commit 4dd1137

Browse files
authored
Support hermes loss_scale (#3963)
1 parent 0745667 commit 4dd1137

File tree

22 files changed

+223
-267
lines changed

22 files changed

+223
-267
lines changed

README.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,6 @@
2828
<p align="center">
2929
<a href="https://arxiv.org/abs/2408.05517">Paper</a> &nbsp | <a href="https://swift.readthedocs.io/en/latest/">Swift3.x En Doc</a> &nbsp | &nbsp <a href="https://swift.readthedocs.io/zh-cn/latest/">Swift3.x中文文档</a> &nbsp
3030
</p>
31-
<p align="center">
32-
<a href="https://swift2x-en.readthedocs.io/en/latest/">Swift2.x En Doc</a> &nbsp | &nbsp <a href="https://swift2x.readthedocs.io/zh-cn/latest/">Swift2.x中文文档</a> &nbsp
33-
</p>
34-
3531

3632
## 📖 Table of Contents
3733
- [Groups](#-Groups)

README_CN.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,6 @@
2929
<p align="center">
3030
<a href="https://arxiv.org/abs/2408.05517">论文</a> &nbsp | <a href="https://swift.readthedocs.io/en/latest/">Swift3.x En Doc</a> &nbsp | &nbsp <a href="https://swift.readthedocs.io/zh-cn/latest/">Swift3.x中文文档</a> &nbsp
3131
</p>
32-
<p align="center">
33-
<a href="https://swift2x-en.readthedocs.io/en/latest/">Swift2.x En Doc</a> &nbsp | &nbsp <a href="https://swift2x.readthedocs.io/zh-cn/latest/">Swift2.x中文文档</a> &nbsp
34-
</p>
35-
3632

3733
## 📖 目录
3834
- [用户群](#-用户群)

docs/source/Instruction/命令行参数.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
- 🔥max_pixels: 多模态模型输入图片的最大像素数(H\*W),将超过该限制的图像进行缩放。默认为None,不限制最大像素数
7171
- 🔥agent_template: Agent模板,确定如何将工具列表转换成system,如何从模型回复中提取toolcall,以及确定`{"role": "tool_call", "content": "xxx"}`, `{"role": "tool_response", "content": "xxx"}`的模板格式。可选为"react_en", "hermes", "glm4", "qwen_en", "toolbench"等,更多请查看[这里](https://github.com/modelscope/ms-swift/blob/main/swift/plugin/agent_template/__init__.py)。默认为None,根据模型类型进行选择。
7272
- response_prefix: response的前缀字符,例如QwQ-32B将response_prefix设置为`'<think>\n'`。默认为None,根据模型自动设置
73+
- 注意:若对deepseek-r1/qwq模型使用不包含`<think>...</think>`的数据集进行训练,请加在推理训练后模型时额外传入`--response_prefix ''`
7374
- padding_side: 当训练`batch_size>=2`时的padding_side,可选值为'left'、'right',默认为'right'。(推理时的batch_size>=2时,只进行左padding)
7475
- loss_scale: 训练tokens的loss权重设置。默认为`'default'`,代表所有response(含history)以1计算交叉熵损失。可选值为'default'、'last_round'、'all',以及agent需要的loss_scale: 'react'、'agentflan'、'alpha_umi'和'qwen'。其中'last_round'代表只计算最后一轮response的损失,'all'代表计算所有tokens的损失。agent部分可以查看[插件化](../Customization/插件化.md)[Agent文档](./Agent支持.md)
7576
- use_chat_template: 使用chat模板或generation模板,默认为`True``swift pt`会自动设置为generation模板

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ Hints:
7373
- 🔥agent_template: Agent template, which determines how to convert the list of tools into a system, how to extract tool calls from the model's response, and specifies the template format for `{"role": "tool_call", "content": "xxx"}` and `{"role": "tool_response", "content": "xxx"}`. Optional values include "react_en", "hermes", "glm4", "qwen_en", "toolbench", etc. For more details, please check [here](https://github.com/modelscope/ms-swift/blob/main/swift/plugin/agent_template/__init__.py). The default value is None, meaning it will be selected based on the model type.
7474
- norm_bbox: Controls how to scale bounding boxes (bbox). Options are 'norm1000' and 'none'. 'norm1000' represents scaling bbox coordinates to one-thousandths, and 'none' means no scaling. Default is None, automatically selected based on the model.
7575
- response_prefix: The prefix character for the response, for example, setting the response_prefix to `'<think>\n'` for QwQ-32B. The default is None, and it is automatically set according to the model.
76+
- Note: If you are training the deepseek-r1/qwq model with a dataset that does not include `<think>...</think>`, please pass `--response_prefix ''` additionally when inferring after training.
7677
- padding_side: Padding side when `batch_size>=2` during training. Options are 'left' and 'right', with 'right' as the default. (For inference with batch_size>=2, only left padding is applied.)
7778
- loss_scale: Setting for the loss weight of training tokens. Default is `'default'`, meaning all responses (including history) are calculated with a cross-entropy loss of 1. Options are 'default', 'last_round', 'all', and agent-specific loss scales: 'react', 'agentflan', 'alpha_umi', and 'qwen'. 'last_round' means calculating only the loss of the last round's response, and 'all' calculates the loss for all tokens. For agent parts, see [Pluginization](../Customization/Pluginization.md) and [Agent Training](./Agent-support.md).
7879
- use_chat_template: Use chat template or generation template, default is `True`. `swift pt` is automatically set to the generation template.

examples/infer/demo_agent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,10 @@ def infer_continue_generate(engine):
109109
from swift.llm import LmdeployEngine
110110
engine = LmdeployEngine(model)
111111

112-
agent_template = agent_templates['hermes']() # react_en/qwen_en/qwen_en_parallel
113-
engine.default_template.agent_template = agent_template
112+
# agent_template = agent_templates['hermes']() # react_en/qwen_en/qwen_en_parallel
113+
# engine.default_template.agent_template = agent_template
114114

115115
infer(engine, get_infer_request())
116116
infer_stream(engine, get_infer_request())
117117

118-
infer_continue_generate(engine)
118+
# infer_continue_generate(engine)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
CUDA_VISIBLE_DEVICES=0 \
2+
swift sft \
3+
--model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
4+
--train_type full \
5+
--dataset AI-ModelScope/function-calling-chatml \
6+
--agent_template react_en \
7+
--loss_scale react \
8+
--response_prefix '' \
9+
--torch_dtype bfloat16 \
10+
--num_train_epochs 2 \
11+
--per_device_train_batch_size 1 \
12+
--per_device_eval_batch_size 1 \
13+
--learning_rate 1e-5 \
14+
--gradient_accumulation_steps 8 \
15+
--eval_steps 100 \
16+
--save_steps 100 \
17+
--save_total_limit 2 \
18+
--logging_steps 5 \
19+
--max_length 8192 \
20+
--save_only_model true \
21+
--packing true \
22+
--use_liger_kernel true \
23+
--output_dir output \
24+
--warmup_ratio 0.05 \
25+
--attn_impl flash_attn \
26+
--dataloader_num_workers 4 \
27+
--dataset_num_proc 16

examples/train/agent/loss_scale/infer.md

Lines changed: 0 additions & 163 deletions
This file was deleted.
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
import os
3+
4+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
5+
# os.environ['SWIFT_DEBUG'] = '1'
6+
7+
8+
def infer(engine: 'InferEngine', infer_request: 'InferRequest'):
9+
stop = [engine.default_template.agent_template.keyword.observation] # compat react_en
10+
request_config = RequestConfig(max_tokens=512, temperature=0, stop=stop)
11+
resp_list = engine.infer([infer_request], request_config)
12+
query = infer_request.messages[0]['content']
13+
response = resp_list[0].choices[0].message.content
14+
print(f'query: {query}')
15+
print(f'response: {response}')
16+
print(f'tool_calls: {resp_list[0].choices[0].message.tool_calls}')
17+
18+
tool = '{"temperature": 32, "condition": "Sunny", "humidity": 50}'
19+
print(f'tool_response: {tool}')
20+
infer_request.messages += [{'role': 'assistant', 'content': response}, {'role': 'tool', 'content': tool}]
21+
resp_list = engine.infer([infer_request], request_config)
22+
response2 = resp_list[0].choices[0].message.content
23+
print(f'response2: {response2}')
24+
25+
26+
def infer_stream(engine: 'InferEngine', infer_request: 'InferRequest'):
27+
stop = [engine.default_template.agent_template.keyword.observation]
28+
request_config = RequestConfig(max_tokens=512, temperature=0, stream=True, stop=stop)
29+
gen_list = engine.infer([infer_request], request_config)
30+
query = infer_request.messages[0]['content']
31+
response = ''
32+
print(f'query: {query}\nresponse: ', end='')
33+
for resp in gen_list[0]:
34+
if resp is None:
35+
continue
36+
delta = resp.choices[0].delta.content
37+
response += delta
38+
print(delta, end='', flush=True)
39+
print()
40+
print(f'tool_calls: {resp.choices[0].delta.tool_calls}')
41+
42+
tool = '{"temperature": 32, "condition": "Sunny", "humidity": 50}'
43+
print(f'tool_response: {tool}\nresponse2: ', end='')
44+
infer_request.messages += [{'role': 'assistant', 'content': response}, {'role': 'tool', 'content': tool}]
45+
gen_list = engine.infer([infer_request], request_config)
46+
for resp in gen_list[0]:
47+
if resp is None:
48+
continue
49+
print(resp.choices[0].delta.content, end='', flush=True)
50+
print()
51+
52+
53+
def get_infer_request():
54+
return InferRequest(
55+
messages=[{
56+
'role': 'user',
57+
'content': "How's the weather in Beijing today?"
58+
}],
59+
tools=[{
60+
'name': 'get_current_weather',
61+
'description': 'Get the current weather in a given location',
62+
'parameters': {
63+
'type': 'object',
64+
'properties': {
65+
'location': {
66+
'type': 'string',
67+
'description': 'The city and state, e.g. San Francisco, CA'
68+
},
69+
'unit': {
70+
'type': 'string',
71+
'enum': ['celsius', 'fahrenheit']
72+
}
73+
},
74+
'required': ['location']
75+
}
76+
}])
77+
78+
79+
if __name__ == '__main__':
80+
from swift.llm import InferEngine, InferRequest, PtEngine, RequestConfig
81+
from swift.plugin import agent_templates
82+
model = 'Qwen/Qwen2.5-3B'
83+
adapters = ['output/vx-xxx/checkpoint-xxx']
84+
engine = PtEngine(model, adapters=adapters, max_batch_size=8)
85+
86+
# agent_template = agent_templates['hermes']() # react_en/qwen_en/qwen_en_parallel
87+
# engine.default_template.agent_template = agent_template
88+
89+
infer(engine, get_infer_request())
90+
infer_stream(engine, get_infer_request())

0 commit comments

Comments
 (0)