@@ -69,6 +69,10 @@ class GLM4_0414TemplateMeta(GLM4TemplateMeta):
69
69
agent_template : str = 'glm4_0414'
70
70
71
71
72
+ class GLM4_1VTemplateMeta (GLM4_0414TemplateMeta ):
73
+ system_prefix : Optional [Prompt ] = field (default_factory = lambda : ['[gMASK]<sop><|system|>{{SYSTEM}}' ])
74
+
75
+
72
76
class GLM4VTemplate (Template ):
73
77
74
78
def replace_tag (self , media_type : Literal ['image' , 'video' , 'audio' ], index : int ,
@@ -106,12 +110,132 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
106
110
return res
107
111
108
112
113
+ class GLM4_1VTemplate (Template ):
114
+ begin_of_image_token = 151339
115
+ end_of_image_token = 151340
116
+ image_token = 151343
117
+ begin_of_video_token = 151341
118
+ end_of_video_token = 151342
119
+ video_token = 151344
120
+
121
+ def replace_tag (self , media_type : Literal ['image' , 'video' , 'audio' ], index : int ,
122
+ inputs : StdTemplateInputs ) -> List [Context ]:
123
+ # TODO: model video infer bug
124
+ assert media_type in ['image' ]
125
+ if media_type == 'image' :
126
+ return [[- 100 ]]
127
+ elif media_type == 'video' :
128
+ return [[- 200 ]]
129
+
130
+ def _encode (self , inputs : StdTemplateInputs ) -> Dict [str , Any ]:
131
+ encoded = super ()._encode (inputs )
132
+ processor = self .processor
133
+ input_ids = encoded ['input_ids' ]
134
+ labels = encoded ['labels' ]
135
+ image_idx_list = findall (input_ids , - 100 )
136
+ video_idx_list = findall (input_ids , - 200 )
137
+ if image_idx_list :
138
+ images = inputs .images
139
+ image_inputs = processor .image_processor (images = images , return_tensors = 'pt' )
140
+ encoded ['pixel_values' ] = image_inputs ['pixel_values' ]
141
+ encoded ['image_grid_thw' ] = image_grid_thw = image_inputs ['image_grid_thw' ]
142
+ merge_length = processor .image_processor .merge_size ** 2
143
+ added_tokens_len = 0
144
+ for i , idx in enumerate (image_idx_list ):
145
+ num_image_tokens = image_grid_thw [i ].prod () // merge_length
146
+ image_tokens = [self .begin_of_image_token
147
+ ] + [self .image_token ] * num_image_tokens + [self .end_of_image_token ]
148
+
149
+ input_ids = input_ids [:added_tokens_len + idx ] + image_tokens + input_ids [added_tokens_len + idx + 1 :]
150
+ if labels is not None :
151
+ labels = labels [:added_tokens_len + idx ] + [- 100 ] * len (image_tokens ) + labels [added_tokens_len
152
+ + idx + 1 :]
153
+ added_tokens_len += len (image_tokens ) - 1
154
+
155
+ if video_idx_list :
156
+ # TODO: model video infer bug
157
+ assert len (
158
+ video_idx_list ) <= 1 , f'GLM4.1V model only support 1 video, but detected { len (video_idx_list )} <video> '
159
+ assert not image_idx_list , "GLM4.1V model doesn't support inputs containing both video and images"
160
+
161
+ video_fnames = inputs .videos
162
+ from transformers .video_utils import load_video
163
+ from transformers .image_utils import load_image
164
+ import numpy as np
165
+ video_metadata = []
166
+ videos = []
167
+ for fname in video_fnames :
168
+ if isinstance (fname , (list , tuple )) and isinstance (fname [0 ], str ):
169
+ video = [np .array (load_image (image_fname )) for image_fname in fname ]
170
+ # create a 4D video because `load_video` always returns a 4D array
171
+ video = np .stack (video )
172
+ metadata = None
173
+ else :
174
+ video , metadata = load_video (fname )
175
+ videos .append (video )
176
+ video_metadata .append (metadata )
177
+ videos = [videos ]
178
+ video_metadata = [video_metadata ]
179
+
180
+ videos_inputs = processor .video_processor (videos = videos , video_metadata = video_metadata , return_tensors = 'pt' )
181
+ encoded ['pixel_values_videos' ] = videos_inputs ['pixel_values_videos' ]
182
+ encoded ['video_grid_thw' ] = video_grid_thw = videos_inputs ['video_grid_thw' ]
183
+ timestamps = videos_inputs .pop ('timestamps' )
184
+ num_frames = len (video_grid_thw )
185
+ video_structure = [self .begin_of_video_token ]
186
+ if hasattr (timestamps , 'tolist' ):
187
+ timestamps_list = timestamps .tolist ()[0 ]
188
+ else :
189
+ timestamps_list = timestamps [0 ] if isinstance (timestamps [0 ], list ) else timestamps
190
+ unique_timestamps = []
191
+ for idx in range (0 , len (timestamps_list )):
192
+ unique_timestamps .append (timestamps_list [idx ])
193
+ selected_timestamps = unique_timestamps [:num_frames ]
194
+ while len (selected_timestamps ) < num_frames :
195
+ selected_timestamps .append (selected_timestamps [- 1 ] if selected_timestamps else 0 )
196
+ merge_length = processor .video_processor .merge_size ** 2
197
+ added_tokens_len = 0
198
+ for frame_idx in range (num_frames ):
199
+ timestamp_sec = selected_timestamps [frame_idx ]
200
+ num_image_tokens = video_grid_thw [frame_idx ].prod () // merge_length
201
+ timestamp_sec_token = processor .tokenizer (str (timestamp_sec ))['input_ids' ]
202
+ frame_structure = [self .begin_of_image_token ] + [self .image_token ] * num_image_tokens + \
203
+ [self .end_of_image_token ] + timestamp_sec_token
204
+ video_structure += frame_structure
205
+ video_structure += [self .end_of_video_token ]
206
+
207
+ for i , idx in enumerate (video_idx_list ):
208
+ # BUG in GLM4.1V?: All video placeholder take same tokens
209
+ # https://github.com/huggingface/transformers/blob/v4.53.0/src/transformers/models/glm4v/processing_glm4v.py#L165-L194
210
+ input_ids = input_ids [:added_tokens_len + idx ] + video_structure + \
211
+ input_ids [added_tokens_len + idx + 1 :]
212
+ if labels is not None :
213
+ labels = labels [:added_tokens_len + idx ] + [- 100 ] * len (video_structure ) + \
214
+ labels [added_tokens_len + idx + 1 :]
215
+ added_tokens_len += len (video_structure ) - 1
216
+
217
+ encoded ['input_ids' ] = input_ids
218
+ encoded ['labels' ] = labels
219
+ encoded ['position_ids' ] = list (range (len (input_ids )))
220
+ return encoded
221
+
222
+ def _data_collator_mm_data (self , batch : List [Dict [str , Any ]]) -> Dict [str , Any ]:
223
+ res = super ()._data_collator_mm_data (batch )
224
+ for media_type in ['image' , 'video' ]:
225
+ grid_thw = self .concat_tensor (batch , f'{ media_type } _grid_thw' , 0 )
226
+ if grid_thw is not None :
227
+ res [f'{ media_type } _grid_thw' ] = grid_thw
228
+ return res
229
+
230
+
109
231
register_template (GLM4TemplateMeta (MLLMTemplateType .glm4v , template_cls = GLM4VTemplate , suffix = ['<|endoftext|>' ]))
110
232
111
233
register_template (GLM4TemplateMeta (LLMTemplateType .glm4 , template_cls = GLM4Template ))
112
234
113
235
register_template (GLM4_0414TemplateMeta (LLMTemplateType .glm4_0414 , template_cls = GLM4_0414Template ))
114
236
237
+ register_template (GLM4_1VTemplateMeta (MLLMTemplateType .glm4_1v , template_cls = GLM4_1VTemplate ))
238
+
115
239
glm4z1rumination_system = (
116
240
'你是一个专业的深度研究助手,通过提供的工具与模拟浏览器交互,来帮助用户完成深度信息调研和报告撰写任务。'
117
241
'今年是 2025 年。\n \n '
0 commit comments