@@ -20,10 +20,10 @@ class PredictionAnalysis:
2020 vid_path: ''
2121 }
2222 """
23- def __init__ (self , save_folder = '.' , rank = 0 ):
23+ def __init__ (self , save_folder = '.' , rank = 0 , prefix = 'prediction_analysis_buf' ):
2424 self .save_folder = save_folder
2525 self .rank = rank
26- self .prefix = 'prediction_analysis_buf'
26+ self .prefix = prefix
2727 self .save_path = os .path .join (save_folder , f'{ self .prefix } _rank{ rank } .json' )
2828 self .data = {}
2929 def log (self ,
@@ -94,12 +94,7 @@ def analysis(self):
9494 avion_verb , avion_noun = avion_pred .split (' ' )
9595 gt_verb , gt_noun = gt_name .split (' ' )
9696
97- if llava_pred != gt_name :
98- if set (llava_pred ).intersection (set (gt_name )) == set (gt_name ):
99- print ('what is going on' )
100- print ('nooo' , llava_pred , gt_name )
101- #wrong_llava_collections.append((llava_pred, gt_name))
102- #print (llava_pred, gt_name)
97+ if llava_pred != gt_name :
10398 wrong_llava_collections [idx ] = 0
10499 else :
105100 wrong_llava_collections [idx ] = 1
@@ -155,6 +150,7 @@ def analysis(self):
155150if __name__ == '__main__' :
156151
157152
158- prediction_analysis = PredictionAnalysis (save_folder = '/storage-rcp-pure/upmwmathis_scratch/shaokai/LLaVA-NeXT' )
153+ prediction_analysis = PredictionAnalysis (save_folder = '/storage-rcp-pure/upmwmathis_scratch/shaokai/LLaVA-NeXT/llavavideo_avion_mc_top10_5epoch_preds' ,
154+ prefix = 'prediction_analysis_buf' )
159155 prediction_analysis .load ()
160156 prediction_analysis .analysis ()
0 commit comments