Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,9 +1110,14 @@ def _swift_encode(self, inputs: StdTemplateInputs):
elif response is not None:
# It is the final round, and the response exists (during training).
context_list.append('{{RESPONSE}}')
# The GLM-4.5 assistant part (tool call) may end with <|observation|>,
# and here we avoid adding <|user|>.
endswith_stop_words = any(
response.endswith(stop_word) for stop_word in template_meta.stop_words
if isinstance(stop_word, str))
# self.is_training needed because we may want to continue generation from
# the current response
if self.is_training and not sep_token or self.task_type == 'embedding':
if self.is_training and not sep_token or self.task_type == 'embedding' and not endswith_stop_words:
extra_context_list = template_meta.suffix
extra_context_type = ContextType.SUFFIX
elif template_meta.response_prefix:
Expand Down Expand Up @@ -1201,6 +1206,7 @@ def _encode_truncated(self, inputs: StdTemplateInputs):
return encoded

def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
inputs.messages = deepcopy(inputs.messages)
template_backend = self.template_backend
if (self.template_meta.template_type == 'dummy' and self.use_chat_template and not self.is_training
and self.task_type != 'seq_cls'):
Expand Down
15 changes: 6 additions & 9 deletions swift/llm/template/template/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,15 +239,12 @@ class GLM4_5Template(ThinkingTemplate):
no_think_prefix = '<think></think>\n'
history_think_prefix = '<think></think>\n'

def _swift_encode(self, inputs: StdTemplateInputs):
res_context_list, loss_scale_list, answer_len = super()._swift_encode(inputs)
# When it's a tool_call, avoid generating <|observation|><|user|>
penultimate_content = res_context_list[-2] if len(res_context_list) >= 2 else None
if isinstance(penultimate_content,
str) and penultimate_content.endswith('<|observation|>') and res_context_list[-1] == '<|user|>':
res_context_list = res_context_list[:-1]
answer_len -= 1
return res_context_list, loss_scale_list, answer_len
def _jinja_encode(self, inputs: StdTemplateInputs):
for message in inputs.messages:
if message['role'] == 'assistant' and isinstance(message['content'],
str) and message['content'].endswith('<|observation|>'):
message['content'] = message['content'][:-len('<|observation|>')]
return super()._jinja_encode(inputs)


register_template(GLM4_5TemplateMeta(LLMTemplateType.glm4_5, template_cls=GLM4_5Template))
Expand Down
2 changes: 1 addition & 1 deletion swift/plugin/agent_template/glm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _format_tool_responses(
if with_action:
return super()._format_tool_responses(assistant_content, tool_messages)
res = []
for _, tool_message in enumerate(tool_messages):
for tool_message in tool_messages:
tool_content = tool_message['content']
res.append(f'\n<tool_response>\n{tool_content}\n</tool_response>')
res.append('<|assistant|>\n')
Expand Down
28 changes: 26 additions & 2 deletions tests/test_align/test_template/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

os.environ['SWIFT_DEBUG'] = '1'
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Hardcoding CUDA_VISIBLE_DEVICES in a test file is not recommended as it can cause tests to fail on machines with different GPU configurations (e.g., fewer than 4 GPUs, or no GPUs) or for other developers. This kind of configuration should be managed by the environment where the tests are run (e.g., a CI script or a local shell setup) rather than in the code itself. Please consider removing this line to improve test portability.


system = 'You are a helpful assistant.'

Expand Down Expand Up @@ -327,7 +328,29 @@ def test_hunyuan():
encoded2 = template.encode(data)
print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}')
print(f'labels: {template.safe_decode(encoded2["labels"])}')
assert encoded['input_ids'] == encoded2['input_ids']
assert encoded['input_ids'][:-1] == encoded2['input_ids']


def test_glm4_5():
engine = PtEngine('ZhipuAI/GLM-4.5-Air')
template = engine.default_template
template.template_backend = 'jinja'
_infer(engine, num_tools=2)

dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
data = dataset[6]
data['messages'].insert(1, data['messages'][1])
data['messages'].insert(3, data['messages'][3])
template.template_backend = 'swift'
template.set_mode('train')
encoded = template.encode(data)
print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
print(f'labels: {template.safe_decode(encoded["labels"])}')
template.template_backend = 'jinja'
encoded2 = template.encode(data)
print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}')
print(f'labels: {template.safe_decode(encoded2["labels"])}')
assert encoded['input_ids'][:-1] == encoded2['input_ids']


if __name__ == '__main__':
Expand All @@ -345,4 +368,5 @@ def test_hunyuan():
# test_glm4_0414()
# test_llama3()
# test_llama4()
test_hunyuan()
# test_hunyuan()
test_glm4_5()
Loading