Skip to content

Commit d7f023a

Browse files
Internvl2 support video (#1366)
1 parent badcaf1 commit d7f023a

File tree

2 files changed

+105
-0
lines changed

2 files changed

+105
-0
lines changed

swift/llm/utils/template.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,76 @@ def get_generate_ids(generate_ids: Tensor, input_token_len: int) -> List[int]:
13871387

13881388
class Internvl2Template(InternvlTemplate):
13891389

1390+
video_segments = 8
1391+
1392+
def replace_tag(self, media_type, index, example) -> List[Context]:
1393+
if media_type == 'image':
1394+
return [[-100]]
1395+
elif media_type == 'video':
1396+
context_list = []
1397+
for i in range(self.video_segments):
1398+
context_list.append(f'Frame{i + 1}: ')
1399+
context_list.append([-100])
1400+
context_list.append('\n')
1401+
return context_list
1402+
1403+
def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1404+
inputs, _ = super(InternvlTemplate, self).encode(example)
1405+
if len(inputs) == 0:
1406+
return inputs, {}
1407+
input_ids = inputs['input_ids']
1408+
idx_list = _findall(input_ids, -100)
1409+
labels = inputs.get('labels')
1410+
images_path = example.get('images') or []
1411+
videos_path = example.get('videos') or []
1412+
if images_path:
1413+
from .vision_utils import load_image
1414+
pixel_values = []
1415+
if isinstance(images_path, str):
1416+
images_path = [images_path]
1417+
for image_path in images_path:
1418+
pixel_values.append(load_image(image_path))
1419+
1420+
assert len(images_path) == len(idx_list)
1421+
added_tokens_len = 0
1422+
patches = 0
1423+
for idx, pv in zip(idx_list, pixel_values):
1424+
patches += pv.shape[0]
1425+
img_tokens: List[int] = self.tokenizer.encode(
1426+
'<img>' + '<IMG_CONTEXT>' * self.num_image_token * pv.shape[0] + '</img>\n',
1427+
add_special_tokens=False)
1428+
input_ids = input_ids[:idx + added_tokens_len] + img_tokens + input_ids[idx + added_tokens_len + 1:]
1429+
if labels is not None:
1430+
labels = labels[:idx + added_tokens_len] + [-100] * len(img_tokens) + labels[idx + added_tokens_len
1431+
+ 1:]
1432+
added_tokens_len += len(img_tokens) - 1
1433+
inputs['input_ids'] = input_ids
1434+
inputs['labels'] = labels
1435+
inputs['pixel_values'] = torch.cat(pixel_values).to(self.model.dtype)
1436+
inputs['image_flags'] = torch.ones(patches)
1437+
if videos_path:
1438+
if not isinstance(videos_path, (list, tuple)):
1439+
videos_path = [videos_path]
1440+
assert len(videos_path) == 1
1441+
from swift.llm.utils.vision_utils import load_video
1442+
pixel_values, num_patches = load_video(videos_path[0], num_segments=self.video_segments)
1443+
assert len(num_patches) == len(idx_list)
1444+
added_tokens_len = 0
1445+
for idx, num_patch in zip(idx_list, num_patches):
1446+
img_tokens: List[int] = self.tokenizer.encode(
1447+
'<img>' + '<IMG_CONTEXT>' * self.num_image_token * num_patch + '</img>\n', add_special_tokens=False)
1448+
input_ids = input_ids[:idx + added_tokens_len] + img_tokens + input_ids[idx + added_tokens_len + 1:]
1449+
if labels is not None:
1450+
labels = labels[:idx + added_tokens_len] + [-100] * len(img_tokens) + labels[idx + added_tokens_len
1451+
+ 1:]
1452+
added_tokens_len += len(img_tokens) - 1
1453+
inputs['input_ids'] = input_ids
1454+
inputs['labels'] = labels
1455+
inputs['pixel_values'] = pixel_values.to(self.model.dtype)
1456+
inputs['image_flags'] = torch.ones(sum(num_patches))
1457+
inputs.pop('loss_scale', None)
1458+
return inputs, {}
1459+
13901460
def __init__(self):
13911461
self.system = '你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。'
13921462
Template.__init__(self, [], ['<|im_start|>user\n{{QUERY}}<|im_end|><|im_start|>assistant\n'], ['<|im_end|>'],

swift/llm/utils/vision_utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
from io import BytesIO
55

6+
import numpy as np
67
import requests
78
import torch
89
import torchvision.transforms as T
@@ -97,3 +98,37 @@ def load_image(img_path, input_size=448, max_num=6):
9798
pixel_values = [transform(image) for image in images]
9899
pixel_values = torch.stack(pixel_values)
99100
return pixel_values
101+
102+
103+
def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
104+
if bound:
105+
start, end = bound[0], bound[1]
106+
else:
107+
start, end = -100000, 100000
108+
start_idx = max(first_idx, round(start * fps))
109+
end_idx = min(round(end * fps), max_frame)
110+
seg_size = float(end_idx - start_idx) / num_segments
111+
frame_indices = np.array(
112+
[int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments)])
113+
return frame_indices
114+
115+
116+
def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
117+
from decord import VideoReader, cpu
118+
from PIL import Image
119+
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
120+
max_frame = len(vr) - 1
121+
fps = float(vr.get_avg_fps())
122+
123+
pixel_values_list, num_patches_list = [], []
124+
transform = build_transform(input_size=input_size)
125+
frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
126+
for frame_index in frame_indices:
127+
img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
128+
img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
129+
pixel_values = [transform(tile) for tile in img]
130+
pixel_values = torch.stack(pixel_values)
131+
num_patches_list.append(pixel_values.shape[0])
132+
pixel_values_list.append(pixel_values)
133+
pixel_values = torch.cat(pixel_values_list)
134+
return pixel_values, num_patches_list

0 commit comments

Comments
 (0)