Skip to content

[model] support glm-4.5 agent #5305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,8 @@ def _swift_encode(self, inputs: StdTemplateInputs):
context_list.append('{{RESPONSE}}')
# self.is_training needed because we may want to continue generation from
# the current response
if self.is_training and not sep_token or self.task_type == 'embedding':
string_stop_words = tuple(s for s in template_meta.stop_words if isinstance(s, str))
if self.is_training and not sep_token and not response.endswith(string_stop_words) or self.task_type == 'embedding':
extra_context_list = template_meta.suffix
extra_context_type = ContextType.SUFFIX
elif template_meta.response_prefix:
Expand Down
13 changes: 12 additions & 1 deletion swift/llm/template/template/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ class GLM4_0414TemplateMeta(GLM4TemplateMeta):
agent_template: str = 'glm4_0414'


@dataclass
class GLM4_5TemplateMeta(GLMTemplateMeta):
prefix: Prompt = field(default_factory=lambda: ['[gMASK]<sop>'])
prompt: Prompt = field(default_factory=lambda: ['<|user|>\n{{QUERY}}<|assistant|>\n'])
chat_sep: Optional[Prompt] = field(default_factory=list)
suffix: Prompt = field(default_factory=lambda: ['<|user|>'])
system_prefix: Optional[Prompt] = field(default_factory=lambda: ['[gMASK]<sop><|system|>\n{{SYSTEM}}'])

agent_template: str = 'glm4_5'
stop_words: List[Word] = field(default_factory=lambda: ['<|endoftext|>', '<|user|>', '<|observation|>'])

class GLM4_1VTemplateMeta(GLM4_0414TemplateMeta):
system_prefix: Optional[Prompt] = field(default_factory=lambda: ['[gMASK]<sop><|system|>{{SYSTEM}}'])

Expand Down Expand Up @@ -237,7 +248,7 @@ def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:

register_template(GLM4_0414TemplateMeta(LLMTemplateType.glm4_0414, template_cls=GLM4_0414Template))

register_template(GLM4TemplateMeta(LLMTemplateType.glm4_5, template_cls=ThinkingTemplate))
register_template(GLM4_5TemplateMeta(LLMTemplateType.glm4_5, template_cls=ThinkingTemplate))

register_template(GLM4_1VTemplateMeta(MLLMTemplateType.glm4_1v, template_cls=GLM4_1VTemplate))

Expand Down
3 changes: 2 additions & 1 deletion swift/plugin/agent_template/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .base import BaseAgentTemplate
from .extra import ReactGRPOAgentTemplate
from .glm4 import GLM4_0414AgentTemplate, GLM4AgentTemplate
from .glm4 import GLM4_0414AgentTemplate, GLM4AgentTemplate, GLM4_5AgentTemplate
from .hermes import HermesAgentTemplate, HunyuanHermesAgentTemplate
from .llama import Llama3AgentTemplate, Llama4AgentTemplate
from .mistral import MistralAgentTemplate
Expand All @@ -23,6 +23,7 @@
'toolbench': ToolBenchAgentTemplate, # ref: https://modelscope.cn/datasets/swift/ToolBench
'glm4': GLM4AgentTemplate,
'glm4_0414': GLM4_0414AgentTemplate, # ref: https://modelscope.cn/models/ZhipuAI/GLM-4-9B-0414
'glm4_5': GLM4_5AgentTemplate,
'llama3': Llama3AgentTemplate,
'llama4': Llama4AgentTemplate,
# extra
Expand Down
68 changes: 67 additions & 1 deletion swift/plugin/agent_template/glm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def _find_function_call(single_content: str) -> Optional['Function']:
matches = pattern.findall(single_content)
if not matches:
return

name, arguments = matches[0]
return Function(name=name, arguments=arguments)

Expand Down Expand Up @@ -77,3 +76,70 @@ def _format_tool_calls(self, tool_call_messages) -> str:

class GLM4_0414AgentTemplate(GLM4AgentTemplate):
is_glm4_0414 = True


class GLM4_5AgentTemplate(BaseAgentTemplate):

@staticmethod
def _find_function_call(single_content: str) -> Optional['Function']:
from swift.llm.infer import Function
single_content = single_content.strip()
func_name_match = re.match(r'^([^\n<]+)', single_content)
if not func_name_match:
return None
func_name = func_name_match.group(1).strip()
keys = re.findall(r'<arg_key>(.*?)</arg_key>', single_content, re.DOTALL)
values = re.findall(r'<arg_value>(.*?)</arg_value>', single_content, re.DOTALL)
if len(keys) != len(values):
return None
args = {k.strip(): v.strip() for k, v in zip(keys, values)}
return Function(name=func_name, arguments=json.dumps(args, ensure_ascii=False))

def get_toolcall(self, response: str) -> List['Function']:
toolcall_list = re.findall(r'<tool_call>(.*?)</tool_call>', response, re.DOTALL)
functions = []
for toolcall in toolcall_list:
function = self._find_function_call(toolcall)
if function:
functions.append(function)
if len(functions) == 0:
# compat react_en
return super().get_toolcall(response)
return functions

def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
tool_descs = [
'# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>'
]
for tool in tools:
tool_descs.append(f'{json.dumps(tool, ensure_ascii=False)}')
tool_descs.append('</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}\n<arg_key>{arg-key-1}</arg_key>\n<arg_value>{arg-value-1}</arg_value>\n<arg_key>{arg-key-2}</arg_key>\n<arg_value>{arg-value-2}</arg_value>\n...\n</tool_call>')
tool_descs = '\n'.join(tool_descs)
if system.strip():
tool_descs += '<|system|>\n' + system.strip()
return tool_descs

def _format_tool_responses(
self,
assistant_content: str,
tool_messages,
) -> Tuple[str, 'Prompt']:
with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content
if with_action:
return super()._format_tool_responses(assistant_content, tool_messages)
res = []
for _, tool_message in enumerate(tool_messages):
tool_content = tool_message['content']
res.append(f"\n<tool_response>\n{tool_content}\n</tool_response>")
res.append('<|assistant|>\n')
return assistant_content, res

def _format_tool_calls(self, tool_call_messages) -> str:
tool_calls = []
for message in tool_call_messages:
tool_call = self._parse_tool_call(message['content'])
tool_calls.append(f"<tool_call>{tool_call['name']}")
for arg_key, arg_value in tool_call["arguments"].items():
tool_calls.append(f"<arg_key>{arg_key}</arg_key>\n<arg_value>{arg_value}</arg_value>")
tool_calls.append("</tool_call>")
return '\n'.join(tool_calls) + '<|observation|>'