1
1
# Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from ast import Tuple
2
3
from functools import partial
3
4
from typing import Any , Dict , List , Literal , Optional
4
5
@@ -183,49 +184,126 @@ class InternS1Template(Internvl2Template, ThinkingTemplate):
183
184
'making your solution path and reasoning clear to others. '
184
185
'Please put your thinking process within <think>...</think> tags.' )
185
186
187
+ def replace_tag (self , media_type : Literal ['image' , 'video' , 'audio' ], index : int ,
188
+ inputs : StdTemplateInputs ) -> List [Context ]:
189
+ assert media_type in ['image' , 'video' ]
190
+ if media_type == 'video' :
191
+ if self .mode == 'vllm' :
192
+ return ['<video>' ]
193
+ else :
194
+ return [[- 200 ]]
195
+ return super ().replace_tag (media_type , index , inputs )
196
+
186
197
def _swift_encode (self , inputs : StdTemplateInputs ):
187
198
if inputs .system is None and self .template_meta .response_prefix == '<think>' :
188
199
inputs .system = self .InternS1DefaultThinkinngSystem
189
200
190
201
return super ()._swift_encode (inputs )
191
202
192
203
def _encode (self , inputs : StdTemplateInputs ) -> Dict [str , Any ]:
193
- from transformers .image_utils import make_flat_list_of_images
204
+ from transformers .image_utils import make_flat_list_of_images , concatenate_list
205
+ from transformers .video_utils import make_batched_videos
206
+ from swift .llm .template .vision_utils import load_video_hf
194
207
import numpy as np
195
208
encoded = super (InternvlTemplate , self )._encode (inputs )
196
209
input_ids = encoded ['input_ids' ]
197
- idx_list = findall (input_ids , - 100 )
198
210
labels = encoded ['labels' ]
199
211
loss_scale = encoded .get ('loss_scale' , None )
200
212
images = inputs .images
201
- if inputs .videos :
202
- # TODO
203
- raise NotImplementedError ('Video is not supported yet.' )
213
+ videos = inputs .videos
214
+ image_num_patches_indices = np .array ([0 ])
215
+ video_num_patches_indices = np .array ([0 ])
216
+ video_patch_indices = np .array ([0 ])
217
+ image_num_patches = []
218
+ video_num_patches = []
219
+ image_video_patches = []
220
+ image_idx_list = []
221
+ video_idx_list = []
222
+ image_pixel_values = None
223
+ video_pixel_values = None
224
+
204
225
if images :
205
226
# InternS1Processor
227
+ image_idx_list = findall (input_ids , - 100 )
206
228
images = make_flat_list_of_images (images )
207
229
image_inputs = self .processor .image_processor (images = images , crop_to_patches = True , return_tensors = 'pt' )
208
230
image_num_patches = image_inputs .pop ('num_patches' )
209
- pixel_values = image_inputs .pop ('pixel_values' )
231
+ image_pixel_values = image_inputs .pop ('pixel_values' )
210
232
image_num_patches_indices = np .cumsum (image_num_patches )
211
- # has_video = bool(inputs.videos) # TODO:video
212
- else :
213
- pixel_values = None
214
- image_num_patches_indices = []
215
- assert len (image_num_patches_indices ) == len (
216
- idx_list ), f'len(num_patches): { len (num_patches )} , len(idx_list): { len (idx_list )} '
233
+ if videos :
234
+ video_idx_list = findall (input_ids , - 200 )
235
+ videos , _ = load_video_hf (videos )
236
+ videos = make_batched_videos (videos )
237
+ video_inputs = self .processor .video_processor (videos = videos , return_tensors = 'pt' )
238
+ video_pixel_values = video_inputs .pop ('pixel_values_videos' )
239
+ num_frames_per_video = [len (video ) for video in video_pixel_values ]
240
+ video_num_patches = [1 for frames in num_frames_per_video for _ in range (frames )]
241
+ video_patch_indices = np .cumsum (num_frames_per_video )
242
+ video_num_patches_indices = np .cumsum (video_num_patches )
243
+ video_pixel_values = video_pixel_values .flatten (0 , 1 )
244
+
245
+ def merge_and_sort (image_idx_list : List [int ], video_idx_list : List [int ]) -> tuple :
246
+ """Merge and sort image and video index lists while preserving their relative order."""
247
+ merged = []
248
+ is_image_list = []
249
+ i , j = 0 , 0
250
+
251
+ while i < len (image_idx_list ) and j < len (video_idx_list ):
252
+ if image_idx_list [i ] < video_idx_list [j ]:
253
+ merged .append (image_idx_list [i ])
254
+ i += 1
255
+ is_image_list .append (True )
256
+ else :
257
+ merged .append (video_idx_list [j ])
258
+ j += 1
259
+ is_image_list .append (False )
260
+ # Add remaining elements
261
+ merged .extend (image_idx_list [i :])
262
+ is_image_list .extend ([True ] * (len (image_idx_list ) - i ))
263
+ merged .extend (video_idx_list [j :])
264
+ is_image_list .extend ([False ] * (len (video_idx_list ) - j ))
265
+ return merged , is_image_list
266
+
267
+ # Merge and sort the index lists
268
+ idx_list , is_image_list = merge_and_sort (image_idx_list , video_idx_list )
269
+
270
+ # Validate the lengths
271
+ if images and len (image_idx_list ) > 0 :
272
+ assert len (image_num_patches_indices ) == len (image_idx_list )
273
+ if videos and len (video_idx_list ) > 0 :
274
+ assert len (video_patch_indices ) == len (video_idx_list )
217
275
218
276
def _get_new_tokens (i ):
219
- start = image_num_patches_indices [i - 1 ] if i > 0 else 0
220
- end = image_num_patches_indices [i ]
221
- image_seq_length = self .processor .image_seq_length
222
- img_tokens : List [int ] = self .processor .encode (
223
- '<IMG_CONTEXT>' , add_special_tokens = False ) * image_seq_length * image_num_patches [start :end ]
277
+ if is_image_list [i ]:
278
+ # Find the corresponding image index
279
+ image_idx = sum (is_image_list [:i ])
280
+ start = image_num_patches_indices [image_idx - 1 ] if image_idx > 0 else 0
281
+ end = image_num_patches_indices [image_idx ]
282
+ image_seq_length = self .processor .image_seq_length
283
+ image_video_patches .append (image_pixel_values [start :end ])
284
+ img_tokens : List [int ] = self .processor .encode (
285
+ '<IMG_CONTEXT>' , add_special_tokens = False ) * image_seq_length * image_num_patches [image_idx ]
286
+ else :
287
+ # Find the corresponding video index
288
+ video_idx = i - sum (is_image_list [:i ])
289
+ current_patch = video_patch_indices [video_idx - 1 ] if video_idx > 0 else 0
290
+ end_patch = video_patch_indices [video_idx ]
291
+
292
+ start = video_num_patches_indices [current_patch ] if video_idx > 0 else 0
293
+ end = video_num_patches_indices [end_patch - 1 ]
294
+ image_video_patches .append (video_pixel_values [start :end ])
295
+ image_seq_length = self .processor .image_seq_length
296
+ num_patches = list (video_num_patches [current_patch :end_patch ])
297
+ video_prompt = '\n ' .join (
298
+ f"Frame{ i + 1 } : <img>{ '<IMG_CONTEXT>' * image_seq_length * num_patches [i ]} </img>"
299
+ for i in range (len (num_patches )))
300
+ img_tokens = self .processor .encode (video_prompt , add_special_tokens = False )
224
301
return img_tokens
225
302
226
303
encoded ['input_ids' ], encoded ['labels' ], encoded ['loss_scale' ] = self ._extend_tokens (
227
304
input_ids , labels , loss_scale , idx_list , _get_new_tokens )
228
- encoded ['pixel_values' ] = pixel_values
305
+ if images or videos :
306
+ encoded ['pixel_values' ] = concatenate_list (image_video_patches )
229
307
return encoded
230
308
231
309
def _post_encode (self , model : nn .Module , inputs : Dict [str , Any ]) -> Dict [str , Any ]:
@@ -247,8 +325,6 @@ def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, An
247
325
pixel_values , vision_feature_layer = - 1 , vision_feature_select_strategy = 'default' )
248
326
image_features = image_features .to (inputs_embeds .device , inputs_embeds .dtype )
249
327
inputs_embeds = inputs_embeds .masked_scatter (special_image_mask , image_features )
250
-
251
- inputs_embeds = inputs_embeds .masked_scatter (special_image_mask , image_features )
252
328
elif is_deepspeed_enabled ():
253
329
dummy_pixel_values = torch .zeros ((1 , 3 , 32 , 32 ), device = device , dtype = inputs_embeds .dtype )
254
330
vit_embeds = model .model .vision_tower .embeddings (dummy_pixel_values )[0 ].to (device = device )
0 commit comments