11import json
22import glob
33import os
4+ import numpy as np
45class PredictionAnalysis :
56 """
67 We save data that can be used for ad-hoc analysis
@@ -24,7 +25,7 @@ def __init__(self, save_folder = '.', rank = 0):
2425 self .rank = rank
2526 self .prefix = 'prediction_analysis_buf'
2627 self .save_path = os .path .join (save_folder , f'{ self .prefix } _rank{ rank } .json' )
27- self .data = {}
28+ self .data = {}
2829 def log (self ,
2930 global_index ,
3031 llava_pred ,
@@ -62,10 +63,11 @@ def load(self):
6263 with open (file , 'r' ) as f :
6364 _data = json .load (f )
6465 self .data .update (_data )
66+ print ('length' , len (self .data ))
67+ assert len (self .data ) == 9668
68+ #print (sorted(list(self.data.keys()), key = lambda x: int(x)))
6569
66- print (sorted (list (self .data .keys ()), key = lambda x : int (x )))
67-
68- def wrong_verb (self ):
70+ def analysis (self ):
6971
7072 N = len (self .data )
7173 llava_wrong_verb_collections = []
@@ -76,27 +78,83 @@ def wrong_verb(self):
7678 avion_wrong_noun_collections = []
7779 avion_wrong_verb_noun_collections = []
7880
79- wrong_llava_collections = []
80- wrong_avion_collections = []
81+ wrong_llava_collections = [0 ] * N
82+ wrong_avion_collections = [0 ] * N
8183
8284 indices = sorted (list (self .data .keys ()), key = lambda x : int (x ))
8385
84- for index in indices :
86+ for idx , index in enumerate ( indices ) :
8587 items = self .data [index ]
8688 llava_pred = items ['llava_pred' ]
8789 gt_name = items ['gt_name' ]
8890 # only replacing the first :
8991 avion_pred = items ['avion_preds' ]['predictions' ][0 ].replace (':' , ' ' , 1 )
9092
93+ llava_verb , llava_noun = llava_pred .split (' ' )
94+ avion_verb , avion_noun = avion_pred .split (' ' )
95+ gt_verb , gt_noun = gt_name .split (' ' )
96+
9197 if llava_pred != gt_name :
92- wrong_llava_collections .append ((llava_pred , gt_name ))
98+ if set (llava_pred ).intersection (set (gt_name )) == set (gt_name ):
99+ print ('what is going on' )
100+ print ('nooo' , llava_pred , gt_name )
101+ #wrong_llava_collections.append((llava_pred, gt_name))
102+ #print (llava_pred, gt_name)
103+ wrong_llava_collections [idx ] = 0
104+ else :
105+ wrong_llava_collections [idx ] = 1
93106 if avion_pred != gt_name :
94- # pred, gt
95- wrong_avion_collections .append ((avion_pred , gt_name ))
107+ wrong_avion_collections [idx ] = 0
108+ else :
109+ wrong_avion_collections [idx ] = 1
110+
96111
112+ if llava_verb == gt_verb and llava_noun != gt_noun :
113+ llava_wrong_noun_collections .append ((llava_pred , gt_name ))
114+ if llava_noun == gt_noun and llava_verb != gt_verb :
115+ llava_wrong_verb_collections .append ((llava_pred , gt_name ))
116+ if llava_noun != gt_noun and llava_verb != gt_verb :
117+ llava_wrong_verb_noun_collections .append ((llava_pred , gt_name ))
118+
119+ if avion_verb == gt_verb and avion_noun != gt_noun :
120+ avion_wrong_noun_collections .append ((avion_pred , gt_name ))
121+ if avion_noun == gt_noun and avion_verb != gt_verb :
122+ avion_wrong_verb_collections .append ((avion_pred , gt_name ))
123+ if avion_noun != gt_noun and avion_verb != gt_verb :
124+ avion_wrong_verb_noun_collections .append ((avion_pred , gt_name ))
125+
126+ wrong_llava_collections = np .array (wrong_llava_collections )
127+ wrong_avion_collections = np .array (wrong_avion_collections )
128+ llava_wrong_noun_collections = np .array (llava_wrong_noun_collections )
129+ llava_wrong_verb_collections = np .array (llava_wrong_verb_collections )
130+ llava_wrong_verb_noun_collections = np .array (llava_wrong_verb_noun_collections )
131+ avion_wrong_noun_collections = np .array (avion_wrong_noun_collections )
132+ avion_wrong_verb_collections = np .array (avion_wrong_verb_collections )
133+ avion_wrong_verb_noun_collections = np .array (avion_wrong_verb_noun_collections )
134+
135+ # first, the correlation between avion and llava
136+ correlation = np .corrcoef (wrong_llava_collections , wrong_avion_collections )[0 , 1 ]
137+
138+ print ("Correlation:" , correlation )
139+
140+ print ('llava top1 action accuracy {:.3f}' .format (np .sum (wrong_llava_collections == 1 ) / len (wrong_llava_collections )))
141+ print ('avion top1 action accuracy {:.3f}' .format (np .sum (wrong_avion_collections == 1 ) / len (wrong_avion_collections )))
142+
143+ print ('llava percentage of wrong noun {:.2f}' .format (len (llava_wrong_noun_collections ) / np .sum (wrong_llava_collections == 0 )))
144+ print ('llava percentage of wrong verb {:.2f}' .format (len (llava_wrong_verb_collections ) / np .sum (wrong_llava_collections == 0 )))
145+ print ('llava percentage of both verb noun wrong {:.2f}' .format (len (llava_wrong_verb_noun_collections ) / np .sum (wrong_llava_collections == 0 )))
146+
147+
148+ print ('avion percentage of wrong noun {:.2f}' .format (len (avion_wrong_noun_collections ) / np .sum (wrong_avion_collections == 0 )))
149+ print ('avion percentage of wrong verb {:.2f}' .format (len (avion_wrong_verb_collections ) / np .sum (wrong_avion_collections == 0 )))
150+ print ('avion percentage of both verb noun wrong {:.2f}' .format (len (avion_wrong_verb_noun_collections ) / np .sum (wrong_avion_collections == 0 )))
151+
152+
153+
97154
98155if __name__ == '__main__' :
99156
100157
101158 prediction_analysis = PredictionAnalysis (save_folder = '/storage-rcp-pure/upmwmathis_scratch/shaokai/LLaVA-NeXT' )
102159 prediction_analysis .load ()
160+ prediction_analysis .analysis ()
0 commit comments