22import csv
33import os
44import argparse
5- from action .utils import generate_label_map , MultiChoiceGenerator
5+ from action .utils import generate_label_map , MultiChoiceGenerator , AvionMultiChoiceGenerator
66from pathlib import Path
77
88
@@ -12,7 +12,7 @@ def datetime2sec(str):
1212 hh , mm , ss = str .split (':' )
1313 return int (hh ) * 3600 + int (mm ) * 60 + float (ss )
1414
15- def generate_train_ann (ann_file , verb_ids , noun_ids , gen_type = 'naive' ):
15+ def generate_train_ann (ann_file , verb_ids , noun_ids , gen_type = 'naive' , avion_prediction_path = '' ):
1616 assert gen_type in GEN_TYPES
1717 # epic kitchen uses csv
1818 csv_reader = csv .reader (open (ann_file ))
@@ -21,8 +21,12 @@ def generate_train_ann(ann_file, verb_ids, noun_ids, gen_type = 'naive'):
2121 ann_root = Path (ann_file ).parent
2222 if gen_type == "random_mc" :
2323 mc_generator = MultiChoiceGenerator (ann_root )
24+ elif gen_type == 'avion_mc' :
25+ mc_generator = AvionMultiChoiceGenerator (ann_root )
26+ with open (avion_prediction_path , 'r' ) as f :
27+ avion_train_predictions = json .load (f )
2428
25- for row in csv_reader :
29+ for idx , row in enumerate ( csv_reader ) :
2630 start_timestamp , end_timestamp = datetime2sec (row [4 ]), datetime2sec (row [5 ])
2731
2832 pid , vid = row [1 :3 ]
@@ -40,6 +44,14 @@ def generate_train_ann(ann_file, verb_ids, noun_ids, gen_type = 'naive'):
4044 gt_answer_letter = mc_data ['gt_answer_letter' ][0 ]
4145 gt_answer_name = mc_data ['gt_answer_name' ][0 ]
4246 conversation = generate_random_mc_conversation (options , gt_answer_letter , gt_answer_name )
47+ elif gen_type == "avion_mc" :
48+ vn_str = f'{ row [10 ]} :{ row [12 ]} '
49+ avion_preds = avion_train_predictions [str (idx )]['predictions' ]
50+ mc_data = mc_generator .generate_multi_choice (vn_str , avion_preds , 5 )
51+ options = mc_data ['option' ][0 ]
52+ gt_answer_letter = mc_data ['gt_answer_letter' ][0 ]
53+ gt_answer_name = mc_data ['gt_answer_name' ][0 ]
54+ conversation = generate_random_mc_conversation (options , gt_answer_letter , gt_answer_name )
4355
4456 data = {'video' : vid_path ,
4557 'conversations' : conversation ,
@@ -67,24 +79,38 @@ def generate_random_mc_conversation(options:list[str], gt_answer_letter, gt_answ
6779 {"from" : "gpt" , "value" : f"{ gt_answer_letter } . { gt_answer_name } " }
6880 ]
6981
82+ def generate_avion_mc_conversation ():
83+ pass
84+
7085
7186def get_args ():
7287 parser = argparse .ArgumentParser (description = "For generating VQA for EPIC-KITCHEN" )
7388 parser .add_argument ('--train_metadata' , default = '/data/shaokai/epic-kitchens-100-annotations/EPIC_100_train.csv' , type = str )
7489 parser .add_argument ('--out_folder' , default = '/data/shaokai/EK100_in_LLAVA/' , type = str )
90+ parser .add_argument ('--avion_train_predictions' , default = '/data/shaokai/avion_predictions_train.json' , type = str )
91+ parser .add_argument ('--gen_type' , default = 'avion_mc' , type = str , choices = GEN_TYPES )
7592 return parser .parse_args ()
7693
77- def main ():
94+ def main ():
7895 args = get_args ()
7996 ann_file = args .train_metadata
80- inst_train_folder = args .out_folder
81- print (ann_file )
82- anno_path = Path (ann_file ).parent
83- labels , mapping_vn2act , verb_ids , noun_ids = generate_label_map (anno_path )
84- conv_lst = generate_train_ann (ann_file , verb_ids , noun_ids , gen_type = 'random_mc' )
85-
97+ inst_train_folder = os .path .join (args .out_folder , args .gen_type )
98+
99+ print ('train_metadata' , args .train_metadata )
100+ print ('out_folder' , args .out_folder )
101+ print ('loading predictions from ' , args .avion_train_predictions )
102+ print ('gen_type is ' , args .gen_type )
103+
86104 os .makedirs (inst_train_folder , exist_ok = True )
87105
106+ anno_path = Path (ann_file ).parent
107+ _ , _ , verb_ids , noun_ids = generate_label_map (anno_path )
108+ conv_lst = generate_train_ann (ann_file ,
109+ verb_ids ,
110+ noun_ids ,
111+ gen_type = args .gen_type ,
112+ avion_prediction_path = args .avion_train_predictions )
113+
88114 # save it to a jsonl
89115 with open (os .path .join (inst_train_folder ,'train_convs_narration.jsonl' ), 'w' ) as f :
90116 for conv in conv_lst :
0 commit comments