Skip to content

Commit 9828d8f

Browse files
committed
[template] add GLM4_5Template
1 parent 49c0c91 commit 9828d8f

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

swift/llm/template/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,8 +1135,7 @@ def _swift_encode(self, inputs: StdTemplateInputs):
11351135
context_list.append('{{RESPONSE}}')
11361136
# self.is_training needed because we may want to continue generation from
11371137
# the current response
1138-
string_stop_words = tuple(s for s in template_meta.stop_words if isinstance(s, str))
1139-
if self.is_training and not sep_token and not response.endswith(string_stop_words) or self.task_type == 'embedding':
1138+
if self.is_training and not sep_token or self.task_type == 'embedding':
11401139
extra_context_list = template_meta.suffix
11411140
extra_context_type = ContextType.SUFFIX
11421141
elif template_meta.response_prefix:

swift/llm/template/template/glm.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,17 @@ class GLM4_0414Template(ThinkingTemplate, GLM4Template):
3838
pass
3939

4040

41+
class GLM4_5Template(ThinkingTemplate):
42+
def _swift_encode(self, inputs: StdTemplateInputs):
43+
res_context_list, loss_scale_list, answer_len = super()._swift_encode(inputs)
44+
# When it's a tool_call, avoid generating <|observation|><|user|>
45+
penultimate_content = res_context_list[-2] if len(res_context_list) >= 2 else None
46+
if isinstance(penultimate_content, str) and penultimate_content.endswith('<|observation|>') and res_context_list[-1] == '<|user|>':
47+
res_context_list = res_context_list[:-1]
48+
answer_len -= 1
49+
return res_context_list, loss_scale_list, answer_len
50+
51+
4152
register_template(
4253
GLMTemplateMeta(
4354
LLMTemplateType.chatglm2,
@@ -241,7 +252,7 @@ def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
241252

242253
register_template(GLM4_0414TemplateMeta(LLMTemplateType.glm4_0414, template_cls=GLM4_0414Template))
243254

244-
register_template(GLM4_5TemplateMeta(LLMTemplateType.glm4_5, template_cls=ThinkingTemplate))
255+
register_template(GLM4_5TemplateMeta(LLMTemplateType.glm4_5, template_cls=GLM4_5Template))
245256

246257
register_template(GLM4_1VTemplateMeta(MLLMTemplateType.glm4_1v, template_cls=GLM4_1VTemplate))
247258

0 commit comments

Comments
 (0)