Skip to content

Commit c121c20

Browse files
update video code
1 parent 87f4b6c commit c121c20

File tree

6 files changed

+403
-7
lines changed

6 files changed

+403
-7
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,4 @@ build/
7070
playground/*.json
7171
mlx_configs/
7272
data_processing/
73-
demo/
73+
# demo/

docs/LLaVA-NeXT-Video_0716.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ The new model achieves the best open-source performance in several video benchma
1414
- **Model Card**: [LLaVA-NeXT-Video-32B-Qwen on Hugging Face](https://huggingface.co/lmms-lab/LLaVA-NeXT-Video-32B-Qwen)
1515
- **Inference Script**:
1616
```bash
17-
bash scripts/video/demo/video_demo.sh lmms-lab/LLaVA-NeXT-Video-32B-Qwen qwen_1_5 32 2 average after grid True playground/demo/xU25MMA2N4aVtYay.mp4
17+
bash scripts/video/demo/video_demo.sh lmms-lab/LLaVA-NeXT-Video-32B-Qwen qwen_1_5 32 2 average grid True playground/demo/xU25MMA2N4aVtYay.mp4
1818
```
1919

2020
### Evaluation Results

llava/model/llava_arch.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,22 @@ def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=N
199199
all_videos_or_images_features.append(feat)
200200
return all_videos_or_images_features
201201

202+
def add_token_per_grid(self, image_feature):
203+
resize_h = int(math.sqrt(image_feature.shape[1]))
204+
num_frames = image_feature.shape[0]
205+
image_feature = image_feature.view(num_frames, 1, resize_h, resize_h, -1)
206+
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
207+
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
208+
image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
209+
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
210+
return image_feature
211+
212+
def add_token_per_frame(self, image_feature):
213+
image_feature = image_feature.permute(2, 0, 1).contiguous()
214+
image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
215+
image_feature = image_feature.permute(1, 2, 0).contiguous()
216+
return image_feature
217+
202218
def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None):
203219
vision_tower = self.get_vision_tower()
204220
# rank_print(modalities)
@@ -253,12 +269,31 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio
253269
# rank0_print("At least we are reaching here")
254270
if image_idx in video_idx_in_batch: # video operations
255271
# rank0_print("Video")
256-
if "unpad" in mm_patch_merge_type:
257-
# image_feature = image_feature.permute(2, 0, 1).contiguous()
258-
# image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
259-
# image_feature = image_feature.permute(1, 2, 0).contiguous()
272+
if self.config.mm_newline_position == "grid":
273+
# Grid-wise
274+
image_feature = self.add_token_per_grid(image_feature)
275+
276+
new_image_features.append(image_feature)
277+
elif self.config.mm_newline_position == "frame":
278+
# Frame-wise
279+
image_feature = self.add_token_per_frame(image_feature)
280+
281+
new_image_features.append(image_feature.flatten(0, 1))
282+
283+
elif self.config.mm_newline_position == "one_token":
284+
# one-token
260285
image_feature = image_feature.flatten(0, 1)
261-
image_feature = torch.cat((image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0)
286+
if 'unpad' in mm_patch_merge_type:
287+
image_feature = torch.cat((
288+
image_feature,
289+
self.model.image_newline[None].to(image_feature.device)
290+
), dim=0)
291+
new_image_features.append(image_feature)
292+
elif self.config.mm_newline_position == "no_token":
293+
new_image_features.append(image_feature.flatten(0, 1))
294+
else:
295+
raise ValueError(f"Unexpected mm_newline_position: {self.config.mm_newline_position}")
296+
262297

263298
elif image_feature.shape[0] > 1: # multi patches and multi images operations
264299
# rank0_print("Single-images")

llava/train/train.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ class ModelArguments:
108108
pos_skipping_range: Optional[int] = field(default=4096)
109109

110110

111+
mm_newline_position: Optional[str] = field(default="one_token")
112+
113+
111114
@dataclass
112115
class DataArguments:
113116
data_path: str = field(default=None, metadata={"help": "Path to the training data, in llava's instruction.json format. Supporting multiple json files via /path/to/{a,b,c}.json"})
@@ -1576,6 +1579,7 @@ def make_inputs_require_grad(module, input, output):
15761579
model.config.image_split_resolution = data_args.image_split_resolution
15771580
model.config.tokenizer_padding_side = tokenizer.padding_side
15781581
model.config.tokenizer_model_max_length = tokenizer.model_max_length
1582+
model.config.mm_newline_position = model_args.mm_newline_position
15791583

15801584
### Deciding train which part of the model
15811585
if model_args.mm_tunable_parts is None: # traditional way of deciding which part to train

0 commit comments

Comments
 (0)