diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index ee2b5d140f..040dab8472 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -1110,9 +1110,14 @@ def _swift_encode(self, inputs: StdTemplateInputs): elif response is not None: # It is the final round, and the response exists (during training). context_list.append('{{RESPONSE}}') + # The GLM-4.5 assistant part (tool call) may end with <|observation|>, + # and here we avoid adding <|user|>. + endswith_stop_words = any( + response.endswith(stop_word) for stop_word in template_meta.stop_words + if isinstance(stop_word, str)) # self.is_training needed because we may want to continue generation from # the current response - if self.is_training and not sep_token or self.task_type == 'embedding': + if self.is_training and not sep_token or self.task_type == 'embedding' and not endswith_stop_words: extra_context_list = template_meta.suffix extra_context_type = ContextType.SUFFIX elif template_meta.response_prefix: @@ -1201,6 +1206,7 @@ def _encode_truncated(self, inputs: StdTemplateInputs): return encoded def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + inputs.messages = deepcopy(inputs.messages) template_backend = self.template_backend if (self.template_meta.template_type == 'dummy' and self.use_chat_template and not self.is_training and self.task_type != 'seq_cls'): diff --git a/swift/llm/template/template/glm.py b/swift/llm/template/template/glm.py index 067d6f3cc9..37e050d151 100644 --- a/swift/llm/template/template/glm.py +++ b/swift/llm/template/template/glm.py @@ -239,15 +239,12 @@ class GLM4_5Template(ThinkingTemplate): no_think_prefix = '\n' history_think_prefix = '\n' - def _swift_encode(self, inputs: StdTemplateInputs): - res_context_list, loss_scale_list, answer_len = super()._swift_encode(inputs) - # When it's a tool_call, avoid generating <|observation|><|user|> - penultimate_content = res_context_list[-2] if len(res_context_list) >= 2 else None - if isinstance(penultimate_content, - str) and penultimate_content.endswith('<|observation|>') and res_context_list[-1] == '<|user|>': - res_context_list = res_context_list[:-1] - answer_len -= 1 - return res_context_list, loss_scale_list, answer_len + def _jinja_encode(self, inputs: StdTemplateInputs): + for message in inputs.messages: + if message['role'] == 'assistant' and isinstance(message['content'], + str) and message['content'].endswith('<|observation|>'): + message['content'] = message['content'][:-len('<|observation|>')] + return super()._jinja_encode(inputs) register_template(GLM4_5TemplateMeta(LLMTemplateType.glm4_5, template_cls=GLM4_5Template)) diff --git a/swift/plugin/agent_template/glm4.py b/swift/plugin/agent_template/glm4.py index 186be5dff6..e18a70fb3e 100644 --- a/swift/plugin/agent_template/glm4.py +++ b/swift/plugin/agent_template/glm4.py @@ -132,7 +132,7 @@ def _format_tool_responses( if with_action: return super()._format_tool_responses(assistant_content, tool_messages) res = [] - for _, tool_message in enumerate(tool_messages): + for tool_message in tool_messages: tool_content = tool_message['content'] res.append(f'\n\n{tool_content}\n') res.append('<|assistant|>\n') diff --git a/tests/test_align/test_template/test_agent.py b/tests/test_align/test_template/test_agent.py index 8c492b5a89..11c091f0f7 100644 --- a/tests/test_align/test_template/test_agent.py +++ b/tests/test_align/test_template/test_agent.py @@ -1,6 +1,7 @@ import os os.environ['SWIFT_DEBUG'] = '1' +os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' system = 'You are a helpful assistant.' @@ -327,7 +328,29 @@ def test_hunyuan(): encoded2 = template.encode(data) print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}') print(f'labels: {template.safe_decode(encoded2["labels"])}') - assert encoded['input_ids'] == encoded2['input_ids'] + assert encoded['input_ids'][:-1] == encoded2['input_ids'] + + +def test_glm4_5(): + engine = PtEngine('ZhipuAI/GLM-4.5-Air') + template = engine.default_template + template.template_backend = 'jinja' + _infer(engine, num_tools=2) + + dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0] + data = dataset[6] + data['messages'].insert(1, data['messages'][1]) + data['messages'].insert(3, data['messages'][3]) + template.template_backend = 'swift' + template.set_mode('train') + encoded = template.encode(data) + print(f'input_ids: {template.safe_decode(encoded["input_ids"])}') + print(f'labels: {template.safe_decode(encoded["labels"])}') + template.template_backend = 'jinja' + encoded2 = template.encode(data) + print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}') + print(f'labels: {template.safe_decode(encoded2["labels"])}') + assert encoded['input_ids'][:-1] == encoded2['input_ids'] if __name__ == '__main__': @@ -345,4 +368,5 @@ def test_hunyuan(): # test_glm4_0414() # test_llama3() # test_llama4() - test_hunyuan() + # test_hunyuan() + test_glm4_5()