Skip to content

Commit 51b3718

Browse files
authored
support agent packing (#3853)
1 parent cc1ece3 commit 51b3718

File tree

6 files changed

+183
-20
lines changed

6 files changed

+183
-20
lines changed

examples/infer/demo_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,15 @@ def infer_stream(engine: 'InferEngine', infer_request: 'InferRequest'):
2424
gen_list = engine.infer([infer_request], request_config)
2525
query = infer_request.messages[0]['content']
2626
response = ''
27-
tool = '{"temperature": 72, "condition": "Sunny", "humidity": 50}\n'
27+
tool = '{"temperature": 72, "condition": "Sunny", "humidity": 50}'
2828
print(f'query: {query}')
2929
for resp in gen_list[0]:
3030
if resp is None:
3131
continue
3232
delta = resp.choices[0].delta.content
3333
response += delta
3434
print(delta, end='', flush=True)
35-
print(tool, end='')
35+
print(tool)
3636

3737
infer_request.messages += [{'role': 'assistant', 'content': response}, {'role': 'tool', 'content': tool}]
3838
gen_list = engine.infer([infer_request], request_config)

examples/train/agent/infer.md

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
以下为如何使用训练后Agent模型的简易教程:
2+
3+
## 方案一:使用swift app
4+
5+
1. 输入以下shell,启动app-ui:
6+
7+
```shell
8+
CUDA_VISIBLE_DEVICES=0 \
9+
swift app \
10+
--adapters output/vx-xxx/checkpoint-xxx \
11+
--stream true \
12+
--max_new_tokens 2048 \
13+
--verbose true \
14+
--stop_words 'Observation:'
15+
```
16+
17+
2. 将以下内容输入system中,点击重置system并清空历史记录:
18+
```
19+
Answer the following questions as best you can. You have access to the following APIs:
20+
1. TouristGuide: Call this tool to interact with the TouristGuide API. What is the TouristGuide API useful for? 旅游指南API,根据用户指定的条件查询目的地的旅游信息. Parameters: [{"name": "destination", "description": "指定需要查询的目的地,例如巴黎、纽约等", "required": "True"}, {"name": "attraction", "description": "指定需要查询的景点,例如埃菲尔铁塔、自由女神像等", "required": "False"}, {"name": "food", "description": "指定需要查询的美食,例如法国香槟、美国汉堡等", "required": "False"}, {"name": "hotel", "description": "指定需要查询的酒店,例如五星级、四星级等", "required": "False"}]
21+
22+
2. newsfeed: Call this tool to interact with the newsfeed API. What is the newsfeed API useful for? 获取指定主题的新闻列表. Parameters: [{"name": "topic", "description": "需要查询的新闻主题", "required": "False"}]
23+
24+
3. poemgen: Call this tool to interact with the poemgen API. What is the poemgen API useful for? 生成优美的诗歌. Parameters: [{"name": "theme", "description": "诗歌主题(例如:爱情、自然、季节等)", "required": "False"}]
25+
26+
4. Converter: Call this tool to interact with the Converter API. What is the Converter API useful for? 通过Python解释器进行单位转换. Parameters: [{"name": "from_unit", "description": "原单位", "required": "True"}, {"name": "to_unit", "description": "目标单位", "required": "True"}, {"name": "value", "description": "需要转换的数值", "required": "True"}]
27+
28+
5. musicPlaylist: Call this tool to interact with the musicPlaylist API. What is the musicPlaylist API useful for? 音乐播放列表API,提供多种音乐类型的播放列表. Parameters: [{"name": "type", "description": "音乐类型,例如流行、摇滚、古典等", "required": "True"}, {"name": "mood", "description": "音乐风格,例如抒情、动感、欢快等", "required": "False"}, {"name": "artist", "description": "歌手名字", "required": "False"}]
29+
30+
Use the following format:
31+
32+
Thought: you should always think about what to do
33+
Action: the action to take, should be one of the above tools[TouristGuide, newsfeed, poemgen, Converter, musicPlaylist]
34+
Action Input: the input to the action
35+
Observation: the result of the action
36+
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
37+
Thought: I now know the final answer
38+
Final Answer: the final answer to the original input question
39+
Begin!
40+
```
41+
42+
3. 输入用户请求:`将200英里转换为公里`,并发送信息。模型将思考调用哪一个工具完成这一工作以及输出调用工具时所需的参数。遇到`Observation:`时终止输出,等待工具返回调用内容。
43+
```
44+
Action: Converter
45+
Action Input: {'from_unit': '英里', 'to_unit': '公里', 'value': 200}
46+
Observation:
47+
```
48+
49+
4. 模拟工具的返回,输入`tool:{'function_result': {'km': 321.8688}}`,模型将继续输入,并得到最终结果。
50+
```
51+
Thought: I now know the final answer
52+
Final Answer: 200英里等于321.8688公里。
53+
```
54+
55+
## 方案二:使用swift infer
56+
57+
1. 输入以下shell,启动命令行交互式推理界面
58+
59+
```shell
60+
CUDA_VISIBLE_DEVICES=0 \
61+
swift infer \
62+
--adapters output/vx-xxx/checkpoint-xxx \
63+
--stream true \
64+
--max_new_tokens 2048 \
65+
--stop_words 'Observation:'
66+
```
67+
68+
2. 依次输入以下内容:
69+
```
70+
<<< multi-line
71+
[INFO:swift] End multi-line input with `#`.
72+
[INFO:swift] Input `single-line` to switch to single-line input mode.
73+
<<<[M] reset-system#
74+
<<<[MS] Answer the following questions as best you can. You have access to the following APIs:
75+
1. translate: Call this tool to interact with the translate API. What is the translate API useful for? 将一种语言翻译成另一种语言. Parameters: [{"name": "text", "description": "需要翻译的文本", "required": "False"}, {"name": "source_lang", "description": "源语言,可选参数,默认为自动检测", "required": "False"}, {"name": "target_lang", "description": "目标语言,必选参数", "required": "False"}]
76+
77+
Use the following format:
78+
79+
Thought: you should always think about what to do
80+
Action: the action to take, should be one of the above tools[translate]
81+
Action Input: the input to the action
82+
Observation: the result of the action
83+
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
84+
Thought: I now know the final answer
85+
Final Answer: the final answer to the original input question
86+
Begin!#
87+
<<<[M] single-line#
88+
<<< 翻译成法语:你好,我叫小明
89+
Action: translate
90+
Action Input: {'text': '你好,我叫小明', 'source_lang': 'auto', 'target_lang': 'fr'}
91+
Observation:
92+
--------------------------------------------------
93+
<<< tool:{'translated_text': 'Bonjour, je m\\'appelle Xiao Ming'}
94+
Thought: I now know the final answer
95+
Final Answer: 根据您的要求,我已经将“你好,我叫小明”翻译成了法语。翻译结果为“Bonjour, je m'appelle Xiao Ming”。
96+
```
97+
98+
## 方案三:使用Python
99+
100+
```python
101+
import os
102+
103+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
104+
105+
if __name__ == '__main__':
106+
from swift.llm import InferEngine, InferRequest, PtEngine, RequestConfig
107+
model = 'Qwen/Qwen2.5-3B'
108+
adapters = ['output/vx-xxx/checkpoint-xxx']
109+
110+
engine = PtEngine(model, max_batch_size=64, adapters=adapters)
111+
system = """Answer the following questions as best you can. You have access to the following APIs:
112+
1. trailFinder: Call this tool to interact with the trailFinder API. What is the trailFinder API useful for? API for finding nearby hiking trails based on user input.. Parameters: [{"name": "location", "description": "User's current location.", "required": "True"}, {"name": "distance", "description": "Maximum distance from user's location.", "required": "False"}, {"name": "difficulty", "description": "Specify the difficulty level of the trail.", "required": "False"}]
113+
114+
2. Factorial calculator: Call this tool to interact with the Factorial calculator API. What is the Factorial calculator API useful for? 计算正整数的阶乘. Parameters: [{"name": "n", "description": "需要计算阶乘的正整数", "required": "False"}]
115+
116+
3. weather: Call this tool to interact with the weather API. What is the weather API useful for? 天气查询API,查询指定城市的实时天气情况. Parameters: [{"name": "city", "description": "指定查询的城市名称", "required": "False"}, {"name": "date", "description": "指定查询的日期", "required": "False"}]
117+
118+
4. English to Chinese Translator: Call this tool to interact with the English to Chinese Translator API. What is the English to Chinese Translator API useful for? 将英文翻译成中文. Parameters: [{"name": "english_text", "description": "需要翻译的英文文本", "required": "True"}, {"name": "target_language", "description": "目标语言(默认为中文)", "required": "False"}]
119+
120+
Use the following format:
121+
122+
Thought: you should always think about what to do
123+
Action: the action to take, should be one of the above tools[trailFinder, Factorial calculator, weather, English to Chinese Translator]
124+
Action Input: the input to the action
125+
Observation: the result of the action
126+
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
127+
Thought: I now know the final answer
128+
Final Answer: the final answer to the original input question
129+
Begin!
130+
"""
131+
request_config = RequestConfig(max_tokens=512, temperature=0, stop=['Observation:'], stream=True)
132+
messages = [{'role': 'system', 'content': system}]
133+
query = '北京今天的天气怎么样?'
134+
messages += [{'role': 'user', 'content': query}]
135+
gen_list = engine.infer([InferRequest(messages=messages)], request_config)
136+
response = ''
137+
tool = '{"temperature": 72, "condition": "Sunny", "humidity": 50}\n'
138+
print(f'query: {query}')
139+
for resp in gen_list[0]:
140+
if resp is None:
141+
continue
142+
delta = resp.choices[0].delta.content
143+
response += delta
144+
print(delta, end='', flush=True)
145+
tool = "{'temp': 25, 'description': 'Partly cloudy', 'status': 'success'}"
146+
print(tool)
147+
messages += [{'role': 'assistant', 'content': response}, {'role': 'tool', 'content': tool}]
148+
gen_list = engine.infer([InferRequest(messages=messages)], request_config)
149+
for resp in gen_list[0]:
150+
if resp is None:
151+
continue
152+
print(resp.choices[0].delta.content, end='', flush=True)
153+
print()
154+
"""
155+
query: 北京今天的天气怎么样?
156+
Action: weather
157+
Action Input: {'city': '北京', 'date': '今天'}
158+
Observation:{'temp': 25, 'description': 'Partly cloudy', 'status': 'success'}
159+
Thought: I now know the final answer
160+
Final Answer: 根据API调用结果,北京今天的天气是部分多云,温度为25摄氏度。
161+
"""
162+
```

examples/train/agent/train.sh

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
1-
# 2*24GiB
2-
nproc_per_node=2
3-
4-
CUDA_VISIBLE_DEVICES=0,1 \
5-
MASTER_PORT=29501 \
6-
NPROC_PER_NODE=$nproc_per_node \
1+
# 24GB
2+
CUDA_VISIBLE_DEVICES=0 \
73
swift sft \
8-
--model Qwen/Qwen2.5-7B-Instruct \
4+
--model Qwen/Qwen2.5-3B \
5+
--template default \
96
--train_type lora \
10-
--dataset swift/ToolBench \
7+
--dataset iic/ms_agent \
118
--loss_scale react \
129
--tools_prompt react_en \
1310
--torch_dtype bfloat16 \
@@ -18,14 +15,16 @@ swift sft \
1815
--lora_rank 8 \
1916
--lora_alpha 32 \
2017
--target_modules all-linear \
21-
--gradient_accumulation_steps $(expr 32 / $nproc_per_node) \
22-
--eval_steps 500 \
23-
--save_steps 500 \
18+
--gradient_accumulation_steps 8 \
19+
--eval_steps 50 \
20+
--save_steps 50 \
2421
--save_total_limit 2 \
2522
--logging_steps 5 \
2623
--max_length 8192 \
24+
--packing true \
25+
--use_liger_kernel true \
2726
--output_dir output \
2827
--warmup_ratio 0.05 \
28+
--attn_impl flash_attn \
2929
--dataloader_num_workers 4 \
30-
--deepspeed zero2 \
3130
--dataset_num_proc 16

requirements/install_all.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ pip install auto_gptq optimum bitsandbytes -U
77
pip install git+https://github.com/modelscope/ms-swift.git#egg=ms-swift[all]
88
pip install timm -U
99
pip install deepspeed -U
10-
pip install qwen_vl_utils qwen_omni_utils decord librosa pyav icecream soundfile -U
10+
pip install qwen_vl_utils qwen_omni_utils decord librosa pyav icecream soundfile liger_kernel -U
1111
# flash-attn: https://github.com/Dao-AILab/flash-attention/releases

swift/llm/template/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def packing_row(self, row: List[Tuple[Dict[str, Any], int]]) -> Dict[str, Any]:
409409
labels[0] = -100
410410
labels_list.append(labels)
411411
packed[key] = sum(labels_list, start=[])
412-
elif key == 'input_ids':
412+
elif key in {'input_ids', 'loss_scale'}:
413413
packed[key] = sum((x[0][key] for x in row), start=[])
414414
if 'position_ids' not in packed:
415415
packed['position_ids'] = sum((list(range(x[1])) for x in row), start=[])
@@ -1234,8 +1234,10 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
12341234
res = {}
12351235
if packing_mode:
12361236
# only support llm
1237-
for k in ['input_ids', 'labels', 'position_ids']:
1238-
res[k] = [self.gather_list(batch, k)]
1237+
for k in ['input_ids', 'labels', 'position_ids', 'loss_scale']:
1238+
v = self.gather_list(batch, k)
1239+
if v:
1240+
res[k] = [v]
12391241
else:
12401242
inputs_embeds = [b['inputs_embeds'] for b in batch if b.get('inputs_embeds') is not None]
12411243
input_ids = [b['input_ids'] for b in batch if b.get('input_ids') is not None]

swift/llm/template/template_inputs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def messages_join_observation(messages: Messages, tools_prompt='react_en') -> No
214214
messages = [
215215
{'role': 'user', 'content': "What's the weather today in Hangzhou?"},
216216
{'role': 'assistant', 'content': 'Action: get_weather\nAction Input:\
217-
[{"location": "Hangzhou"}]\nObservations: It is 26 degrees Celsius and sunny in Hangzhou today.'}
217+
[{"location": "Hangzhou"}]\nObservations: It is 26 degrees Celsius and sunny in Hangzhou today.\n'}
218218
]
219219
"""
220220
if len(messages) < 2:
@@ -229,7 +229,7 @@ def messages_join_observation(messages: Messages, tools_prompt='react_en') -> No
229229
if (pre_role == 'assistant' and role == 'tool' and isinstance(pre_content, str)
230230
and pre_content.endswith(keyword.get('observation'))):
231231
assert isinstance(pre_content, str)
232-
pre_message['content'] = pre_content + content # assistant
232+
pre_message['content'] = pre_content + content + '\n' # assistant
233233
messages.pop(i) # remove tool
234234
elif (pre_role == 'assistant' and role == 'assistant' and isinstance(pre_content, str)
235235
and isinstance(content, str)):

0 commit comments

Comments
 (0)