Skip to content

Commit 48890b0

Browse files
authored
Merge pull request #183 from LLaVA-VL/yhzhang/video_dev
update video code
2 parents 6ca43a6 + 066ea45 commit 48890b0

File tree

5 files changed

+164
-39
lines changed

5 files changed

+164
-39
lines changed

llava/model/llava_arch.py

Lines changed: 98 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,13 @@ def initialize_vision_modules(self, model_args, fsdp=None):
9393
self.config.mm_vision_select_feature = mm_vision_select_feature
9494
self.config.mm_patch_merge_type = mm_patch_merge_type
9595

96+
if not hasattr(self.config, 'add_faster_video'):
97+
if model_args.add_faster_video:
98+
embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
99+
self.faster_token = nn.Parameter(
100+
torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
101+
)
102+
96103
if getattr(self, "mm_projector", None) is None:
97104
self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
98105

@@ -160,19 +167,19 @@ def get_model(self):
160167
def get_vision_tower(self):
161168
return self.get_model().get_vision_tower()
162169

163-
def get_2dPool(self, image_feature):
170+
def get_2dPool(self, image_feature, stride=2):
164171
height = width = self.get_vision_tower().num_patches_per_side
165172
num_frames, num_tokens, num_dim = image_feature.shape
166173
image_feature = image_feature.view(num_frames, height, width, -1)
167174
image_feature = image_feature.permute(0, 3, 1, 2).contiguous()
168175
# image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride)
169176
if self.config.mm_spatial_pool_mode == "average":
170-
image_feature = nn.functional.avg_pool2d(image_feature, self.config.mm_spatial_pool_stride)
177+
image_feature = nn.functional.avg_pool2d(image_feature, stride)
171178
elif self.config.mm_spatial_pool_mode == "max":
172-
image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride)
179+
image_feature = nn.functional.max_pool2d(image_feature, stride)
173180
elif self.config.mm_spatial_pool_mode == "bilinear":
174181
height, weight = image_feature.shape[2:]
175-
scaled_shape = [math.ceil(height / 2), math.ceil(weight / 2)]
182+
scaled_shape = [math.ceil(height / stride), math.ceil(weight / stride)]
176183
image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear')
177184

178185
else:
@@ -191,13 +198,54 @@ def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=N
191198
videos_or_images_features = self.get_model().get_vision_tower()(videos_or_images)
192199
per_videos_or_images_features = torch.split(videos_or_images_features, split_sizes, dim=0) # tuple, (dim_1, 576, 4096)
193200
all_videos_or_images_features = []
201+
all_faster_video_features = []
202+
cur_mm_spatial_pool_stride = self.config.mm_spatial_pool_stride
194203

195204
for idx, feat in enumerate(per_videos_or_images_features):
205+
196206
feat = self.get_model().mm_projector(feat)
197-
if idx in video_idx_in_batch:
198-
feat = self.get_2dPool(feat)
199-
all_videos_or_images_features.append(feat)
200-
return all_videos_or_images_features
207+
faster_video_feature = 0
208+
slower_img_feat = 0
209+
if idx in video_idx_in_batch and cur_mm_spatial_pool_stride > 1:
210+
slower_img_feat = self.get_2dPool(feat,cur_mm_spatial_pool_stride)
211+
if self.config.add_faster_video:
212+
cur_mm_spatial_pool_stride = cur_mm_spatial_pool_stride * 2
213+
faster_video_feature = self.get_2dPool(feat,cur_mm_spatial_pool_stride)
214+
if slower_img_feat is not 0:
215+
all_videos_or_images_features.append(slower_img_feat)
216+
else:
217+
all_videos_or_images_features.append(feat)
218+
all_faster_video_features.append(faster_video_feature)
219+
return all_videos_or_images_features,all_faster_video_features
220+
221+
def add_token_per_grid(self, image_feature):
222+
resize_h = int(math.sqrt(image_feature.shape[1]))
223+
num_frames = image_feature.shape[0]
224+
feature_dim = image_feature.shape[-1]
225+
226+
image_feature = image_feature.view(num_frames, 1, resize_h, resize_h, -1)
227+
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
228+
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
229+
image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
230+
if self.config.add_faster_video:
231+
# import pdb; pdb.set_trace()
232+
# (3584, 832, 14) -> (3584, 64, 13, 14)
233+
image_feature = image_feature.view(feature_dim, num_frames,resize_h, -1)
234+
# (3584, 64, 13, 14) -> (64, 13, 14, 3584)
235+
image_feature = image_feature.permute(1, 2, 3, 0).contiguous()
236+
# (64, 13, 14, 3584) -> (64, 13*14, 3584)
237+
image_feature = image_feature.flatten(1, 2)
238+
# import pdb; pdb.set_trace()
239+
return image_feature
240+
# import pdb; pdb.set_trace()
241+
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
242+
return image_feature
243+
244+
def add_token_per_frame(self, image_feature):
245+
image_feature = image_feature.permute(2, 0, 1).contiguous()
246+
image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
247+
image_feature = image_feature.permute(1, 2, 0).contiguous()
248+
return image_feature
201249

202250
def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None):
203251
vision_tower = self.get_vision_tower()
@@ -224,6 +272,7 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio
224272
concat_images = torch.cat([image for image in images_list], dim=0)
225273
split_sizes = [image.shape[0] for image in images_list]
226274
encoded_image_features = self.encode_images(concat_images)
275+
# image_features,all_faster_video_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes)
227276

228277
# This is a list, each element is [num_images, patch * patch, dim]
229278
# rank_print(f"Concat images : {concat_images.shape}")
@@ -239,6 +288,7 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio
239288
# image_features = torch.split(image_features, split_sizes, dim=0)
240289
mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
241290
image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
291+
mm_newline_position = getattr(self.config, "mm_newline_position", "one_token")
242292

243293
if mm_patch_merge_type == "flat":
244294
image_features = [x.flatten(0, 1) for x in image_features]
@@ -253,13 +303,44 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio
253303
# rank0_print("At least we are reaching here")
254304
if image_idx in video_idx_in_batch: # video operations
255305
# rank0_print("Video")
256-
if "unpad" in mm_patch_merge_type:
257-
# image_feature = image_feature.permute(2, 0, 1).contiguous()
258-
# image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
259-
# image_feature = image_feature.permute(1, 2, 0).contiguous()
306+
if mm_newline_position == "grid":
307+
# Grid-wise
308+
image_feature = self.add_token_per_grid(image_feature)
309+
if self.config.add_faster_video:
310+
faster_video_feature = self.add_token_per_grid(all_faster_video_features[image_idx])
311+
# Add a token for each frame
312+
concat_slow_fater_token = []
313+
# import pdb; pdb.set_trace()
314+
for _ in range(image_feature.shape[0]):
315+
if _ % self.config.faster_token_stride == 0:
316+
concat_slow_fater_token.append(torch.cat((image_feature[_], self.model.faster_token[None].to(image_feature.device)), dim=0))
317+
else:
318+
concat_slow_fater_token.append(torch.cat((faster_video_feature[_], self.model.faster_token[None].to(image_feature.device)), dim=0))
319+
# import pdb; pdb.set_trace()
320+
image_feature = torch.cat(concat_slow_fater_token)
321+
322+
# print("!!!!!!!!!!!!")
323+
324+
new_image_features.append(image_feature)
325+
elif mm_newline_position == "frame":
326+
# Frame-wise
327+
image_feature = self.add_token_per_frame(image_feature)
328+
329+
new_image_features.append(image_feature.flatten(0, 1))
330+
331+
elif mm_newline_position == "one_token":
332+
# one-token
260333
image_feature = image_feature.flatten(0, 1)
261-
image_feature = torch.cat((image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0)
262-
334+
if 'unpad' in mm_patch_merge_type:
335+
image_feature = torch.cat((
336+
image_feature,
337+
self.model.image_newline[None].to(image_feature.device)
338+
), dim=0)
339+
new_image_features.append(image_feature)
340+
elif mm_newline_position == "no_token":
341+
new_image_features.append(image_feature.flatten(0, 1))
342+
else:
343+
raise ValueError(f"Unexpected mm_newline_position: {mm_newline_position}")
263344
elif image_feature.shape[0] > 1: # multi patches and multi images operations
264345
# rank0_print("Single-images")
265346
base_image_feature = image_feature[0]
@@ -316,12 +397,13 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio
316397
pass
317398
else:
318399
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
400+
new_image_features.append(image_feature)
319401
else: # single image operations
320402
image_feature = image_feature[0]
321403
if "unpad" in mm_patch_merge_type:
322404
image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0)
323405

324-
new_image_features.append(image_feature)
406+
new_image_features.append(image_feature)
325407
image_features = new_image_features
326408
else:
327409
raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
@@ -506,4 +588,4 @@ def initialize_vision_tokenizer(self, model_args, tokenizer):
506588
for p in self.get_input_embeddings().parameters():
507589
p.requires_grad = False
508590
for p in self.get_output_embeddings().parameters():
509-
p.requires_grad = False
591+
p.requires_grad = False

llava/train/train.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,11 @@ class ModelArguments:
108108
pos_skipping_range: Optional[int] = field(default=4096)
109109

110110

111-
mm_newline_position: Optional[str] = field(default="one_token")
111+
mm_newline_position: Optional[str] = field(default="grid")
112+
delay_load: Optional[bool] = field(default=True)
113+
add_faster_video: Optional[bool] = field(default=False)
114+
faster_token_stride: Optional[int] = field(default=10)
115+
112116

113117

114118
@dataclass
@@ -126,6 +130,8 @@ class DataArguments:
126130
video_folder: Optional[str] = field(default=None)
127131
video_fps: Optional[int] = field(default=1)
128132
frames_upbound: Optional[int] = field(default=0)
133+
add_time_instruction: Optional[bool] = field(default=False)
134+
force_sample: Optional[bool] = field(default=False)
129135

130136

131137
@dataclass
@@ -1158,10 +1164,22 @@ def _get_item(self, i) -> Dict[str, torch.Tensor]:
11581164
frame_files.sort() # Ensure the frames are sorted if they are named sequentially
11591165

11601166
# TODO: Hard CODE: Determine the indices for uniformly sampling 10 frames
1161-
num_frames_to_sample = 10
1167+
if self.data_args.force_sample:
1168+
num_frames_to_sample = self.data_args.frames_upbound
1169+
else:
1170+
num_frames_to_sample = 10
1171+
1172+
avg_fps = 2
1173+
11621174
total_frames = len(frame_files)
11631175
sampled_indices = np.linspace(0, total_frames - 1, num_frames_to_sample, dtype=int)
11641176

1177+
1178+
frame_time = [i/2 for i in sampled_indices]
1179+
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
1180+
1181+
video_time = total_frames / avg_fps
1182+
11651183
# Read and store the sampled frames
11661184
video = []
11671185
for idx in sampled_indices:
@@ -1173,12 +1191,16 @@ def _get_item(self, i) -> Dict[str, torch.Tensor]:
11731191
except IOError:
11741192
print(f"Failed to read frame at path: {frame_path}")
11751193
else:
1176-
video = process_video_with_decord(video_file, self.data_args)
1194+
video, video_time, frame_time, num_frames_to_sample = process_video_with_decord(video_file, self.data_args)
11771195

11781196
processor = self.data_args.image_processor
11791197
image = processor.preprocess(video, return_tensors="pt")["pixel_values"]
1198+
if self.data_args.add_time_instruction:
1199+
time_instruciton = f"The video lasts for {video_time:.2f} seconds, and {num_frames_to_sample} frames are uniformly sampled from it. These frames are located at {frame_time}.Please answer the following questions related to this video."
1200+
sources[0]["conversations"][0]["value"] = f'{DEFAULT_IMAGE_TOKEN}\n{time_instruciton}\n{sources[0]["conversations"][0]["value"].replace(DEFAULT_IMAGE_TOKEN, "")}'
11801201
image = [(image, video[0].size, "video")]
11811202
sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources]), self.data_args)
1203+
# print(sources)
11821204
except Exception as e:
11831205
print(f"Error: {e}")
11841206
print(f"Failed to read video file: {video_file}")
@@ -1580,6 +1602,11 @@ def make_inputs_require_grad(module, input, output):
15801602
model.config.tokenizer_padding_side = tokenizer.padding_side
15811603
model.config.tokenizer_model_max_length = tokenizer.model_max_length
15821604
model.config.mm_newline_position = model_args.mm_newline_position
1605+
model.config.add_faster_video = model_args.add_faster_video
1606+
model.config.faster_token_stride = model_args.faster_token_stride
1607+
model.config.add_time_instruction = data_args.add_time_instruction
1608+
model.config.force_sample = data_args.force_sample
1609+
model.config.mm_spatial_pool_stride = model_args.mm_spatial_pool_stride
15831610

15841611
### Deciding train which part of the model
15851612
if model_args.mm_tunable_parts is None: # traditional way of deciding which part to train

llava/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,25 @@
2525
def process_video_with_decord(video_file, data_args):
2626
vr = VideoReader(video_file, ctx=cpu(0), num_threads=1)
2727
total_frame_num = len(vr)
28+
video_time = total_frame_num / vr.get_avg_fps()
2829
avg_fps = round(vr.get_avg_fps() / data_args.video_fps)
2930
frame_idx = [i for i in range(0, total_frame_num, avg_fps)]
31+
frame_time = [i/avg_fps for i in frame_idx]
32+
3033

3134
if data_args.frames_upbound > 0:
32-
if len(frame_idx) > data_args.frames_upbound:
35+
if len(frame_idx) > data_args.frames_upbound or data_args.force_sample:
3336
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, data_args.frames_upbound, dtype=int)
3437
frame_idx = uniform_sampled_frames.tolist()
38+
frame_time = [i/vr.get_avg_fps() for i in frame_idx]
3539

3640
video = vr.get_batch(frame_idx).asnumpy()
41+
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
42+
43+
num_frames_to_sample = num_frames = len(frame_idx)
3744
# https://github.com/dmlc/decord/issues/208
3845
vr.seek(0)
39-
return video
46+
return video, video_time, frame_time, num_frames_to_sample
4047

4148
def process_video_with_pyav(video_file, data_args):
4249
container = av.open(video_file)

0 commit comments

Comments
 (0)