Skip to content

Commit f3b6f6b

Browse files
author
Ye Shaokai
committed
added assertion to check whether avion gt matches my gt
1 parent 11e0e64 commit f3b6f6b

File tree

3 files changed

+15
-4
lines changed

3 files changed

+15
-4
lines changed

action/ek_eval.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,9 +337,18 @@ def prepare_llava(pretrained):
337337
model_name = "llava_qwen"
338338

339339
device_map = "auto"
340-
print ('pretrained???', pretrained)
341-
#tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, attn_implementation="sdpa")
342-
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, torch_dtype="bfloat16", device_map=device_map) # Add any other thing you want to pass in llava_model_args
340+
341+
overwrite_config = None
342+
if 'video' in pretrained:
343+
overwrite_config = {'tie_word_embeddings': False, 'use_cache': True, "vocab_size": 152064}
344+
345+
print ('overwrite config', overwrite_config)
346+
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained,
347+
None,
348+
model_name,
349+
torch_dtype="bfloat16",
350+
device_map=device_map,
351+
overwrite_config = overwrite_config) # Add any other thing you want to pass in llava_model_args
343352

344353

345354
return tokenizer, model, image_processor, max_length

action/generate_description.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,12 @@ def generate_train_ann(ann_file, verb_ids, noun_ids, gen_type = 'naive', avion_p
4747
elif gen_type == "avion_mc":
4848
vn_str = f'{row[10]}:{row[12]}'
4949
avion_preds = avion_train_predictions[str(idx)]['predictions']
50+
gt_from_avion = avion_train_predictions[str(idx)]['target']
5051
mc_data = mc_generator.generate_multi_choice(vn_str, avion_preds, n_options)
5152
options = mc_data['options'][0]
5253
gt_answer_letter = mc_data['gt_answer_letter'][0]
5354
gt_answer_name = mc_data['gt_answer_name'][0]
55+
assert gt_answer_name.replace(' ', ':') == gt_from_avion
5456
conversation = generate_random_mc_conversation(options, gt_answer_letter, gt_answer_name )
5557

5658
data = {'video': vid_path,

shaokai_generate_train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@ python3 action/generate_description.py \
88
--out_folder /storage-rcp-pure/upmwmathis_scratch/shaokai/EK100_inst_train \
99
--avion_train_predictions /storage-rcp-pure/upmwmathis_scratch/shaokai/avion_predictions_train.json \
1010
--gen_type avion_mc \
11-
--n_options 10
11+
--n_options 3
1212

0 commit comments

Comments
 (0)