Skip to content

Commit adb6f8f

Browse files
committed
Merge branch 'main' into release/3.1
2 parents c3e1da4 + 1701e17 commit adb6f8f

File tree

3 files changed

+75
-16
lines changed

3 files changed

+75
-16
lines changed

swift/llm/template/template_inputs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,5 +232,10 @@ def messages_join_observation(messages: Messages, tools_prompt='react_en') -> No
232232
assert isinstance(pre_content, str)
233233
pre_message['content'] = pre_content + content # assistant
234234
messages.pop(i) # remove tool
235+
elif (pre_role == 'assistant' and role == 'assistant' and isinstance(pre_content, str)
236+
and isinstance(content, str)):
237+
# Consecutive messages from the assistant role need to be merged to prevent errors.
238+
pre_message['content'] = pre_content + content
239+
messages.pop(i)
235240
else:
236241
i += 1

swift/plugin/tools.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def format_react_en(tool_names, tool_descs):
2929
3030
Begin!
3131
"""
32-
tool_descs = [json.dumps(t) if not isinstance(t, str) else t for t in tool_descs]
32+
tool_descs = [json.dumps(t, ensure_ascii=False) if not isinstance(t, str) else t for t in tool_descs]
3333
return REACT_PROMPT.format(tool_list='\n\n'.join(tool_descs), tool_names=','.join(tool_names))
3434

3535

@@ -49,7 +49,7 @@ def format_react_zh(tool_names, tool_descs):
4949
5050
开始!
5151
"""
52-
tool_descs = [json.dumps(t) if not isinstance(t, str) else t for t in tool_descs]
52+
tool_descs = [json.dumps(t, ensure_ascii=False) if not isinstance(t, str) else t for t in tool_descs]
5353
return REACT_ZH_PROMPT.format(tool_list='\n\n'.join(tool_descs), tool_names=','.join(tool_names))
5454

5555

@@ -59,7 +59,7 @@ def format_glm4(tool_names, tool_descs):
5959
# 可用工具
6060
6161
{tool_list}"""
62-
tool_descs = [json.dumps(t) if not isinstance(t, str) else t for t in tool_descs]
62+
tool_descs = [json.dumps(t, ensure_ascii=False) if not isinstance(t, str) else t for t in tool_descs]
6363
tool_list = ''
6464
for name, tool in zip(tool_names, tool_descs):
6565
tool_list += f'## {name}\n\n{tool}\n\n'
@@ -92,7 +92,7 @@ def format_toolbench(tool_names, tool_descs):
9292
use function Finish->give_up_and_restart.
9393
2.Do not use origin tool names, use only subfunctions' names.
9494
Specifically, you have access to the following APIs: {tool_list}"""
95-
tool_descs = [json.dumps(t) if not isinstance(t, str) else t for t in tool_descs]
95+
tool_descs = [json.dumps(t, ensure_ascii=False) if not isinstance(t, str) else t for t in tool_descs]
9696
return TOOLBENCH_PROMPT.format(tool_list='\n\n'.join(tool_descs))
9797

9898

@@ -107,10 +107,20 @@ def format_qwen(tool_names, tool_descs):
107107
108108
{tool_list}
109109
110-
## 你可以在回复中插入以下命令以调用这些工具:
111-
112-
{format_list}
113-
'''
110+
## 你可以在回复中插入以下命令以并行调用N个工具:
111+
112+
✿FUNCTION✿: 工具1的名称,必须是[{tool_names}]之一
113+
✿ARGS✿: 工具1的输入
114+
✿FUNCTION✿: 工具2的名称
115+
✿ARGS✿: 工具2的输入
116+
...
117+
✿FUNCTION✿: 工具N的名称
118+
✿ARGS✿: 工具N的输入
119+
✿RESULT✿: 工具1的结果
120+
✿RESULT✿: 工具2的结果
121+
...
122+
✿RESULT✿: 工具N的结果
123+
✿RETURN✿: 根据工具结果进行回复'''
114124
# 定义星期映射
115125
weekdays = {0: '星期一', 1: '星期二', 2: '星期三', 3: '星期四', 4: '星期五', 5: '星期六', 6: '星期日'}
116126
now = dt.datetime.now()
@@ -122,15 +132,13 @@ def format_qwen(tool_names, tool_descs):
122132
PROMPT = PROMPT.replace('{date}', formatted_date)
123133
tool_list = ''
124134
for name, tool in zip(tool_names, tool_descs):
125-
tool_list += f'### {name} \n{name}: {tool["description"]} 输入参数: {json.dumps(tool["parameters"])}\n'
135+
desc = tool.get('description', '')
136+
parameters = json.dumps(params, ensure_ascii=False) if (params := tool.get('parameters')) else ''
137+
tool_list += f'### {name}\n\n{name}: {desc} 输入参数: {parameters} 此工具的输入应为JSON对象。'
126138

127139
PROMPT = PROMPT.replace('{tool_list}', tool_list)
128-
129-
format_list = ''
130-
for i, _ in enumerate(tool_names):
131-
format_list += f'✿FUNCTION✿:工具{i+1}的名称\n✿ARGS✿:工具{i + 1}的输入\n✿RESULT✿:工具{i + 1}的结果\n'
132-
PROMPT = PROMPT.replace('{format_list}', format_list)
133-
return PROMPT
140+
PROMPT = PROMPT.replace('{tool_names}', ','.join(tool_names))
141+
return PROMPT.rstrip()
134142

135143

136144
def format_custom(tool_names, tool_descs):
@@ -140,7 +148,7 @@ def format_custom(tool_names, tool_descs):
140148
141149
{tool_list}'''
142150
tool_list = ''
143-
tool_descs = [json.dumps(t) if not isinstance(t, str) else t for t in tool_descs]
151+
tool_descs = [json.dumps(t, ensure_ascii=False) if not isinstance(t, str) else t for t in tool_descs]
144152
for name, tool in zip(tool_names, tool_descs):
145153
tool_list += f'## {name}\n\n{tool}\n\n'
146154
return PROMPT.format(tool_list=tool_list)

tests/llm/test_template.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,52 @@ def test_template(self):
4343
response2 = _infer_model(pt_engine)
4444
assert response == response2
4545

46+
def test_tool_message_join(self):
47+
from copy import deepcopy
48+
49+
from swift.llm.template.template_inputs import StdTemplateInputs
50+
from swift.plugin.tools import get_tools_keyword
51+
52+
messages = [
53+
# first round
54+
{
55+
'role': 'user',
56+
'content': 'testing_user_message'
57+
},
58+
{
59+
'role': 'assistant',
60+
'content': ''
61+
},
62+
{
63+
'role': 'tool',
64+
'content': ''
65+
},
66+
# second round
67+
{
68+
'role': 'assistant',
69+
'content': ''
70+
},
71+
{
72+
'role': 'tool',
73+
'content': ''
74+
},
75+
]
76+
77+
# testing two template type.
78+
for tool_prompt in ('react_en', 'qwen'):
79+
tool_prompt = 'react_en'
80+
test_messages = deepcopy(messages)
81+
obs_word = get_tools_keyword(tool_prompt).get('observation')
82+
test_messages[1]['content'] = f'{obs_word}'
83+
test_messages[2]['content'] = 'first_round_result\n'
84+
test_messages[3]['content'] = f'{obs_word}'
85+
test_messages[4]['content'] = 'second_round_result\n'
86+
StdTemplateInputs.messages_join_observation(test_messages, tools_prompt=tool_prompt)
87+
88+
# multi-round tool calling should be joined that only one assistant message left.
89+
assert len(test_messages) == 2, f'Tool prompot {tool_prompt} join failed, {messages}'
90+
assert test_messages[1]['content'] == f"""{obs_word}first_round_result\n{obs_word}second_round_result\n"""
91+
4692

4793
if __name__ == '__main__':
4894
unittest.main()

0 commit comments

Comments
 (0)