Skip to content

Commit 7cfb5c8

Browse files
authored
support Intern-S1 video (#5514)
* wip video * support video * test * si video * test * fix test input * annotation
1 parent a52025d commit 7cfb5c8

File tree

2 files changed

+108
-21
lines changed

2 files changed

+108
-21
lines changed

swift/llm/template/template/internvl.py

Lines changed: 96 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
from ast import Tuple
23
from functools import partial
34
from typing import Any, Dict, List, Literal, Optional
45

@@ -183,49 +184,126 @@ class InternS1Template(Internvl2Template, ThinkingTemplate):
183184
'making your solution path and reasoning clear to others. '
184185
'Please put your thinking process within <think>...</think> tags.')
185186

187+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
188+
inputs: StdTemplateInputs) -> List[Context]:
189+
assert media_type in ['image', 'video']
190+
if media_type == 'video':
191+
if self.mode == 'vllm':
192+
return ['<video>']
193+
else:
194+
return [[-200]]
195+
return super().replace_tag(media_type, index, inputs)
196+
186197
def _swift_encode(self, inputs: StdTemplateInputs):
187198
if inputs.system is None and self.template_meta.response_prefix == '<think>':
188199
inputs.system = self.InternS1DefaultThinkinngSystem
189200

190201
return super()._swift_encode(inputs)
191202

192203
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
193-
from transformers.image_utils import make_flat_list_of_images
204+
from transformers.image_utils import make_flat_list_of_images, concatenate_list
205+
from transformers.video_utils import make_batched_videos
206+
from swift.llm.template.vision_utils import load_video_hf
194207
import numpy as np
195208
encoded = super(InternvlTemplate, self)._encode(inputs)
196209
input_ids = encoded['input_ids']
197-
idx_list = findall(input_ids, -100)
198210
labels = encoded['labels']
199211
loss_scale = encoded.get('loss_scale', None)
200212
images = inputs.images
201-
if inputs.videos:
202-
# TODO
203-
raise NotImplementedError('Video is not supported yet.')
213+
videos = inputs.videos
214+
image_num_patches_indices = np.array([0])
215+
video_num_patches_indices = np.array([0])
216+
video_patch_indices = np.array([0])
217+
image_num_patches = []
218+
video_num_patches = []
219+
image_video_patches = []
220+
image_idx_list = []
221+
video_idx_list = []
222+
image_pixel_values = None
223+
video_pixel_values = None
224+
204225
if images:
205226
# InternS1Processor
227+
image_idx_list = findall(input_ids, -100)
206228
images = make_flat_list_of_images(images)
207229
image_inputs = self.processor.image_processor(images=images, crop_to_patches=True, return_tensors='pt')
208230
image_num_patches = image_inputs.pop('num_patches')
209-
pixel_values = image_inputs.pop('pixel_values')
231+
image_pixel_values = image_inputs.pop('pixel_values')
210232
image_num_patches_indices = np.cumsum(image_num_patches)
211-
# has_video = bool(inputs.videos) # TODO:video
212-
else:
213-
pixel_values = None
214-
image_num_patches_indices = []
215-
assert len(image_num_patches_indices) == len(
216-
idx_list), f'len(num_patches): {len(num_patches)}, len(idx_list): {len(idx_list)}'
233+
if videos:
234+
video_idx_list = findall(input_ids, -200)
235+
videos, _ = load_video_hf(videos)
236+
videos = make_batched_videos(videos)
237+
video_inputs = self.processor.video_processor(videos=videos, return_tensors='pt')
238+
video_pixel_values = video_inputs.pop('pixel_values_videos')
239+
num_frames_per_video = [len(video) for video in video_pixel_values]
240+
video_num_patches = [1 for frames in num_frames_per_video for _ in range(frames)]
241+
video_patch_indices = np.cumsum(num_frames_per_video)
242+
video_num_patches_indices = np.cumsum(video_num_patches)
243+
video_pixel_values = video_pixel_values.flatten(0, 1)
244+
245+
def merge_and_sort(image_idx_list: List[int], video_idx_list: List[int]) -> tuple:
246+
"""Merge and sort image and video index lists while preserving their relative order."""
247+
merged = []
248+
is_image_list = []
249+
i, j = 0, 0
250+
251+
while i < len(image_idx_list) and j < len(video_idx_list):
252+
if image_idx_list[i] < video_idx_list[j]:
253+
merged.append(image_idx_list[i])
254+
i += 1
255+
is_image_list.append(True)
256+
else:
257+
merged.append(video_idx_list[j])
258+
j += 1
259+
is_image_list.append(False)
260+
# Add remaining elements
261+
merged.extend(image_idx_list[i:])
262+
is_image_list.extend([True] * (len(image_idx_list) - i))
263+
merged.extend(video_idx_list[j:])
264+
is_image_list.extend([False] * (len(video_idx_list) - j))
265+
return merged, is_image_list
266+
267+
# Merge and sort the index lists
268+
idx_list, is_image_list = merge_and_sort(image_idx_list, video_idx_list)
269+
270+
# Validate the lengths
271+
if images and len(image_idx_list) > 0:
272+
assert len(image_num_patches_indices) == len(image_idx_list)
273+
if videos and len(video_idx_list) > 0:
274+
assert len(video_patch_indices) == len(video_idx_list)
217275

218276
def _get_new_tokens(i):
219-
start = image_num_patches_indices[i - 1] if i > 0 else 0
220-
end = image_num_patches_indices[i]
221-
image_seq_length = self.processor.image_seq_length
222-
img_tokens: List[int] = self.processor.encode(
223-
'<IMG_CONTEXT>', add_special_tokens=False) * image_seq_length * image_num_patches[start:end]
277+
if is_image_list[i]:
278+
# Find the corresponding image index
279+
image_idx = sum(is_image_list[:i])
280+
start = image_num_patches_indices[image_idx - 1] if image_idx > 0 else 0
281+
end = image_num_patches_indices[image_idx]
282+
image_seq_length = self.processor.image_seq_length
283+
image_video_patches.append(image_pixel_values[start:end])
284+
img_tokens: List[int] = self.processor.encode(
285+
'<IMG_CONTEXT>', add_special_tokens=False) * image_seq_length * image_num_patches[image_idx]
286+
else:
287+
# Find the corresponding video index
288+
video_idx = i - sum(is_image_list[:i])
289+
current_patch = video_patch_indices[video_idx - 1] if video_idx > 0 else 0
290+
end_patch = video_patch_indices[video_idx]
291+
292+
start = video_num_patches_indices[current_patch] if video_idx > 0 else 0
293+
end = video_num_patches_indices[end_patch - 1]
294+
image_video_patches.append(video_pixel_values[start:end])
295+
image_seq_length = self.processor.image_seq_length
296+
num_patches = list(video_num_patches[current_patch:end_patch])
297+
video_prompt = '\n'.join(
298+
f"Frame{i + 1}: <img>{'<IMG_CONTEXT>' * image_seq_length * num_patches[i]}</img>"
299+
for i in range(len(num_patches)))
300+
img_tokens = self.processor.encode(video_prompt, add_special_tokens=False)
224301
return img_tokens
225302

226303
encoded['input_ids'], encoded['labels'], encoded['loss_scale'] = self._extend_tokens(
227304
input_ids, labels, loss_scale, idx_list, _get_new_tokens)
228-
encoded['pixel_values'] = pixel_values
305+
if images or videos:
306+
encoded['pixel_values'] = concatenate_list(image_video_patches)
229307
return encoded
230308

231309
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
@@ -247,8 +325,6 @@ def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, An
247325
pixel_values, vision_feature_layer=-1, vision_feature_select_strategy='default')
248326
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
249327
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
250-
251-
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
252328
elif is_deepspeed_enabled():
253329
dummy_pixel_values = torch.zeros((1, 3, 32, 32), device=device, dtype=inputs_embeds.dtype)
254330
vit_embeds = model.model.vision_tower.embeddings(dummy_pixel_values)[0].to(device=device)

tests/test_align/test_template/test_video.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,16 @@ def test_ovis2_5():
190190
print(f'response: {response}')
191191

192192

193+
def test_interns1():
194+
pt_engine = PtEngine('Shanghai_AI_Laboratory/Intern-S1-mini')
195+
messages = [{'role': 'user', 'content': '<video>Describe this video in detail.'}]
196+
videos = ['https://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/baby.mp4']
197+
response = _infer_model(pt_engine, messages=messages, videos=videos)
198+
pt_engine.default_template.template_backend = 'jinja'
199+
response2 = _infer_model(pt_engine, messages=messages, videos=videos)
200+
assert response == response2
201+
202+
193203
if __name__ == '__main__':
194204
from swift.llm import PtEngine, RequestConfig
195205
from swift.utils import get_logger, seed_everything
@@ -207,4 +217,5 @@ def test_ovis2_5():
207217
# test_glm4_1v() # bug now, wait model fix
208218
# test_keye_vl()
209219
# test_glm4_5v()
210-
test_ovis2_5()
220+
# test_ovis2_5()
221+
test_interns1()

0 commit comments

Comments
 (0)