diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py
index ee2b5d140f..040dab8472 100644
--- a/swift/llm/template/base.py
+++ b/swift/llm/template/base.py
@@ -1110,9 +1110,14 @@ def _swift_encode(self, inputs: StdTemplateInputs):
elif response is not None:
# It is the final round, and the response exists (during training).
context_list.append('{{RESPONSE}}')
+ # The GLM-4.5 assistant part (tool call) may end with <|observation|>,
+ # and here we avoid adding <|user|>.
+ endswith_stop_words = any(
+ response.endswith(stop_word) for stop_word in template_meta.stop_words
+ if isinstance(stop_word, str))
# self.is_training needed because we may want to continue generation from
# the current response
- if self.is_training and not sep_token or self.task_type == 'embedding':
+ if self.is_training and not sep_token or self.task_type == 'embedding' and not endswith_stop_words:
extra_context_list = template_meta.suffix
extra_context_type = ContextType.SUFFIX
elif template_meta.response_prefix:
@@ -1201,6 +1206,7 @@ def _encode_truncated(self, inputs: StdTemplateInputs):
return encoded
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
+ inputs.messages = deepcopy(inputs.messages)
template_backend = self.template_backend
if (self.template_meta.template_type == 'dummy' and self.use_chat_template and not self.is_training
and self.task_type != 'seq_cls'):
diff --git a/swift/llm/template/template/glm.py b/swift/llm/template/template/glm.py
index 067d6f3cc9..37e050d151 100644
--- a/swift/llm/template/template/glm.py
+++ b/swift/llm/template/template/glm.py
@@ -239,15 +239,12 @@ class GLM4_5Template(ThinkingTemplate):
no_think_prefix = '\n'
history_think_prefix = '\n'
- def _swift_encode(self, inputs: StdTemplateInputs):
- res_context_list, loss_scale_list, answer_len = super()._swift_encode(inputs)
- # When it's a tool_call, avoid generating <|observation|><|user|>
- penultimate_content = res_context_list[-2] if len(res_context_list) >= 2 else None
- if isinstance(penultimate_content,
- str) and penultimate_content.endswith('<|observation|>') and res_context_list[-1] == '<|user|>':
- res_context_list = res_context_list[:-1]
- answer_len -= 1
- return res_context_list, loss_scale_list, answer_len
+ def _jinja_encode(self, inputs: StdTemplateInputs):
+ for message in inputs.messages:
+ if message['role'] == 'assistant' and isinstance(message['content'],
+ str) and message['content'].endswith('<|observation|>'):
+ message['content'] = message['content'][:-len('<|observation|>')]
+ return super()._jinja_encode(inputs)
register_template(GLM4_5TemplateMeta(LLMTemplateType.glm4_5, template_cls=GLM4_5Template))
diff --git a/swift/plugin/agent_template/glm4.py b/swift/plugin/agent_template/glm4.py
index 186be5dff6..e18a70fb3e 100644
--- a/swift/plugin/agent_template/glm4.py
+++ b/swift/plugin/agent_template/glm4.py
@@ -132,7 +132,7 @@ def _format_tool_responses(
if with_action:
return super()._format_tool_responses(assistant_content, tool_messages)
res = []
- for _, tool_message in enumerate(tool_messages):
+ for tool_message in tool_messages:
tool_content = tool_message['content']
res.append(f'\n\n{tool_content}\n')
res.append('<|assistant|>\n')
diff --git a/tests/test_align/test_template/test_agent.py b/tests/test_align/test_template/test_agent.py
index 8c492b5a89..11c091f0f7 100644
--- a/tests/test_align/test_template/test_agent.py
+++ b/tests/test_align/test_template/test_agent.py
@@ -1,6 +1,7 @@
import os
os.environ['SWIFT_DEBUG'] = '1'
+os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
system = 'You are a helpful assistant.'
@@ -327,7 +328,29 @@ def test_hunyuan():
encoded2 = template.encode(data)
print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}')
print(f'labels: {template.safe_decode(encoded2["labels"])}')
- assert encoded['input_ids'] == encoded2['input_ids']
+ assert encoded['input_ids'][:-1] == encoded2['input_ids']
+
+
+def test_glm4_5():
+ engine = PtEngine('ZhipuAI/GLM-4.5-Air')
+ template = engine.default_template
+ template.template_backend = 'jinja'
+ _infer(engine, num_tools=2)
+
+ dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
+ data = dataset[6]
+ data['messages'].insert(1, data['messages'][1])
+ data['messages'].insert(3, data['messages'][3])
+ template.template_backend = 'swift'
+ template.set_mode('train')
+ encoded = template.encode(data)
+ print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
+ print(f'labels: {template.safe_decode(encoded["labels"])}')
+ template.template_backend = 'jinja'
+ encoded2 = template.encode(data)
+ print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}')
+ print(f'labels: {template.safe_decode(encoded2["labels"])}')
+ assert encoded['input_ids'][:-1] == encoded2['input_ids']
if __name__ == '__main__':
@@ -345,4 +368,5 @@ def test_hunyuan():
# test_glm4_0414()
# test_llama3()
# test_llama4()
- test_hunyuan()
+ # test_hunyuan()
+ test_glm4_5()