Skip to content

Commit 1477de3

Browse files
authored
[model] support GLM4.1V (#4804)
* init * test glm41v * fix image tokens * fix grid * test video * fix video token * rm model * update doc * diable video
1 parent 9008153 commit 1477de3

File tree

9 files changed

+199
-3
lines changed

9 files changed

+199
-3
lines changed

docs/source/Instruction/支持的模型和数据集.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,8 @@
645645
|[XiaomiMiMo/MiMo-VL-7B-RL](https://modelscope.cn/models/XiaomiMiMo/MiMo-VL-7B-RL)|mimo_vl|mimo_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[XiaomiMiMo/MiMo-VL-7B-RL](https://huggingface.co/XiaomiMiMo/MiMo-VL-7B-RL)|
646646
|[ZhipuAI/glm-4v-9b](https://modelscope.cn/models/ZhipuAI/glm-4v-9b)|glm4v|glm4v|transformers>=4.42,<4.45|&#x2718;|-|[THUDM/glm-4v-9b](https://huggingface.co/THUDM/glm-4v-9b)|
647647
|[ZhipuAI/cogagent-9b-20241220](https://modelscope.cn/models/ZhipuAI/cogagent-9b-20241220)|glm4v|glm4v|transformers>=4.42|&#x2718;|-|[THUDM/cogagent-9b-20241220](https://huggingface.co/THUDM/cogagent-9b-20241220)|
648+
|[ZhipuAI/GLM-4.1V-9B-Base](https://modelscope.cn/models/ZhipuAI/GLM-4.1V-9B-Base)|glm4_1v|glm4_1v|transformers>=4.53|&#x2718;|-|[THUDM/GLM-4.1V-9B-Base](https://huggingface.co/THUDM/GLM-4.1V-9B-Base)|
649+
|[ZhipuAI/GLM-4.1V-9B-Thinking](https://modelscope.cn/models/ZhipuAI/GLM-4.1V-9B-Thinking)|glm4_1v|glm4_1v|transformers>=4.53|&#x2718;|-|[THUDM/GLM-4.1V-9B-Thinking](https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking)|
648650
|[ZhipuAI/glm-edge-v-2b](https://modelscope.cn/models/ZhipuAI/glm-edge-v-2b)|glm_edge_v|glm_edge_v|transformers>=4.46|&#x2718;|vision|[THUDM/glm-edge-v-2b](https://huggingface.co/THUDM/glm-edge-v-2b)|
649651
|[ZhipuAI/glm-edge-4b-chat](https://modelscope.cn/models/ZhipuAI/glm-edge-4b-chat)|glm_edge_v|glm_edge_v|transformers>=4.46|&#x2718;|vision|[THUDM/glm-edge-4b-chat](https://huggingface.co/THUDM/glm-edge-4b-chat)|
650652
|[ZhipuAI/cogvlm-chat](https://modelscope.cn/models/ZhipuAI/cogvlm-chat)|cogvlm|cogvlm|transformers<4.42|&#x2718;|-|[THUDM/cogvlm-chat-hf](https://huggingface.co/THUDM/cogvlm-chat-hf)|

docs/source_en/Instruction/Supported-models-and-datasets.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,8 @@ The table below introduces the models integrated with ms-swift:
645645
|[XiaomiMiMo/MiMo-VL-7B-RL](https://modelscope.cn/models/XiaomiMiMo/MiMo-VL-7B-RL)|mimo_vl|mimo_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|&#x2718;|vision, video|[XiaomiMiMo/MiMo-VL-7B-RL](https://huggingface.co/XiaomiMiMo/MiMo-VL-7B-RL)|
646646
|[ZhipuAI/glm-4v-9b](https://modelscope.cn/models/ZhipuAI/glm-4v-9b)|glm4v|glm4v|transformers>=4.42,<4.45|&#x2718;|-|[THUDM/glm-4v-9b](https://huggingface.co/THUDM/glm-4v-9b)|
647647
|[ZhipuAI/cogagent-9b-20241220](https://modelscope.cn/models/ZhipuAI/cogagent-9b-20241220)|glm4v|glm4v|transformers>=4.42|&#x2718;|-|[THUDM/cogagent-9b-20241220](https://huggingface.co/THUDM/cogagent-9b-20241220)|
648+
|[ZhipuAI/GLM-4.1V-9B-Base](https://modelscope.cn/models/ZhipuAI/GLM-4.1V-9B-Base)|glm4_1v|glm4_1v|transformers>=4.53|&#x2718;|-|[THUDM/GLM-4.1V-9B-Base](https://huggingface.co/THUDM/GLM-4.1V-9B-Base)|
649+
|[ZhipuAI/GLM-4.1V-9B-Thinking](https://modelscope.cn/models/ZhipuAI/GLM-4.1V-9B-Thinking)|glm4_1v|glm4_1v|transformers>=4.53|&#x2718;|-|[THUDM/GLM-4.1V-9B-Thinking](https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking)|
648650
|[ZhipuAI/glm-edge-v-2b](https://modelscope.cn/models/ZhipuAI/glm-edge-v-2b)|glm_edge_v|glm_edge_v|transformers>=4.46|&#x2718;|vision|[THUDM/glm-edge-v-2b](https://huggingface.co/THUDM/glm-edge-v-2b)|
649651
|[ZhipuAI/glm-edge-4b-chat](https://modelscope.cn/models/ZhipuAI/glm-edge-4b-chat)|glm_edge_v|glm_edge_v|transformers>=4.46|&#x2718;|vision|[THUDM/glm-edge-4b-chat](https://huggingface.co/THUDM/glm-edge-4b-chat)|
650652
|[ZhipuAI/cogvlm-chat](https://modelscope.cn/models/ZhipuAI/cogvlm-chat)|cogvlm|cogvlm|transformers<4.42|&#x2718;|-|[THUDM/cogvlm-chat-hf](https://huggingface.co/THUDM/cogvlm-chat-hf)|

swift/llm/model/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ class MLLMModelType:
154154
mimo_vl = 'mimo_vl'
155155

156156
glm4v = 'glm4v'
157+
glm4_1v = 'glm4_1v'
157158
glm_edge_v = 'glm_edge_v'
158159
cogvlm = 'cogvlm'
159160
cogagent_vqa = 'cogagent_vqa'

swift/llm/model/model/glm.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from ..constant import LLMModelType, MLLMModelType
1414
from ..model_arch import ModelArch
1515
from ..patcher import patch_output_to_input_device
16-
from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
16+
from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal,
17+
get_model_tokenizer_with_flash_attn, register_model)
1718
from ..utils import AttnImpl, ModelInfo, safe_snapshot_download
1819

1920
logger = get_logger()
@@ -231,6 +232,35 @@ def get_model_tokenizer_glm4v(model_dir: str,
231232
))
232233

233234

235+
def get_model_tokenizer_glm4_1v(*args, **kwargs):
236+
from transformers import Glm4vForConditionalGeneration
237+
logger.info(
238+
"If you encounter the error 'TypeError: group_images_by_shape() missing 1 required positional argument: "
239+
"\"disable_grouping\"', please install the source version of the transformers library.")
240+
241+
kwargs['automodel_class'] = kwargs['automodel_class'] or Glm4vForConditionalGeneration
242+
return get_model_tokenizer_multimodal(*args, **kwargs)
243+
244+
245+
register_model(
246+
ModelMeta(
247+
MLLMModelType.glm4_1v,
248+
[
249+
ModelGroup(
250+
[
251+
Model('ZhipuAI/GLM-4.1V-9B-Base', 'THUDM/GLM-4.1V-9B-Base'),
252+
Model('ZhipuAI/GLM-4.1V-9B-Thinking', 'THUDM/GLM-4.1V-9B-Thinking'),
253+
],
254+
requires=['transformers>=4.53'],
255+
),
256+
],
257+
TemplateType.glm4_1v,
258+
get_model_tokenizer_glm4_1v,
259+
architectures=['Glm4vForConditionalGeneration'],
260+
model_arch=ModelArch.glm4_1v,
261+
))
262+
263+
234264
def get_model_tokenizer_cogvlm(model_dir: str,
235265
model_info: ModelInfo,
236266
model_kwargs: Dict[str, Any],

swift/llm/model/model_arch.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class MLLMModelArch:
3434

3535
cogvlm = 'cogvlm'
3636
glm4v = 'glm4v'
37+
glm4_1v = 'glm4_1v'
3738
glm_edge_v = 'glm_edge_v'
3839

3940
llama3_1_omni = 'llama3_1_omni'
@@ -511,6 +512,14 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
511512
vision_tower='transformer.vision',
512513
))
513514

515+
register_model_arch(
516+
MultiModelKeys(
517+
MLLMModelArch.glm4_1v,
518+
language_model='model.language_model',
519+
aligner='model.visual.merger',
520+
vision_tower='model.visual',
521+
))
522+
514523
register_model_arch(
515524
MultiModelKeys(
516525
MLLMModelArch.idefics3,

swift/llm/template/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ class MLLMTemplateType:
152152
cogvlm2 = 'cogvlm2'
153153
cogvlm2_video = 'cogvlm2_video'
154154
glm4v = 'glm4v'
155+
glm4_1v = 'glm4_1v'
155156
glm_edge_v = 'glm_edge_v'
156157

157158
minicpmv = 'minicpmv'

swift/llm/template/template/glm.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ class GLM4_0414TemplateMeta(GLM4TemplateMeta):
6969
agent_template: str = 'glm4_0414'
7070

7171

72+
class GLM4_1VTemplateMeta(GLM4_0414TemplateMeta):
73+
system_prefix: Optional[Prompt] = field(default_factory=lambda: ['[gMASK]<sop><|system|>{{SYSTEM}}'])
74+
75+
7276
class GLM4VTemplate(Template):
7377

7478
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
@@ -106,12 +110,132 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
106110
return res
107111

108112

113+
class GLM4_1VTemplate(Template):
114+
begin_of_image_token = 151339
115+
end_of_image_token = 151340
116+
image_token = 151343
117+
begin_of_video_token = 151341
118+
end_of_video_token = 151342
119+
video_token = 151344
120+
121+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
122+
inputs: StdTemplateInputs) -> List[Context]:
123+
# TODO: model video infer bug
124+
assert media_type in ['image']
125+
if media_type == 'image':
126+
return [[-100]]
127+
elif media_type == 'video':
128+
return [[-200]]
129+
130+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
131+
encoded = super()._encode(inputs)
132+
processor = self.processor
133+
input_ids = encoded['input_ids']
134+
labels = encoded['labels']
135+
image_idx_list = findall(input_ids, -100)
136+
video_idx_list = findall(input_ids, -200)
137+
if image_idx_list:
138+
images = inputs.images
139+
image_inputs = processor.image_processor(images=images, return_tensors='pt')
140+
encoded['pixel_values'] = image_inputs['pixel_values']
141+
encoded['image_grid_thw'] = image_grid_thw = image_inputs['image_grid_thw']
142+
merge_length = processor.image_processor.merge_size**2
143+
added_tokens_len = 0
144+
for i, idx in enumerate(image_idx_list):
145+
num_image_tokens = image_grid_thw[i].prod() // merge_length
146+
image_tokens = [self.begin_of_image_token
147+
] + [self.image_token] * num_image_tokens + [self.end_of_image_token]
148+
149+
input_ids = input_ids[:added_tokens_len + idx] + image_tokens + input_ids[added_tokens_len + idx + 1:]
150+
if labels is not None:
151+
labels = labels[:added_tokens_len + idx] + [-100] * len(image_tokens) + labels[added_tokens_len
152+
+ idx + 1:]
153+
added_tokens_len += len(image_tokens) - 1
154+
155+
if video_idx_list:
156+
# TODO: model video infer bug
157+
assert len(
158+
video_idx_list) <= 1, f'GLM4.1V model only support 1 video, but detected {len(video_idx_list)} <video> '
159+
assert not image_idx_list, "GLM4.1V model doesn't support inputs containing both video and images"
160+
161+
video_fnames = inputs.videos
162+
from transformers.video_utils import load_video
163+
from transformers.image_utils import load_image
164+
import numpy as np
165+
video_metadata = []
166+
videos = []
167+
for fname in video_fnames:
168+
if isinstance(fname, (list, tuple)) and isinstance(fname[0], str):
169+
video = [np.array(load_image(image_fname)) for image_fname in fname]
170+
# create a 4D video because `load_video` always returns a 4D array
171+
video = np.stack(video)
172+
metadata = None
173+
else:
174+
video, metadata = load_video(fname)
175+
videos.append(video)
176+
video_metadata.append(metadata)
177+
videos = [videos]
178+
video_metadata = [video_metadata]
179+
180+
videos_inputs = processor.video_processor(videos=videos, video_metadata=video_metadata, return_tensors='pt')
181+
encoded['pixel_values_videos'] = videos_inputs['pixel_values_videos']
182+
encoded['video_grid_thw'] = video_grid_thw = videos_inputs['video_grid_thw']
183+
timestamps = videos_inputs.pop('timestamps')
184+
num_frames = len(video_grid_thw)
185+
video_structure = [self.begin_of_video_token]
186+
if hasattr(timestamps, 'tolist'):
187+
timestamps_list = timestamps.tolist()[0]
188+
else:
189+
timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps
190+
unique_timestamps = []
191+
for idx in range(0, len(timestamps_list)):
192+
unique_timestamps.append(timestamps_list[idx])
193+
selected_timestamps = unique_timestamps[:num_frames]
194+
while len(selected_timestamps) < num_frames:
195+
selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
196+
merge_length = processor.video_processor.merge_size**2
197+
added_tokens_len = 0
198+
for frame_idx in range(num_frames):
199+
timestamp_sec = selected_timestamps[frame_idx]
200+
num_image_tokens = video_grid_thw[frame_idx].prod() // merge_length
201+
timestamp_sec_token = processor.tokenizer(str(timestamp_sec))['input_ids']
202+
frame_structure = [self.begin_of_image_token] + [self.image_token] * num_image_tokens + \
203+
[self.end_of_image_token] + timestamp_sec_token
204+
video_structure += frame_structure
205+
video_structure += [self.end_of_video_token]
206+
207+
for i, idx in enumerate(video_idx_list):
208+
# BUG in GLM4.1V?: All video placeholder take same tokens
209+
# https://github.com/huggingface/transformers/blob/v4.53.0/src/transformers/models/glm4v/processing_glm4v.py#L165-L194
210+
input_ids = input_ids[:added_tokens_len + idx] + video_structure + \
211+
input_ids[added_tokens_len + idx + 1:]
212+
if labels is not None:
213+
labels = labels[:added_tokens_len + idx] + [-100] * len(video_structure) + \
214+
labels[added_tokens_len + idx + 1:]
215+
added_tokens_len += len(video_structure) - 1
216+
217+
encoded['input_ids'] = input_ids
218+
encoded['labels'] = labels
219+
encoded['position_ids'] = list(range(len(input_ids)))
220+
return encoded
221+
222+
def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
223+
res = super()._data_collator_mm_data(batch)
224+
for media_type in ['image', 'video']:
225+
grid_thw = self.concat_tensor(batch, f'{media_type}_grid_thw', 0)
226+
if grid_thw is not None:
227+
res[f'{media_type}_grid_thw'] = grid_thw
228+
return res
229+
230+
109231
register_template(GLM4TemplateMeta(MLLMTemplateType.glm4v, template_cls=GLM4VTemplate, suffix=['<|endoftext|>']))
110232

111233
register_template(GLM4TemplateMeta(LLMTemplateType.glm4, template_cls=GLM4Template))
112234

113235
register_template(GLM4_0414TemplateMeta(LLMTemplateType.glm4_0414, template_cls=GLM4_0414Template))
114236

237+
register_template(GLM4_1VTemplateMeta(MLLMTemplateType.glm4_1v, template_cls=GLM4_1VTemplate))
238+
115239
glm4z1rumination_system = (
116240
'你是一个专业的深度研究助手,通过提供的工具与模拟浏览器交互,来帮助用户完成深度信息调研和报告撰写任务。'
117241
'今年是 2025 年。\n\n'

tests/test_align/test_template/test_video.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,16 @@ def test_qwen2_5_omni():
152152
assert response == response2 == ground_truth
153153

154154

155+
def test_glm4_1v():
156+
messages = [{'role': 'user', 'content': '<video>What happened in the video?'}]
157+
videos = ['https://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/baby.mp4']
158+
pt_engine = PtEngine('ZhipuAI/GLM-4.1V-9B-Thinking')
159+
response = _infer_model(pt_engine, messages=messages, videos=videos)
160+
pt_engine.default_template.template_backend = 'jinja'
161+
response2 = _infer_model(pt_engine, messages=messages, videos=videos)
162+
assert response == response2
163+
164+
155165
if __name__ == '__main__':
156166
from swift.llm import PtEngine, RequestConfig
157167
from swift.utils import get_logger, seed_everything
@@ -165,4 +175,5 @@ def test_qwen2_5_omni():
165175
# test_minicpmo()
166176
# test_valley()
167177
# test_qwen2_5_vl()
168-
test_qwen2_5_omni()
178+
# test_qwen2_5_omni()
179+
test_glm4_1v() # bug now, wait model fix

tests/test_align/test_template/test_vision.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,21 @@ def test_kimi_vl_thinking():
563563
'The second image is an illustration of four sheep in a car')
564564

565565

566+
def test_glm4_1v():
567+
models = ['ZhipuAI/GLM-4.1V-9B-Thinking']
568+
messages = [{'role': 'user', 'content': '<image><image>What is the difference between the two images?'}]
569+
images = [
570+
'http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/cat.png',
571+
'http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/animal.png'
572+
]
573+
for model in models:
574+
pt_engine = PtEngine(model)
575+
response = _infer_model(pt_engine, messages=messages, images=images)
576+
pt_engine.default_template.template_backend = 'jinja'
577+
response2 = _infer_model(pt_engine, messages=messages, images=images)
578+
assert response == response2
579+
580+
566581
if __name__ == '__main__':
567582
from swift.llm import PtEngine, RequestConfig
568583
from swift.utils import get_logger, seed_everything
@@ -616,4 +631,5 @@ def test_kimi_vl_thinking():
616631
# test_internvl3_8b()
617632
# test_internvl3_9b()
618633
# test_kimi_vl()
619-
test_kimi_vl_thinking()
634+
# test_kimi_vl_thinking()
635+
test_glm4_1v()

0 commit comments

Comments
 (0)