@@ -12,7 +12,7 @@ def datetime2sec(str):
1212 hh , mm , ss = str .split (':' )
1313 return int (hh ) * 3600 + int (mm ) * 60 + float (ss )
1414
15- def generate_train_ann (ann_file , verb_ids , noun_ids , gen_type = 'naive' , avion_prediction_path = '' ):
15+ def generate_train_ann (ann_file , verb_ids , noun_ids , gen_type = 'naive' , avion_prediction_path = '' , n_options = 5 ):
1616 assert gen_type in GEN_TYPES
1717 # epic kitchen uses csv
1818 csv_reader = csv .reader (open (ann_file ))
@@ -39,15 +39,15 @@ def generate_train_ann(ann_file, verb_ids, noun_ids, gen_type = 'naive', avion_p
3939 elif gen_type == "random_mc" :
4040 # here we use the index
4141 vn_str = f'{ row [10 ]} :{ row [12 ]} '
42- mc_data = mc_generator .generate_multi_choice (vn_str , 5 )
42+ mc_data = mc_generator .generate_multi_choice (vn_str , n_options )
4343 options = mc_data ['option' ][0 ]
4444 gt_answer_letter = mc_data ['gt_answer_letter' ][0 ]
4545 gt_answer_name = mc_data ['gt_answer_name' ][0 ]
4646 conversation = generate_random_mc_conversation (options , gt_answer_letter , gt_answer_name )
4747 elif gen_type == "avion_mc" :
4848 vn_str = f'{ row [10 ]} :{ row [12 ]} '
4949 avion_preds = avion_train_predictions [str (idx )]['predictions' ]
50- mc_data = mc_generator .generate_multi_choice (vn_str , avion_preds , 5 )
50+ mc_data = mc_generator .generate_multi_choice (vn_str , avion_preds , n_options )
5151 options = mc_data ['option' ][0 ]
5252 gt_answer_letter = mc_data ['gt_answer_letter' ][0 ]
5353 gt_answer_name = mc_data ['gt_answer_name' ][0 ]
@@ -86,27 +86,30 @@ def get_args():
8686 parser .add_argument ('--out_folder' , default = '/data/shaokai/EK100_in_LLAVA/' , type = str )
8787 parser .add_argument ('--avion_train_predictions' , default = '/data/shaokai/avion_predictions_train.json' , type = str )
8888 parser .add_argument ('--gen_type' , default = 'avion_mc' , type = str , choices = GEN_TYPES )
89+ parser .add_argument ('--n_options' , default = 5 , type = int )
8990 return parser .parse_args ()
9091
9192def main ():
9293 args = get_args ()
9394 ann_file = args .train_metadata
94- inst_train_folder = os .path .join (args .out_folder , args .gen_type )
95+ inst_train_folder = os .path .join (args .out_folder , f' { args .gen_type } _top { args . n_options } ' )
9596
9697 print ('train_metadata' , args .train_metadata )
9798 print ('out_folder' , args .out_folder )
9899 print ('loading predictions from ' , args .avion_train_predictions )
99100 print ('gen_type is ' , args .gen_type )
101+ print ('n_options' , args .n_options )
100102
101- os .makedirs (inst_train_folder , exist_ok = True )
103+ os .makedirs (inst_train_folder , exist_ok = True )
102104
103105 anno_path = Path (ann_file ).parent
104106 _ , _ , verb_ids , noun_ids = generate_label_map (anno_path )
105107 conv_lst = generate_train_ann (ann_file ,
106108 verb_ids ,
107109 noun_ids ,
108110 gen_type = args .gen_type ,
109- avion_prediction_path = args .avion_train_predictions )
111+ avion_prediction_path = args .avion_train_predictions ,
112+ n_options = args .n_options )
110113
111114 # save it to a jsonl
112115 with open (os .path .join (inst_train_folder ,'train_convs_narration.jsonl' ), 'w' ) as f :
0 commit comments