Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,7 @@ def _swift_prepare_inputs(self, inputs: StdTemplateInputs):
i += 1
pre_message['content'], tool_content = self.agent_template._format_tool_responses(
pre_content, messages[i_start:i + 1])
# where tool_content is a List.
messages[i_start:i + 1] = [{'role': 'tool', 'content': tool_content}]
i = i_start + 1
elif pre_role == 'assistant' and role == 'assistant' or pre_role == 'user' and role == 'user':
Expand Down Expand Up @@ -1109,9 +1110,14 @@ def _swift_encode(self, inputs: StdTemplateInputs):
elif response is not None:
# It is the final round, and the response exists (during training).
context_list.append('{{RESPONSE}}')
# The GLM-4.5 assistant part (tool call) may end with <|observation|>,
# and here we avoid adding <|user|>.
endswith_stop_words = any(
response.endswith(stop_word) for stop_word in template_meta.stop_words
if isinstance(stop_word, str))
# self.is_training needed because we may want to continue generation from
# the current response
if self.is_training and not sep_token or self.task_type == 'embedding':
if self.is_training and not sep_token or self.task_type == 'embedding' and not endswith_stop_words:
extra_context_list = template_meta.suffix
extra_context_type = ContextType.SUFFIX
elif template_meta.response_prefix:
Expand Down Expand Up @@ -1200,6 +1206,7 @@ def _encode_truncated(self, inputs: StdTemplateInputs):
return encoded

def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
inputs.messages = deepcopy(inputs.messages)
template_backend = self.template_backend
if (self.template_meta.template_type == 'dummy' and self.use_chat_template and not self.is_training
and self.task_type != 'seq_cls'):
Expand Down
15 changes: 6 additions & 9 deletions swift/llm/template/template/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,15 +239,12 @@ class GLM4_5Template(ThinkingTemplate):
no_think_prefix = '<think></think>\n'
history_think_prefix = '<think></think>\n'

def _swift_encode(self, inputs: StdTemplateInputs):
res_context_list, loss_scale_list, answer_len = super()._swift_encode(inputs)
# When it's a tool_call, avoid generating <|observation|><|user|>
penultimate_content = res_context_list[-2] if len(res_context_list) >= 2 else None
if isinstance(penultimate_content,
str) and penultimate_content.endswith('<|observation|>') and res_context_list[-1] == '<|user|>':
res_context_list = res_context_list[:-1]
answer_len -= 1
return res_context_list, loss_scale_list, answer_len
def _jinja_encode(self, inputs: StdTemplateInputs):
for message in inputs.messages:
if message['role'] == 'assistant' and isinstance(message['content'],
str) and message['content'].endswith('<|observation|>'):
message['content'] = message['content'][:-len('<|observation|>')]
return super()._jinja_encode(inputs)


register_template(GLM4_5TemplateMeta(LLMTemplateType.glm4_5, template_cls=GLM4_5Template))
Expand Down
5 changes: 3 additions & 2 deletions swift/llm/template/template/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ def _swift_prepare_inputs(self, inputs):
# Delete the content before '</think>' in all assistant turns except the last round.
if message['role'] == 'assistant' and isinstance(message['content'], str) and i != len(messages) - 1:
if self.with_answer:
message['content'] = message['content'].split('<answer>')[-1].rstrip().rstrip(
'</answer>').strip()
message['content'] = message['content'].split('<answer>')[-1].rstrip()
if message['content'].endswith('</answer>'):
message['content'] = message['content'][:-len('</answer>')].strip()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This is an excellent and critical fix. The previous implementation using rstrip('</answer>') was incorrect because str.rstrip(chars) treats its argument as a set of characters to remove from the end, not as a suffix string. For example, 'great</answer>'.rstrip('</answer>') would incorrectly result in 'g'. The new implementation using endswith and slicing is robust and correctly removes the suffix. Great catch!

else:
message['content'] = self.history_think_prefix + message['content'].split(
'</think>')[-1].strip()
Expand Down
2 changes: 1 addition & 1 deletion swift/plugin/agent_template/glm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _format_tool_responses(
if with_action:
return super()._format_tool_responses(assistant_content, tool_messages)
res = []
for _, tool_message in enumerate(tool_messages):
for tool_message in tool_messages:
tool_content = tool_message['content']
res.append(f'\n<tool_response>\n{tool_content}\n</tool_response>')
res.append('<|assistant|>\n')
Expand Down
26 changes: 25 additions & 1 deletion tests/test_align/test_template/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

os.environ['SWIFT_DEBUG'] = '1'
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Hardcoding CUDA_VISIBLE_DEVICES in a test file is not recommended as it can cause tests to fail on machines with different GPU configurations (e.g., fewer than 4 GPUs, or no GPUs) or for other developers. This kind of configuration should be managed by the environment where the tests are run (e.g., a CI script or a local shell setup) rather than in the code itself. Please consider removing this line to improve test portability.


system = 'You are a helpful assistant.'

Expand Down Expand Up @@ -330,6 +331,28 @@ def test_hunyuan():
assert encoded['input_ids'] == encoded2['input_ids']


def test_glm4_5():
engine = PtEngine('ZhipuAI/GLM-4.5-Air')
template = engine.default_template
template.template_backend = 'jinja'
_infer(engine, num_tools=2)

dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
data = dataset[6]
data['messages'].insert(1, data['messages'][1])
data['messages'].insert(3, data['messages'][3])
template.template_backend = 'swift'
template.set_mode('train')
encoded = template.encode(data)
print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
print(f'labels: {template.safe_decode(encoded["labels"])}')
template.template_backend = 'jinja'
encoded2 = template.encode(data)
print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}')
print(f'labels: {template.safe_decode(encoded2["labels"])}')
assert encoded['input_ids'] == encoded2['input_ids']


if __name__ == '__main__':
from swift.plugin import agent_templates
from swift.llm import PtEngine, InferRequest, RequestConfig, load_dataset
Expand All @@ -345,4 +368,5 @@ def test_hunyuan():
# test_glm4_0414()
# test_llama3()
# test_llama4()
test_hunyuan()
# test_hunyuan()
test_glm4_5()
Loading