Skip to content

Commit 0937497

Browse files
Support qwen agent format (#2722)
1 parent da336a3 commit 0937497

File tree

10 files changed

+128
-28
lines changed

10 files changed

+128
-28
lines changed

swift/llm/infer/infer_engine/infer_engine.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,13 @@ def _gen_wrapper():
136136
pass
137137
return self._update_metrics(res, metrics)
138138

139-
def _get_toolcall(self, response: Union[str, List[Dict[str,
140-
Any]]]) -> Optional[List[ChatCompletionMessageToolCall]]:
139+
def _get_toolcall(self,
140+
response: Union[str, List[Dict[str, Any]]],
141+
tools_prompt='react_en') -> Optional[List[ChatCompletionMessageToolCall]]:
141142
if not isinstance(response, str):
142143
response = '\n'.join([resp['text'] for resp in response if resp['type'] == 'text'])
143144

144-
action, action_input = split_action_action_input(response)
145+
action, action_input = split_action_action_input(response, tools_prompt=tools_prompt)
145146
if action is None:
146147
return None
147148

swift/llm/infer/infer_engine/lmdeploy_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ async def _infer_stream_async(
211211
usage_info = self._get_usage_info(len(inputs['input_ids']), output.num_token)
212212
toolcall = None
213213
if is_finished:
214-
toolcall = self._get_toolcall(template.decode(output.token_ids))
214+
toolcall = self._get_toolcall(template.decode(output.token_ids), template.tools_prompt)
215215
finish_reason = self._get_finish_reason(generation_config.max_new_tokens, output.num_token,
216216
output.status.name == 'FINISH')
217217
choices = [
@@ -237,7 +237,7 @@ async def _infer_full_async(self, template: Template, inputs: Dict[str, Any],
237237
logprobs = self._get_logprobs(template.tokenizer, output.logprobs, output.token_ids, generation_config.logprobs)
238238

239239
usage_info = self._get_usage_info(len(inputs['input_ids']), output.num_token)
240-
toolcall = self._get_toolcall(response)
240+
toolcall = self._get_toolcall(response, template.tools_prompt)
241241
finish_reason = self._get_finish_reason(generation_config.max_new_tokens, output.num_token,
242242
output.status.name == 'FINISH')
243243
choices = [

swift/llm/infer/infer_engine/pt_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def _model_generate(**kwargs):
230230
usage_info = self._get_usage_info(num_prompt_tokens, len(generate_ids))
231231
toolcall = None
232232
if is_finished[i]:
233-
toolcall = self._get_toolcall(template.decode(generate_ids))
233+
toolcall = self._get_toolcall(template.decode(generate_ids), template.tools_prompt)
234234
finish_reason = self._get_finish_reason(generation_config.max_new_tokens, num_prompt_tokens,
235235
is_finished[i])
236236

@@ -291,7 +291,7 @@ def _infer_full(self,
291291
usage_info = self._get_usage_info(num_prompt_tokens, len(generate_ids))
292292
response = template.decode(generate_ids, template_inputs=template_inputs[i])
293293
finish_reason = self._get_finish_reason(generation_config.max_new_tokens, num_prompt_tokens, True)
294-
toolcall = self._get_toolcall(response)
294+
toolcall = self._get_toolcall(response, template.tools_prompt)
295295
choices = [
296296
ChatCompletionResponseChoice(
297297
index=0,

swift/llm/infer/infer_engine/vllm_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ async def _infer_stream_async(self, template: Template, inputs: Dict[str, Any],
300300
token_idxs[output.index] = len(output.token_ids)
301301
toolcall = None
302302
if output.is_finished:
303-
toolcall = self._get_toolcall(template.decode(output.token_ids))
303+
toolcall = self._get_toolcall(template.decode(output.token_ids), template.tools_prompt)
304304
choice = ChatCompletionResponseStreamChoice(
305305
index=output.index,
306306
delta=DeltaMessage(role='assistant', content=output.delta_text, tool_calls=toolcall),
@@ -328,7 +328,7 @@ async def _infer_full_async(self,
328328
response = template.decode(output.token_ids)
329329
logprobs = self._get_logprobs(template.tokenizer, output.logprobs, output.token_ids,
330330
generation_config.logprobs)
331-
toolcall = self._get_toolcall(response)
331+
toolcall = self._get_toolcall(response, template.tools_prompt)
332332
choice = ChatCompletionResponseChoice(
333333
index=output.index,
334334
message=ChatMessage(role='assistant', content=response, tool_calls=toolcall),

swift/llm/template/template_inputs.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ class StdTemplateInputs:
113113
videos: List[str] = field(default_factory=list)
114114
objects: List[Dict[str, Any]] = field(default_factory=list)
115115

116+
agent_keyword: Optional[Dict[str, str]] = None
117+
116118
def __post_init__(self):
117119
self.image_idx = 0
118120
self.audio_idx = 0
@@ -125,6 +127,8 @@ def __post_init__(self):
125127
self.videos = [self.videos]
126128
if self.audios and not isinstance(self.audios, (list, tuple)):
127129
self.audios = [self.audios]
130+
if self.agent_keyword is None:
131+
self.agent_keyword = {}
128132

129133
def to_history(self):
130134
if not self.messages:
@@ -137,7 +141,7 @@ def is_multimodal(self):
137141

138142
@classmethod
139143
def from_dict(cls, inputs: Dict[str, Any], *, tools_prompt: str = 'react_en') -> 'StdTemplateInputs':
140-
from swift.plugin import get_tools_prompt
144+
from swift.plugin import get_tools_prompt, get_tools_keyword
141145
inputs = deepcopy(inputs)
142146
kwargs = {}
143147
for key in ['rejected_response', 'label']:
@@ -153,12 +157,15 @@ def from_dict(cls, inputs: Dict[str, Any], *, tools_prompt: str = 'react_en') ->
153157
else:
154158
system = None
155159

160+
keyword = None
156161
if tools is not None:
157162
if system is not None:
158-
logger.warning_once('You have tools prompt but you also have a system field, which will be ignored')
163+
logger.warning_once(
164+
'You have tools prompt but you also have a system field, so the system field will be ignored')
159165
if isinstance(tools, str):
160166
tools = json.loads(tools)
161167
system = get_tools_prompt(tools, tools_prompt)
168+
keyword = get_tools_keyword(tools_prompt)
162169

163170
media_kwargs = StdTemplateInputs.remove_messages_media(messages)
164171
for k in list(media_kwargs.keys()):
@@ -173,8 +180,8 @@ def from_dict(cls, inputs: Dict[str, Any], *, tools_prompt: str = 'react_en') ->
173180
else:
174181
media_kwargs[k] = inputs_mm_data
175182

176-
StdTemplateInputs.messages_join_observation(messages)
177-
return cls(messages=messages, system=system, objects=objects, **kwargs, **media_kwargs)
183+
StdTemplateInputs.messages_join_observation(messages, tools_prompt)
184+
return cls(messages=messages, system=system, objects=objects, agent_keyword=keyword, **kwargs, **media_kwargs)
178185

179186
@staticmethod
180187
def remove_messages_media(messages: Messages) -> Dict[str, Any]:
@@ -204,7 +211,7 @@ def remove_messages_media(messages: Messages) -> Dict[str, Any]:
204211
return res
205212

206213
@staticmethod
207-
def messages_join_observation(messages: Messages) -> None:
214+
def messages_join_observation(messages: Messages, tools_prompt='react_en') -> None:
208215
"""
209216
Joins observations from 'tool' message into the 'assistant' response.
210217
@@ -228,12 +235,14 @@ def messages_join_observation(messages: Messages) -> None:
228235
if len(messages) < 2:
229236
return
230237
i = 1
238+
from swift.plugin import get_tools_keyword
239+
keyword = get_tools_keyword(tools_prompt)
231240
while i < len(messages):
232241
pre_message, message = messages[i - 1], messages[i]
233242
pre_role, pre_content = pre_message['role'], pre_message['content']
234243
role, content = message['role'], message['content']
235-
if pre_role == 'assistant' and role == 'tool' and isinstance(pre_content,
236-
str) and pre_content.endswith('Observation:'):
244+
if (pre_role == 'assistant' and role == 'tool' and isinstance(pre_content, str)
245+
and pre_content.endswith(keyword.get('observation'))):
237246
assert isinstance(pre_content, str)
238247
pre_message['content'] = pre_content + content # assistant
239248
messages.pop(i) # remove tool

swift/llm/template/utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,18 +184,24 @@ def split_parts_by_regex(text_list: list, regex_delimiters: Dict[str, List[float
184184
text_list[i:i + 1] = segments
185185

186186

187-
def split_action_action_input(response: str) -> Tuple[Optional[str], Optional[str]]:
187+
def split_action_action_input(response: str, tools_prompt='react_en') -> Tuple[Optional[str], Optional[str]]:
188+
188189
agent_keyword = [
189190
'action:', 'Action:', 'ACTION:', 'action input:', 'Action Input:', 'Action input:', 'ACTION INPUT:', 'Thought:',
190191
'Final Answer:', 'Observation:'
191192
]
193+
from swift.plugin import get_tools_keyword
194+
keyword = get_tools_keyword(tools_prompt)
195+
for key in keyword.values():
196+
if key not in agent_keyword:
197+
agent_keyword.append(key)
192198
agent_parts = split_str_parts_by(response, agent_keyword)
193199
action = None
194200
action_input = None
195201
for c in agent_parts:
196-
if c['key'].lower() == 'action:':
202+
if c['key'].lower() == keyword['action'].lower():
197203
action = c['content']
198-
elif c['key'].lower() == 'action input:':
204+
elif c['key'].lower() == keyword['action_input'].lower():
199205
action_input = c['content']
200206
if action:
201207
action = action.strip().replace('\n', '')

swift/plugin/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .loss_scale import loss_scale_map
1111
from .metric import InferStats, MeanMetric, Metric, compute_acc, get_metric
1212
from .optimizer import optimizers_map
13-
from .tools import get_tools_prompt
13+
from .tools import get_tools_prompt, get_tools_keyword
1414
from .tuner import Tuner, extra_tuners
1515

1616
else:
@@ -22,7 +22,7 @@
2222
'loss_scale': ['loss_scale_map'],
2323
'metric': ['InferStats', 'MeanMetric', 'Metric', 'compute_acc', 'get_metric'],
2424
'optimizer': ['optimizers_map'],
25-
'tools': ['get_tools_prompt'],
25+
'tools': ['get_tools_prompt', 'get_tools_keyword'],
2626
'tuner': ['Tuner', 'extra_tuners'],
2727
}
2828

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"✿FUNCTION✿:": [2.0, 2.0],
3+
"✿ARGS✿:": [2.0, 2.0],
4+
"✿RETURN✿:": [1.0, 1.0],
5+
"✿RESULT✿:": [2.0, 0.0]
6+
}

swift/plugin/loss_scale.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,20 @@ def get_loss_scale(self,
154154
return super().get_loss_scale(context, context_type, is_last_round)
155155

156156

157+
class QwenLossScale(LossScale):
158+
loss_scale_config = 'qwen_loss_scale_config.json'
159+
160+
def get_loss_scale(self,
161+
context: str,
162+
context_type: ContextType,
163+
is_last_round: bool,
164+
*,
165+
query: Optional[str] = None):
166+
if context_type == ContextType.RESPONSE:
167+
return calculate_loss_scale(query, context, self.loss_scale_map)
168+
return super().get_loss_scale(context, context_type, is_last_round)
169+
170+
157171
class AlphaUmiLossScale(REACTLossScale):
158172
loss_scale_config = 'alpha_umi_loss_scale_config.json'
159173

@@ -171,5 +185,6 @@ def get_loss_scale(self, context: str, context_type: ContextType, *args, **kwarg
171185
'alpha_umi': AlphaUmiLossScale(),
172186
'default': LossScale(),
173187
'last_round': LastRoundLossScale(),
188+
'qwen': QwenLossScale(),
174189
'all': TrainAllLossScale(),
175190
}

swift/plugin/tools.py

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
import datetime as dt
3+
from dataclasses import dataclass
24
from typing import Dict, List, Optional, Union
35

6+
import json
7+
8+
9+
@dataclass
10+
class AgentKeyword:
11+
action: str = 'Action:'
12+
action_input: str = 'Action Input:'
13+
observation: str = 'Observation:'
14+
415

516
def format_react_en(tool_names, tool_descs):
617
REACT_PROMPT = """Answer the following questions as best as you can. You have access to the following tools:
@@ -18,6 +29,7 @@ def format_react_en(tool_names, tool_descs):
1829
1930
Begin!
2031
"""
32+
tool_descs = [json.dumps(t) if not isinstance(t, str) else t for t in tool_descs]
2133
return REACT_PROMPT.format(tool_list='\n\n'.join(tool_descs), tool_names=','.join(tool_names))
2234

2335

@@ -37,6 +49,7 @@ def format_react_zh(tool_names, tool_descs):
3749
3850
开始!
3951
"""
52+
tool_descs = [json.dumps(t) if not isinstance(t, str) else t for t in tool_descs]
4053
return REACT_ZH_PROMPT.format(tool_list='\n\n'.join(tool_descs), tool_names=','.join(tool_names))
4154

4255

@@ -46,6 +59,7 @@ def format_glm4(tool_names, tool_descs):
4659
# 可用工具
4760
4861
{tool_list}"""
62+
tool_descs = [json.dumps(t) if not isinstance(t, str) else t for t in tool_descs]
4963
tool_list = ''
5064
for name, tool in zip(tool_names, tool_descs):
5165
tool_list += f'## {name}\n\n{tool}\n\n'
@@ -78,28 +92,72 @@ def format_toolbench(tool_names, tool_descs):
7892
use function Finish->give_up_and_restart.
7993
2.Do not use origin tool names, use only subfunctions' names.
8094
Specifically, you have access to the following APIs: {tool_list}"""
95+
tool_descs = [json.dumps(t) if not isinstance(t, str) else t for t in tool_descs]
8196
return TOOLBENCH_PROMPT.format(tool_list='\n\n'.join(tool_descs))
8297

8398

99+
def format_qwen(tool_names, tool_descs):
100+
PROMPT = '''You are a helpful assistant.
101+
102+
当前时间:{date}
103+
104+
# 工具
105+
106+
## 你拥有如下工具:
107+
108+
{tool_list}
109+
110+
## 你可以在回复中插入以下命令以调用这些工具:
111+
112+
{format_list}
113+
'''
114+
# 定义星期映射
115+
weekdays = {0: '星期一', 1: '星期二', 2: '星期三', 3: '星期四', 4: '星期五', 5: '星期六', 6: '星期日'}
116+
now = dt.datetime.now()
117+
year = now.year
118+
month = now.month
119+
day = now.day
120+
weekday = weekdays[now.weekday()]
121+
formatted_date = f'{year}{month:02d}{day:02d}日,{weekday}'
122+
PROMPT = PROMPT.replace('{date}', formatted_date)
123+
tool_list = ''
124+
for name, tool in zip(tool_names, tool_descs):
125+
tool_list += f'### {name} \n{name}: {tool["description"]} 输入参数: {json.dumps(tool["parameters"])}\n'
126+
127+
PROMPT = PROMPT.replace('{tool_list}', tool_list)
128+
129+
format_list = ''
130+
for i, _ in enumerate(tool_names):
131+
format_list += f'✿FUNCTION✿:工具{i+1}的名称\n✿ARGS✿:工具{i + 1}的输入\n✿RESULT✿:工具{i + 1}的结果\n'
132+
PROMPT = PROMPT.replace('{format_list}', format_list)
133+
return PROMPT
134+
135+
84136
def format_custom(tool_names, tool_descs):
85137
PROMPT = '''你是一个人工智能助手。你的任务是针对用户的问题和要求提供适当的答复和支持。
86138
87139
# 可用工具
88140
89141
{tool_list}'''
90142
tool_list = ''
143+
tool_descs = [json.dumps(t) if not isinstance(t, str) else t for t in tool_descs]
91144
for name, tool in zip(tool_names, tool_descs):
92145
tool_list += f'## {name}\n\n{tool}\n\n'
93146
return PROMPT.format(tool_list=tool_list)
94147

95148

96149
# Add your prompt here, use --tools_prompt to train
97150
tools_prompt = {
98-
'react_en': format_react_en,
99-
'react_zh': format_react_zh,
100-
'glm4': format_glm4,
101-
'toolbench': format_toolbench,
102-
'custom': format_custom,
151+
'react_en': (format_react_en, AgentKeyword().__dict__),
152+
'react_zh': (format_react_zh, AgentKeyword().__dict__),
153+
'glm4': (format_glm4, AgentKeyword().__dict__),
154+
'toolbench': (format_toolbench, AgentKeyword().__dict__),
155+
'qwen': (format_qwen, AgentKeyword(
156+
action='✿FUNCTION✿:',
157+
action_input='✿ARGS✿:',
158+
observation='✿RESULT✿:',
159+
).__dict__),
160+
'custom': (format_custom, AgentKeyword().__dict__),
103161
}
104162

105163

@@ -111,10 +169,15 @@ def get_tools_prompt(tools: List[Dict[str, Union[str, Dict]]], prompt_format: st
111169
if isinstance(info, dict) and 'function' in info:
112170
info = info['function']
113171
tool_names.append(info['name'])
114-
tool_descs.append(str(info)) # info: dict
172+
tool_descs.append(info) # info: dict
115173
except KeyError:
116174
print('invalid tools format, please check'
117175
'https://github.com/modelscope/swift/blob/main/docs/source_en/LLM/Agent-deployment-best-practice.md')
118176
return None
119-
prompt_format = tools_prompt.get(prompt_format) or format_toolbench
177+
prompt_format = tools_prompt.get(prompt_format, (None, None))[0] or format_toolbench
120178
return prompt_format(tool_names, tool_descs)
179+
180+
181+
def get_tools_keyword(prompt_format: str = 'react_en') -> Dict[str, str]:
182+
keyword = tools_prompt.get(prompt_format, (None, None))[1] or AgentKeyword().__dict__
183+
return keyword

0 commit comments

Comments
 (0)