Skip to content

Commit f9a925a

Browse files
authored
[model] support glm-4.5 agent_template (#5305)
1 parent 7cfb5c8 commit f9a925a

File tree

3 files changed

+87
-3
lines changed

3 files changed

+87
-3
lines changed

swift/llm/template/template/glm.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ class GLM4_0414TemplateMeta(GLM4TemplateMeta):
6565
agent_template: str = 'glm4_0414'
6666

6767

68+
@dataclass
69+
class GLM4_5TemplateMeta(GLM4_0414TemplateMeta):
70+
agent_template: str = 'glm4_5'
71+
72+
6873
class GLM4_1VTemplateMeta(GLM4_0414TemplateMeta):
6974
system_prefix: Optional[Prompt] = field(default_factory=lambda: ['[gMASK]<sop><|system|>{{SYSTEM}}'])
7075

@@ -234,8 +239,18 @@ class GLM4_5Template(ThinkingTemplate):
234239
no_think_prefix = '<think></think>\n'
235240
history_think_prefix = '<think></think>\n'
236241

242+
def _swift_encode(self, inputs: StdTemplateInputs):
243+
res_context_list, loss_scale_list, answer_len = super()._swift_encode(inputs)
244+
# When it's a tool_call, avoid generating <|observation|><|user|>
245+
penultimate_content = res_context_list[-2] if len(res_context_list) >= 2 else None
246+
if isinstance(penultimate_content,
247+
str) and penultimate_content.endswith('<|observation|>') and res_context_list[-1] == '<|user|>':
248+
res_context_list = res_context_list[:-1]
249+
answer_len -= 1
250+
return res_context_list, loss_scale_list, answer_len
251+
237252

238-
register_template(GLM4_0414TemplateMeta(LLMTemplateType.glm4_5, template_cls=GLM4_5Template))
253+
register_template(GLM4_5TemplateMeta(LLMTemplateType.glm4_5, template_cls=GLM4_5Template))
239254

240255
register_template(GLM4_1VTemplateMeta(MLLMTemplateType.glm4_1v, template_cls=GLM4_1VTemplate))
241256

swift/plugin/agent_template/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
from .base import BaseAgentTemplate
33
from .extra import ReactGRPOAgentTemplate
4-
from .glm4 import GLM4_0414AgentTemplate, GLM4AgentTemplate
4+
from .glm4 import GLM4_5AgentTemplate, GLM4_0414AgentTemplate, GLM4AgentTemplate
55
from .hermes import HermesAgentTemplate, HunyuanHermesAgentTemplate
66
from .llama import Llama3AgentTemplate, Llama4AgentTemplate
77
from .mistral import MistralAgentTemplate
@@ -23,6 +23,7 @@
2323
'toolbench': ToolBenchAgentTemplate, # ref: https://modelscope.cn/datasets/swift/ToolBench
2424
'glm4': GLM4AgentTemplate,
2525
'glm4_0414': GLM4_0414AgentTemplate, # ref: https://modelscope.cn/models/ZhipuAI/GLM-4-9B-0414
26+
'glm4_5': GLM4_5AgentTemplate,
2627
'llama3': Llama3AgentTemplate,
2728
'llama4': Llama4AgentTemplate,
2829
# extra

swift/plugin/agent_template/glm4.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ def _find_function_call(single_content: str) -> Optional['Function']:
2222
matches = pattern.findall(single_content)
2323
if not matches:
2424
return
25-
2625
name, arguments = matches[0]
2726
return Function(name=name, arguments=arguments)
2827

@@ -77,3 +76,72 @@ def _format_tool_calls(self, tool_call_messages) -> str:
7776

7877
class GLM4_0414AgentTemplate(GLM4AgentTemplate):
7978
is_glm4_0414 = True
79+
80+
81+
class GLM4_5AgentTemplate(BaseAgentTemplate):
82+
83+
@staticmethod
84+
def _find_function_call(single_content: str) -> Optional['Function']:
85+
from swift.llm.infer import Function
86+
single_content = single_content.strip()
87+
func_name_match = re.match(r'^([^\n<]+)', single_content)
88+
if not func_name_match:
89+
return None
90+
func_name = func_name_match.group(1).strip()
91+
keys = re.findall(r'<arg_key>(.*?)</arg_key>', single_content, re.DOTALL)
92+
values = re.findall(r'<arg_value>(.*?)</arg_value>', single_content, re.DOTALL)
93+
if len(keys) != len(values):
94+
return None
95+
args = {k.strip(): v.strip() for k, v in zip(keys, values)}
96+
return Function(name=func_name, arguments=json.dumps(args, ensure_ascii=False))
97+
98+
def get_toolcall(self, response: str) -> List['Function']:
99+
toolcall_list = re.findall(r'<tool_call>(.*?)</tool_call>', response, re.DOTALL)
100+
functions = []
101+
for toolcall in toolcall_list:
102+
function = self._find_function_call(toolcall)
103+
if function:
104+
functions.append(function)
105+
if len(functions) == 0:
106+
# compat react_en
107+
return super().get_toolcall(response)
108+
return functions
109+
110+
def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
111+
tool_descs = [
112+
'# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>'
113+
]
114+
for tool in tools:
115+
tool_descs.append(f'{json.dumps(tool, ensure_ascii=False)}')
116+
tool_descs.append(
117+
'</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}\n<arg_key>{arg-key-1}</arg_key>\n<arg_value>{arg-value-1}</arg_value>\n<arg_key>{arg-key-2}</arg_key>\n<arg_value>{arg-value-2}</arg_value>\n...\n</tool_call>'
118+
)
119+
tool_descs = '\n'.join(tool_descs)
120+
if system.strip():
121+
tool_descs += '<|system|>\n' + system.strip()
122+
return tool_descs
123+
124+
def _format_tool_responses(
125+
self,
126+
assistant_content: str,
127+
tool_messages,
128+
) -> Tuple[str, 'Prompt']:
129+
with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content
130+
if with_action:
131+
return super()._format_tool_responses(assistant_content, tool_messages)
132+
res = []
133+
for _, tool_message in enumerate(tool_messages):
134+
tool_content = tool_message['content']
135+
res.append(f'\n<tool_response>\n{tool_content}\n</tool_response>')
136+
res.append('<|assistant|>\n')
137+
return assistant_content, res
138+
139+
def _format_tool_calls(self, tool_call_messages) -> str:
140+
tool_calls = []
141+
for message in tool_call_messages:
142+
tool_call = self._parse_tool_call(message['content'])
143+
tool_calls.append(f"<tool_call>{tool_call['name']}")
144+
for arg_key, arg_value in tool_call['arguments'].items():
145+
tool_calls.append(f'<arg_key>{arg_key}</arg_key>\n<arg_value>{arg_value}</arg_value>')
146+
tool_calls.append('</tool_call>')
147+
return '\n'.join(tool_calls) + '<|observation|>'

0 commit comments

Comments
 (0)