Skip to content

Commit 85f519f

Browse files
authored
[template] update glm4.5 agent template (#5518)
1 parent 7f35502 commit 85f519f

File tree

4 files changed

+40
-13
lines changed

4 files changed

+40
-13
lines changed

swift/llm/template/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1110,9 +1110,14 @@ def _swift_encode(self, inputs: StdTemplateInputs):
11101110
elif response is not None:
11111111
# It is the final round, and the response exists (during training).
11121112
context_list.append('{{RESPONSE}}')
1113+
# The GLM-4.5 assistant part (tool call) may end with <|observation|>,
1114+
# and here we avoid adding <|user|>.
1115+
endswith_stop_words = any(
1116+
response.endswith(stop_word) for stop_word in template_meta.stop_words
1117+
if isinstance(stop_word, str))
11131118
# self.is_training needed because we may want to continue generation from
11141119
# the current response
1115-
if self.is_training and not sep_token or self.task_type == 'embedding':
1120+
if self.is_training and not sep_token or self.task_type == 'embedding' and not endswith_stop_words:
11161121
extra_context_list = template_meta.suffix
11171122
extra_context_type = ContextType.SUFFIX
11181123
elif template_meta.response_prefix:
@@ -1201,6 +1206,7 @@ def _encode_truncated(self, inputs: StdTemplateInputs):
12011206
return encoded
12021207

12031208
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
1209+
inputs.messages = deepcopy(inputs.messages)
12041210
template_backend = self.template_backend
12051211
if (self.template_meta.template_type == 'dummy' and self.use_chat_template and not self.is_training
12061212
and self.task_type != 'seq_cls'):

swift/llm/template/template/glm.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -239,15 +239,12 @@ class GLM4_5Template(ThinkingTemplate):
239239
no_think_prefix = '<think></think>\n'
240240
history_think_prefix = '<think></think>\n'
241241

242-
def _swift_encode(self, inputs: StdTemplateInputs):
243-
res_context_list, loss_scale_list, answer_len = super()._swift_encode(inputs)
244-
# When it's a tool_call, avoid generating <|observation|><|user|>
245-
penultimate_content = res_context_list[-2] if len(res_context_list) >= 2 else None
246-
if isinstance(penultimate_content,
247-
str) and penultimate_content.endswith('<|observation|>') and res_context_list[-1] == '<|user|>':
248-
res_context_list = res_context_list[:-1]
249-
answer_len -= 1
250-
return res_context_list, loss_scale_list, answer_len
242+
def _jinja_encode(self, inputs: StdTemplateInputs):
243+
for message in inputs.messages:
244+
if message['role'] == 'assistant' and isinstance(message['content'],
245+
str) and message['content'].endswith('<|observation|>'):
246+
message['content'] = message['content'][:-len('<|observation|>')]
247+
return super()._jinja_encode(inputs)
251248

252249

253250
register_template(GLM4_5TemplateMeta(LLMTemplateType.glm4_5, template_cls=GLM4_5Template))

swift/plugin/agent_template/glm4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def _format_tool_responses(
132132
if with_action:
133133
return super()._format_tool_responses(assistant_content, tool_messages)
134134
res = []
135-
for _, tool_message in enumerate(tool_messages):
135+
for tool_message in tool_messages:
136136
tool_content = tool_message['content']
137137
res.append(f'\n<tool_response>\n{tool_content}\n</tool_response>')
138138
res.append('<|assistant|>\n')

tests/test_align/test_template/test_agent.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22

33
os.environ['SWIFT_DEBUG'] = '1'
4+
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
45

56
system = 'You are a helpful assistant.'
67

@@ -327,7 +328,29 @@ def test_hunyuan():
327328
encoded2 = template.encode(data)
328329
print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}')
329330
print(f'labels: {template.safe_decode(encoded2["labels"])}')
330-
assert encoded['input_ids'] == encoded2['input_ids']
331+
assert encoded['input_ids'][:-1] == encoded2['input_ids']
332+
333+
334+
def test_glm4_5():
335+
engine = PtEngine('ZhipuAI/GLM-4.5-Air')
336+
template = engine.default_template
337+
template.template_backend = 'jinja'
338+
_infer(engine, num_tools=2)
339+
340+
dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
341+
data = dataset[6]
342+
data['messages'].insert(1, data['messages'][1])
343+
data['messages'].insert(3, data['messages'][3])
344+
template.template_backend = 'swift'
345+
template.set_mode('train')
346+
encoded = template.encode(data)
347+
print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
348+
print(f'labels: {template.safe_decode(encoded["labels"])}')
349+
template.template_backend = 'jinja'
350+
encoded2 = template.encode(data)
351+
print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}')
352+
print(f'labels: {template.safe_decode(encoded2["labels"])}')
353+
assert encoded['input_ids'][:-1] == encoded2['input_ids']
331354

332355

333356
if __name__ == '__main__':
@@ -345,4 +368,5 @@ def test_hunyuan():
345368
# test_glm4_0414()
346369
# test_llama3()
347370
# test_llama4()
348-
test_hunyuan()
371+
# test_hunyuan()
372+
test_glm4_5()

0 commit comments

Comments
 (0)