Skip to content

Commit 2a02a4f

Browse files
author
Ye Shaokai
committed
fixed double init
1 parent 3e4bba9 commit 2a02a4f

File tree

2 files changed

+52
-53
lines changed

2 files changed

+52
-53
lines changed

action/ek_eval.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
from collections import Counter
2525
import torch.distributed as dist
2626

27-
dist.init_process_group(backend='nccl')
27+
if not dist.is_initialized():
28+
dist.init_process_group(backend='nccl')
2829
rank = dist.get_rank()
2930
torch.cuda.set_device(rank)
3031

@@ -282,7 +283,7 @@ def __init__(
282283
self.ann_root = Path(metadata).parent
283284
self.mc_generator = MultiChoiceGenerator(self.ann_root)
284285
self.rank = dist.get_rank()
285-
self.prediction_analysis = PredictionAnalysis(f'prediction_analysis_buf_rank{self.rank}.json')
286+
self.prediction_analysis = PredictionAnalysis(rank = self.rank)
286287

287288
def __getitem__(self, i):
288289
frames, label, time_meta = self.get_raw_item(
@@ -340,7 +341,7 @@ def get_args_parser():
340341
parser.add_argument('--action_predictions', default=None, type=str, help='path to action predictions')
341342
parser.add_argument('--topk_predictions', default = 5, type =int)
342343
parser.add_argument('--llava_checkpoint', default = None, type = str)
343-
parser.add_argument('--early_stop', default = None, type = int)
344+
344345

345346
return parser
346347

@@ -542,10 +543,7 @@ def evaluate_on_EK100(eval_args,
542543
# let's evaluate 1000 samples to get the complete picture
543544
if finish_early and idx> (1000 / dist.get_world_size()):
544545
break
545-
546-
if eval_args.early_stop and idx > eval_args.early_stop:
547-
break
548-
546+
549547
# Update running corrects and total samples
550548

551549
llava_correct, llava_pred = ensemble_llava_evaluation(
@@ -565,14 +563,14 @@ def evaluate_on_EK100(eval_args,
565563

566564
# log the predictions into prediciton analysis
567565

568-
# val_dataset.prediction_analysis.log(global_index,
569-
# llava_pred,
570-
# gt_name,
571-
# predictions[str(global_index)],
572-
# time_meta['start_second'].item(),
573-
# time_meta['end_second'].item(),
574-
# time_meta['vid_path'],
575-
# dataset_name = 'EK100')
566+
val_dataset.prediction_analysis.log(global_index,
567+
llava_pred,
568+
gt_name,
569+
predictions[str(global_index)],
570+
time_meta['start_second'].item(),
571+
time_meta['end_second'].item(),
572+
time_meta['vid_path'],
573+
dataset_name = 'EK100')
576574

577575

578576

@@ -623,7 +621,7 @@ def evaluate_on_EK100(eval_args,
623621
logger.info(f'Global Avion Accuracy: {global_avion_accuracy:.4f}')
624622
logger.info(f'Final Global Accuracy: {global_accuracy:.4f}')
625623

626-
#val_dataset.prediction_analysis.save()
624+
val_dataset.prediction_analysis.save()
627625

628626
return global_accuracy
629627

action/prediction_analysis.py

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import glob
3-
3+
import os
44
class PredictionAnalysis:
55
"""
66
We save data that can be used for ad-hoc analysis
@@ -19,8 +19,11 @@ class PredictionAnalysis:
1919
vid_path: ''
2020
}
2121
"""
22-
def __init__(self, save_path):
23-
self.save_path = save_path
22+
def __init__(self, save_folder = '.', rank = 0):
23+
self.save_folder = save_folder
24+
self.rank = rank
25+
self.prefix = 'prediction_analysis_buf'
26+
self.save_path = os.path.join(save_folder, f'{self.prefix}_rank{rank}.json')
2427
self.data = {}
2528
def log(self,
2629
global_index,
@@ -50,52 +53,50 @@ def save(self):
5053
json.dump(self.data, f, indent = 4)
5154

5255

53-
class Analysis:
54-
"""
55-
56-
This same code should be applied to the training too.
57-
58-
collect all the wrong top-1 prediction from avion
59-
collect all the wrong top-1 prediction from llava
60-
61-
Determine percentage of wrong llava prediction that has wrong verb only
62-
Determine percentage of wrong llava prediction that has wrong noun only
63-
Determine percentage of wrong llava prediciton that has both verb and noun wrong
64-
Determine percentage of wrong llava prediction that was wrong because the answer not in the top k
65-
"""
66-
pass
67-
68-
def __init__(self, prefix):
69-
70-
files = glob.glob(prefix + '*')
71-
72-
self.data = {}
73-
74-
for file in files:
75-
print ('loading pred checkpoint from: ', file)
76-
with open(file, 'r') as f:
77-
_data = json.load(f)
78-
self.data.update(_data)
56+
def load(self):
57+
save_folder = self.save_folder
58+
if self.rank == 0:
59+
files = glob.glob(os.path.join(save_folder,self.prefix + '*'))
60+
for file in files:
61+
print ('loading pred checkpoint from: ', file)
62+
with open(file, 'r') as f:
63+
_data = json.load(f)
64+
self.data.update(_data)
7965

80-
# add some assertion for number of keys in the data
66+
print (sorted(list(self.data.keys()), key = lambda x: int(x)))
8167

8268
def wrong_verb(self):
8369

8470
N = len(self.data)
71+
llava_wrong_verb_collections = []
72+
llava_wrong_noun_collections = []
73+
llava_wrong_verb_noun_collections = []
8574

86-
wrong_verb_collections = []
87-
wrong_noun_collections = []
88-
wrong_verb_noun_collections = []
75+
avion_wrong_verb_collections = []
76+
avion_wrong_noun_collections = []
77+
avion_wrong_verb_noun_collections = []
8978

9079
wrong_llava_collections = []
9180
wrong_avion_collections = []
9281

93-
indices = sorted(self.data.keys())
82+
indices = sorted(list(self.data.keys()), key = lambda x: int(x))
9483

9584
for index in indices:
9685
items = self.data[index]
97-
98-
86+
llava_pred = items['llava_pred']
87+
gt_name = items['gt_name']
88+
# only replacing the first :
89+
avion_pred = items['avion_preds']['predictions'][0].replace(':', ' ', 1)
90+
91+
if llava_pred != gt_name:
92+
wrong_llava_collections.append((llava_pred, gt_name))
93+
if avion_pred!= gt_name:
94+
# pred, gt
95+
wrong_avion_collections.append((avion_pred, gt_name))
96+
9997

10098
if __name__ == '__main__':
101-
pass
99+
100+
101+
prediction_analysis = PredictionAnalysis(save_folder = '/storage-rcp-pure/upmwmathis_scratch/shaokai/LLaVA-NeXT')
102+
prediction_analysis.load()

0 commit comments

Comments
 (0)