11import json
22import glob
3-
3+ import os
44class PredictionAnalysis :
55 """
66 We save data that can be used for ad-hoc analysis
@@ -19,8 +19,11 @@ class PredictionAnalysis:
1919 vid_path: ''
2020 }
2121 """
22- def __init__ (self , save_path ):
23- self .save_path = save_path
22+ def __init__ (self , save_folder = '.' , rank = 0 ):
23+ self .save_folder = save_folder
24+ self .rank = rank
25+ self .prefix = 'prediction_analysis_buf'
26+ self .save_path = os .path .join (save_folder , f'{ self .prefix } _rank{ rank } .json' )
2427 self .data = {}
2528 def log (self ,
2629 global_index ,
@@ -50,52 +53,50 @@ def save(self):
5053 json .dump (self .data , f , indent = 4 )
5154
5255
53- class Analysis :
54- """
55-
56- This same code should be applied to the training too.
57-
58- collect all the wrong top-1 prediction from avion
59- collect all the wrong top-1 prediction from llava
60-
61- Determine percentage of wrong llava prediction that has wrong verb only
62- Determine percentage of wrong llava prediction that has wrong noun only
63- Determine percentage of wrong llava prediciton that has both verb and noun wrong
64- Determine percentage of wrong llava prediction that was wrong because the answer not in the top k
65- """
66- pass
67-
68- def __init__ (self , prefix ):
69-
70- files = glob .glob (prefix + '*' )
71-
72- self .data = {}
73-
74- for file in files :
75- print ('loading pred checkpoint from: ' , file )
76- with open (file , 'r' ) as f :
77- _data = json .load (f )
78- self .data .update (_data )
56+ def load (self ):
57+ save_folder = self .save_folder
58+ if self .rank == 0 :
59+ files = glob .glob (os .path .join (save_folder ,self .prefix + '*' ))
60+ for file in files :
61+ print ('loading pred checkpoint from: ' , file )
62+ with open (file , 'r' ) as f :
63+ _data = json .load (f )
64+ self .data .update (_data )
7965
80- # add some assertion for number of keys in the data
66+ print ( sorted ( list ( self . data . keys ()), key = lambda x : int ( x )))
8167
8268 def wrong_verb (self ):
8369
8470 N = len (self .data )
71+ llava_wrong_verb_collections = []
72+ llava_wrong_noun_collections = []
73+ llava_wrong_verb_noun_collections = []
8574
86- wrong_verb_collections = []
87- wrong_noun_collections = []
88- wrong_verb_noun_collections = []
75+ avion_wrong_verb_collections = []
76+ avion_wrong_noun_collections = []
77+ avion_wrong_verb_noun_collections = []
8978
9079 wrong_llava_collections = []
9180 wrong_avion_collections = []
9281
93- indices = sorted (self .data .keys ())
82+ indices = sorted (list ( self .data .keys ()), key = lambda x : int ( x ))
9483
9584 for index in indices :
9685 items = self .data [index ]
97-
98-
86+ llava_pred = items ['llava_pred' ]
87+ gt_name = items ['gt_name' ]
88+ # only replacing the first :
89+ avion_pred = items ['avion_preds' ]['predictions' ][0 ].replace (':' , ' ' , 1 )
90+
91+ if llava_pred != gt_name :
92+ wrong_llava_collections .append ((llava_pred , gt_name ))
93+ if avion_pred != gt_name :
94+ # pred, gt
95+ wrong_avion_collections .append ((avion_pred , gt_name ))
96+
9997
10098if __name__ == '__main__' :
101- pass
99+
100+
101+ prediction_analysis = PredictionAnalysis (save_folder = '/storage-rcp-pure/upmwmathis_scratch/shaokai/LLaVA-NeXT' )
102+ prediction_analysis .load ()
0 commit comments