diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 6771ed23f1..30af2e1232 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -1135,7 +1135,8 @@ def _swift_encode(self, inputs: StdTemplateInputs): context_list.append('{{RESPONSE}}') # self.is_training needed because we may want to continue generation from # the current response - if self.is_training and not sep_token or self.task_type == 'embedding': + string_stop_words = tuple(s for s in template_meta.stop_words if isinstance(s, str)) + if self.is_training and not sep_token and not response.endswith(string_stop_words) or self.task_type == 'embedding': extra_context_list = template_meta.suffix extra_context_type = ContextType.SUFFIX elif template_meta.response_prefix: diff --git a/swift/llm/template/template/glm.py b/swift/llm/template/template/glm.py index 7ead62d889..a9d2362fde 100644 --- a/swift/llm/template/template/glm.py +++ b/swift/llm/template/template/glm.py @@ -65,6 +65,17 @@ class GLM4_0414TemplateMeta(GLM4TemplateMeta): agent_template: str = 'glm4_0414' +@dataclass +class GLM4_5TemplateMeta(GLMTemplateMeta): + prefix: Prompt = field(default_factory=lambda: ['[gMASK]']) + prompt: Prompt = field(default_factory=lambda: ['<|user|>\n{{QUERY}}<|assistant|>\n']) + chat_sep: Optional[Prompt] = field(default_factory=list) + suffix: Prompt = field(default_factory=lambda: ['<|user|>']) + system_prefix: Optional[Prompt] = field(default_factory=lambda: ['[gMASK]<|system|>\n{{SYSTEM}}']) + + agent_template: str = 'glm4_5' + stop_words: List[Word] = field(default_factory=lambda: ['<|endoftext|>', '<|user|>', '<|observation|>']) + class GLM4_1VTemplateMeta(GLM4_0414TemplateMeta): system_prefix: Optional[Prompt] = field(default_factory=lambda: ['[gMASK]<|system|>{{SYSTEM}}']) @@ -237,7 +248,7 @@ def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: register_template(GLM4_0414TemplateMeta(LLMTemplateType.glm4_0414, template_cls=GLM4_0414Template)) -register_template(GLM4TemplateMeta(LLMTemplateType.glm4_5, template_cls=ThinkingTemplate)) +register_template(GLM4_5TemplateMeta(LLMTemplateType.glm4_5, template_cls=ThinkingTemplate)) register_template(GLM4_1VTemplateMeta(MLLMTemplateType.glm4_1v, template_cls=GLM4_1VTemplate)) diff --git a/swift/plugin/agent_template/__init__.py b/swift/plugin/agent_template/__init__.py index 4b93b01983..05a16c4d9b 100644 --- a/swift/plugin/agent_template/__init__.py +++ b/swift/plugin/agent_template/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .base import BaseAgentTemplate from .extra import ReactGRPOAgentTemplate -from .glm4 import GLM4_0414AgentTemplate, GLM4AgentTemplate +from .glm4 import GLM4_0414AgentTemplate, GLM4AgentTemplate, GLM4_5AgentTemplate from .hermes import HermesAgentTemplate, HunyuanHermesAgentTemplate from .llama import Llama3AgentTemplate, Llama4AgentTemplate from .mistral import MistralAgentTemplate @@ -23,6 +23,7 @@ 'toolbench': ToolBenchAgentTemplate, # ref: https://modelscope.cn/datasets/swift/ToolBench 'glm4': GLM4AgentTemplate, 'glm4_0414': GLM4_0414AgentTemplate, # ref: https://modelscope.cn/models/ZhipuAI/GLM-4-9B-0414 + 'glm4_5': GLM4_5AgentTemplate, 'llama3': Llama3AgentTemplate, 'llama4': Llama4AgentTemplate, # extra diff --git a/swift/plugin/agent_template/glm4.py b/swift/plugin/agent_template/glm4.py index 0dfea2ab65..08b17bca2e 100644 --- a/swift/plugin/agent_template/glm4.py +++ b/swift/plugin/agent_template/glm4.py @@ -22,7 +22,6 @@ def _find_function_call(single_content: str) -> Optional['Function']: matches = pattern.findall(single_content) if not matches: return - name, arguments = matches[0] return Function(name=name, arguments=arguments) @@ -77,3 +76,70 @@ def _format_tool_calls(self, tool_call_messages) -> str: class GLM4_0414AgentTemplate(GLM4AgentTemplate): is_glm4_0414 = True + + +class GLM4_5AgentTemplate(BaseAgentTemplate): + + @staticmethod + def _find_function_call(single_content: str) -> Optional['Function']: + from swift.llm.infer import Function + single_content = single_content.strip() + func_name_match = re.match(r'^([^\n<]+)', single_content) + if not func_name_match: + return None + func_name = func_name_match.group(1).strip() + keys = re.findall(r'(.*?)', single_content, re.DOTALL) + values = re.findall(r'(.*?)', single_content, re.DOTALL) + if len(keys) != len(values): + return None + args = {k.strip(): v.strip() for k, v in zip(keys, values)} + return Function(name=func_name, arguments=json.dumps(args, ensure_ascii=False)) + + def get_toolcall(self, response: str) -> List['Function']: + toolcall_list = re.findall(r'(.*?)', response, re.DOTALL) + functions = [] + for toolcall in toolcall_list: + function = self._find_function_call(toolcall) + if function: + functions.append(function) + if len(functions) == 0: + # compat react_en + return super().get_toolcall(response) + return functions + + def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str: + tool_descs = [ + '# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n' + ] + for tool in tools: + tool_descs.append(f'{json.dumps(tool, ensure_ascii=False)}') + tool_descs.append('\n\nFor each function call, output the function name and arguments within the following XML format:\n{function-name}\n{arg-key-1}\n{arg-value-1}\n{arg-key-2}\n{arg-value-2}\n...\n') + tool_descs = '\n'.join(tool_descs) + if system.strip(): + tool_descs += '<|system|>\n' + system.strip() + return tool_descs + + def _format_tool_responses( + self, + assistant_content: str, + tool_messages, + ) -> Tuple[str, 'Prompt']: + with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content + if with_action: + return super()._format_tool_responses(assistant_content, tool_messages) + res = [] + for _, tool_message in enumerate(tool_messages): + tool_content = tool_message['content'] + res.append(f"\n\n{tool_content}\n") + res.append('<|assistant|>\n') + return assistant_content, res + + def _format_tool_calls(self, tool_call_messages) -> str: + tool_calls = [] + for message in tool_call_messages: + tool_call = self._parse_tool_call(message['content']) + tool_calls.append(f"{tool_call['name']}") + for arg_key, arg_value in tool_call["arguments"].items(): + tool_calls.append(f"{arg_key}\n{arg_value}") + tool_calls.append("") + return '\n'.join(tool_calls) + '<|observation|>' \ No newline at end of file