Skip to content

Commit 5278e94

Browse files
committed
updates
1 parent afe35d4 commit 5278e94

File tree

4 files changed

+11
-2
lines changed

4 files changed

+11
-2
lines changed

action/dataset.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,8 @@ def get_args_parser():
630630
## avaion refinement
631631
parser.add_argument('--action_predictions', default=None, type=str, help='path to action predictions')
632632
parser.add_argument('--topk_predictions', default = 5, type =int)
633-
633+
parser.add_argument('--llava_checkpoint', default = None, type = str)
634+
634635
return parser
635636

636637
def prepare_llava(pretrained):
@@ -639,7 +640,6 @@ def prepare_llava(pretrained):
639640
from llava.model.builder import load_pretrained_model
640641
warnings.filterwarnings("ignore")
641642
# Load the OneVision model
642-
#pretrained = f"lmms-lab/llava-onevision-qwen2-{llm_size}-ov"
643643
model_name = "llava_qwen"
644644

645645
device_map = "auto"
@@ -722,6 +722,9 @@ def get_topk_predictions(data, idx, k):
722722

723723
pretrained = f"lmms-lab/llava-onevision-qwen2-{args.llm_size}-ov"
724724

725+
if args.llava_checkpoint:
726+
pretrained = args.llava_checkpoint
727+
725728
tokenizer, model, image_processor, max_length = prepare_llava(pretrained)
726729

727730
if args.action_predictions:

llava/model/builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def load_from_hf(repo_id, filename, subfolder=None):
200200
model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
201201

202202
elif "qwen" in model_name.lower() or "quyen" in model_name.lower():
203+
203204
tokenizer = AutoTokenizer.from_pretrained(model_path)
204205
if "moe" in model_name.lower() or "A14B" in model_name.lower():
205206
from llava.model.language_model.llava_qwen_moe import LlavaQwenMoeConfig
@@ -214,6 +215,7 @@ def load_from_hf(repo_id, filename, subfolder=None):
214215

215216
else:
216217
from llava.model.language_model.llava_qwen import LlavaQwenConfig
218+
217219
if overwrite_config is not None:
218220
llava_cfg = LlavaQwenConfig.from_pretrained(model_path)
219221
rank0_print(f"Overwriting config with {overwrite_config}")

llava/model/language_model/llava_qwen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
4848

4949
def __init__(self, config):
5050
# super(Qwen2ForCausalLM, self).__init__(config)
51+
print ('what does config look like')
52+
print (config)
5153
Qwen2ForCausalLM.__init__(self, config)
54+
5255
config.model_type = "llava_qwen"
5356
config.rope_scaling = None
5457

run_EK100.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@ python3 action/dataset.py \
44
--val-metadata /media/data/haozhe/VFM/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv \
55
--llm_size 7b \
66
--llava_num_frames 16 > kitchen_test.out 2>&1 \
7+
# --llava_checkpoint /data/epic_kitchen/EK100_test/checkpoint-8402
78
# --action_predictions action/avaion_predictions.json \
89
# --topk_predictions 10

0 commit comments

Comments
 (0)