Skip to content

Commit e96b45b

Browse files
update video code
1 parent a78b59d commit e96b45b

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

llava/model/builder.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,11 @@ def load_from_hf(repo_id, filename, subfolder=None):
238238
llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
239239
else:
240240
llava_cfg = customized_config
241+
242+
if overwrite_config is not None:
243+
rank0_print(f"Overwriting config with {overwrite_config}")
244+
for k, v in overwrite_config.items():
245+
setattr(llava_cfg, k, v)
241246
model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
242247
except:
243248
raise ValueError(f"Model {model_name} not supported")

llava/model/llava_arch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio
223223

224224
if isinstance(modalities, str):
225225
modalities = [modalities]
226-
226+
227227
if type(images) is list or images.ndim == 5:
228228
if type(images) is list:
229229
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
@@ -242,6 +242,7 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio
242242
else:
243243
images_list.append(image.unsqueeze(0))
244244

245+
# import pdb;pdb.set_trace()
245246
concat_images = torch.cat([image for image in images_list], dim=0)
246247
split_sizes = [image.shape[0] for image in images_list]
247248
encoded_image_features = self.encode_images(concat_images)

0 commit comments

Comments
 (0)