@@ -99,7 +99,7 @@ def get_uid_options_map_from_prediction_folder(uid, prediction_folder):
9999 options = value ["avion_preds" ]
100100 ret [uid ] = options
101101 return ret
102- def get_uid_official_map (uid , ann_file ):
102+ def get_uid_official_map (ann_file ):
103103 csv_reader = csv .reader (open (ann_file , 'r' ))
104104 _ = next (csv_reader )
105105 anno_root = Path (ann_file ).parent
@@ -124,6 +124,32 @@ def get_uid_official_map(uid, ann_file):
124124 ret [uid ] = official_key
125125 return ret
126126
127+ def get_uid_narration_map (ann_file ):
128+ csv_reader = csv .reader (open (ann_file , 'r' ))
129+ _ = next (csv_reader )
130+ anno_root = Path (ann_file ).parent
131+ labels , mapping_vn2narration , mapping_vn2act , verb_maps , noun_maps = generate_label_map (anno_root ,
132+ 'official_key' )
133+ ret = {}
134+ for idx , row in enumerate (csv_reader ):
135+ pid , vid = row [1 :3 ]
136+
137+ start_second , end_second = datetime2sec (row [4 ]), datetime2sec (row [5 ])
138+ start_second = round (float (start_second ),2 )
139+ end_second = round (float (end_second ),2 )
140+ vid_path = '{}/{}' .format (pid , vid )
141+ left = vid_path .split ('/' )[0 ]
142+ right = vid_path .split ('/' )[1 ]
143+ uid = f'{ left } -{ right } _{ start_second } _{ end_second } '
144+
145+ verb , noun = int (row [10 ]), int (row [12 ])
146+ gt_vn = '{}:{}' .format (verb , noun )
147+ narration = row [8 ]
148+ official_key = verb_maps [str (verb )] + ' ' + noun_maps [str (noun )]
149+ ret [uid ] = narration
150+ return ret
151+
152+
127153
128154def get_narration_by_uid (uid , ann_file ):
129155 csv_reader = csv .reader (open (ann_file , 'r' ))
@@ -177,7 +203,17 @@ def get_uid_official_map(uid, ann_file):
177203 official_key = verb_maps [str (verb )] + ' ' + noun_maps [str (noun )]
178204 ret [uid ] = official_key
179205 return ret
206+
207+
208+
209+
210+ def compare_caption_generation (chatgpt_file , llava_file ):
211+ "Do we have llava file for this yet?"
212+ pass
180213
214+ def compare_open_ended_question_answering (chatgpt_file , llava_file ):
215+ "Do we have llava file for this yet?"
216+ pass
181217
182218def search_llavaction_win (tim_chatgpt_file ,
183219 random_chatgpt_file ,
@@ -202,18 +238,7 @@ def search_llavaction_win(tim_chatgpt_file,
202238 if llavaction_pred [uid ]['pred' ] == llavaction_pred [uid ]['gt' ] and \
203239 tim_chatgpt_pred [uid ]['pred' ] != tim_chatgpt_pred [uid ]['gt' ] and \
204240 llava_pred [uid ]['pred' ] != llava_pred [uid ]['gt' ]:
205-
206- # print ('uid', uid)
207- # print ('gt', tim_chatgpt_pred[uid]['gt'])
208- # print ('tim_chatgpt_pred', tim_chatgpt_pred[uid]['pred'])
209- # print ('llava_pred', llava_pred[uid]['pred'])
210- # print ('llavaction_pred', llavaction_pred[uid]['pred'])
211- # print ('options', tim_chatgpt_options)
212- # print ('llava_options', llava_options)
213- # print ('llavaction_options', llavaction_options)
214- # print ('random_chatgpt_options', random_chatgpt_options)
215- # print ('----')
216- # get all these printed items in the results dictionary
241+
217242 results [uid ] = {'gt' : tim_chatgpt_pred [uid ]['gt' ],
218243 'tim_chatgpt_pred' : tim_chatgpt_pred [uid ]['pred' ],
219244 'llava_pred' : llava_pred [uid ]['pred' ],
@@ -226,33 +251,50 @@ def search_llavaction_win(tim_chatgpt_file,
226251 with open ('llavaction_win.json' , 'w' ) as f :
227252 json .dump (results , f , indent = 4 )
228253
229- def get_wrong_prediction_uids (prediction_folder ):
254+ def get_wrong_prediction_uids (prediction_folder , ann_file ):
230255 """
231256 look for where llava makes mistakes
232257 """
233258 files = glob .glob (os .path .join (prediction_folder , '*.json' ))
234-
259+ uid_narration_map = get_uid_narration_map (ann_file )
260+
235261 data = {}
236262 for file in files :
237263 with open (file , 'r' ) as f :
238264 data .update (json .load (f ))
239265 sorted_keys = sorted (data .keys (), key = lambda k : int (k ))
266+ results = {}
240267 for key in sorted_keys :
241268 value = data [key ]
242269 start_second = value ['start_second' ]
243270 end_second = value ['end_second' ]
244271 vid_path = value ['vid_path' ]
245272 gt_name = value ['gt_name' ]
273+ official_key = value ['gt_name' ]
246274 left = vid_path .split ('/' )[0 ]
247275 right = vid_path .split ('/' )[1 ]
248276 uid = f'{ left } -{ right } _{ start_second } _{ end_second } '
277+ narration = uid_narration_map [uid ]
278+
279+ if value ['gt_name' ] not in value ['avion_preds' ]:
280+ continue
281+
249282 if value ['llava_pred' ] != value ['gt_name' ]:
250283 print ('uid' , uid )
251284 print ('llava_pred' , value ['llava_pred' ])
252- print ('gt' , gt_name )
285+ print ('official key gt' , gt_name )
286+ print ('gt narration' , narration )
253287 print ('options' , value ['avion_preds' ])
254-
255-
288+ # put everything i printed in a dictionary
289+ results [uid ] = {'llava_pred' : value ['llava_pred' ],
290+ 'gt' : gt_name ,
291+ 'narration' : narration ,
292+ 'official_key' : official_key ,
293+ 'options' : value ['avion_preds' ]}
294+
295+ # write results to a file
296+ with open ('llava_gets_confused_by_key.json' , 'w' ) as f :
297+ json .dump (results , f , indent = 4 )
256298
257299def walk_through (ann_file ):
258300 csv_reader = csv .reader (open (ann_file , 'r' ))
@@ -283,20 +325,20 @@ def walk_through(ann_file):
283325 print ('narration' , narration )
284326 print ('----' )
285327 count += 1
286- print ( 'count' , count )
328+
287329
288330if __name__ == '__main__' :
289331 ann_file = '/data/shaokai/epic-kitchens-100-annotations/EPIC_100_validation.csv'
290- prediction_folder = '/data/shaokai/LLaVA-NeXT/tt_dev_7b_16f_top20_full_includes_tim /'
332+ prediction_folder = '/data/shaokai/predictions_for_vis/dev_7b_16f_top5_full_includes_tim /'
291333 #walk_through(ann_file)
292- # get_wrong_prediction_uids(prediction_folder)
334+ get_wrong_prediction_uids (prediction_folder , ann_file )
293335 root = '/data/shaokai/predictions_for_vis/'
294336 chatgpt_tim_file = os .path .join (root , 'gpt-4o-2024-08-06_tim_GT_random_narration_top5_8f_9668samples.json' )
295337 chatgpt_random_file = os .path .join (root , 'gpt-4o-2024-08-06_random_GT_random_narration_top5_8f_9668samples.json' )
296338 llava_zeroshot_folder = os .path .join (root , 'LLaVA_Video_7B' )
297339 llavaction_folder = os .path .join (root , 'LLaVAction_7B' )
298- search_llavaction_win (chatgpt_tim_file ,
299- chatgpt_random_file ,
300- llava_zeroshot_folder ,
301- llavaction_folder )
340+ # search_llavaction_win(chatgpt_tim_file,
341+ # chatgpt_random_file,
342+ # llava_zeroshot_folder,
343+ # llavaction_folder)
302344
0 commit comments