1919import logging
2020from llava .utils import rank0_print
2121from action .utils import generate_label_map , MultiChoiceGenerator , match_answer , parse_avion_predictions
22+ from action .prediction_analysis import PredictionAnalysis
2223import copy
2324from collections import Counter
2425import torch .distributed as dist
@@ -224,6 +225,10 @@ def get_raw_item(
224225 fast_rcc = fast_rcc ,
225226 rcc_params = rcc_params ,
226227 jitter = is_training )
228+ time_meta ['start_second' ] = start_second
229+ time_meta ['end_second' ] = end_second
230+ time_meta ['fps' ] = fps
231+ time_meta ['vid_path' ] = vid_path
227232 return frames , '{}:{}' .format (verb , noun ), time_meta
228233 else :
229234 raise NotImplementedError
@@ -271,10 +276,13 @@ def __init__(
271276 self .verb_maps = verb_maps
272277 self .noun_maps = noun_maps
273278 self .vn_list = list (self .label_mapping .keys ())
279+
274280 self .labels = labels
275281 self .topk_predictions = topk_predictions
276282 self .ann_root = Path (metadata ).parent
277283 self .mc_generator = MultiChoiceGenerator (self .ann_root )
284+ self .rank = dist .get_rank ()
285+ self .prediction_analysis = PredictionAnalysis (f'prediction_analysis_buf_rank{ self .rank } .json' )
278286
279287 def __getitem__ (self , i ):
280288 frames , label , time_meta = self .get_raw_item (
@@ -299,8 +307,6 @@ def __getitem__(self, i):
299307
300308 data = self .mc_generator .generate_multi_choice (label , self .topk_predictions )
301309
302- dataset_size = len (self .samples )
303-
304310 return frames , data , time_meta , i
305311
306312
@@ -330,10 +336,11 @@ def get_args_parser():
330336 # llm size is type of string and can only be '7b' or '5b' etc.
331337 parser .add_argument ('--pretrained_name' , default = '' , type = str , help = 'the name in huggingface' )
332338 parser .add_argument ('--llava_num_frames' , default = 16 , type = int , help = 'number of frames for llava' )
333- ## avaion refinement
339+ ## avion refinement
334340 parser .add_argument ('--action_predictions' , default = None , type = str , help = 'path to action predictions' )
335341 parser .add_argument ('--topk_predictions' , default = 5 , type = int )
336342 parser .add_argument ('--llava_checkpoint' , default = None , type = str )
343+ parser .add_argument ('--early_stop' , default = None , type = int )
337344
338345 return parser
339346
@@ -438,7 +445,7 @@ def ensemble_llava_evaluation(
438445 rank0_print ('inspecting the counter' , counter )
439446 rank0_print ('most common' , counter .most_common (1 )[0 ][0 ])
440447
441- return match_answer (counter .most_common (1 )[0 ][0 ], gt_name )
448+ return match_answer (counter .most_common (1 )[0 ][0 ], gt_name ), counter . most_common ( 1 )[ 0 ][ 0 ]
442449
443450
444451
@@ -497,7 +504,7 @@ def evaluate_on_EK100(eval_args,
497504 print ('pretrained' , pretrained )
498505
499506 # so we know it's evaluation during training
500- finish_early = model is not None
507+ finish_early = False # model is not None
501508
502509 if model is None :
503510 if args .llava_checkpoint is not None :
@@ -508,26 +515,40 @@ def evaluate_on_EK100(eval_args,
508515 with open (eval_args .action_predictions , 'r' ) as f :
509516 predictions = json .load (f )
510517
511- avaion_correct = torch .tensor (0 , device = 'cuda' )
512- running_corrects = torch .tensor (0 , device = 'cuda' )
513- total_samples = torch .tensor (0 , device = 'cuda' )
518+ device = torch .device (f'cuda:{ rank } ' )
519+
520+ global_avion_correct = torch .tensor (0.0 , device = device )
521+ global_running_corrects = torch .tensor (0.0 , device = device )
522+ global_total_samples = torch .tensor (0.0 , device = device )
523+
514524
515525 for idx , (frames , mc_data , time_meta , global_index ) in tqdm (enumerate (val_dataloader )):
526+
527+ global_index = global_index .item ()
528+
516529 gt_name = mc_data ['gt_answer_name' ][0 ][0 ]
530+ local_avion_correct = torch .tensor (0.0 , device = device )
531+ local_running_corrects = torch .tensor (0.0 , device = device )
532+ local_total_samples = torch .tensor (0.0 , device = device )
517533
518534 if eval_args .action_predictions :
519- mc_data = get_topk_predictions (predictions , global_index . item () , eval_args .topk_predictions )
535+ mc_data = get_topk_predictions (predictions , global_index , eval_args .topk_predictions )
520536 avion_pred = mc_data ['avion_pred' ]
521537 if gt_name == avion_pred :
522- avaion_correct += 1
538+ local_avion_correct .add_ (1 )
539+ global_avion_correct .add_ (1 )
523540
524541 # we don't want to evaluate the whole thing
525542 # let's evaluate 1000 samples to get the complete picture
526543 if finish_early and idx > (1000 / dist .get_world_size ()):
527544 break
528545
546+ if eval_args .early_stop and idx > eval_args .early_stop :
547+ break
548+
529549 # Update running corrects and total samples
530- running_corrects += ensemble_llava_evaluation (
550+
551+ llava_correct , llava_pred = ensemble_llava_evaluation (
531552 eval_args .pretrained_name ,
532553 gt_name ,
533554 frames ,
@@ -541,33 +562,69 @@ def evaluate_on_EK100(eval_args,
541562 ensemble_k = 1 ,
542563 time_meta = time_meta ,
543564 is_test = not finish_early )
565+
566+ # log the predictions into prediciton analysis
567+
568+ # val_dataset.prediction_analysis.log(global_index,
569+ # llava_pred,
570+ # gt_name,
571+ # predictions[str(global_index)],
572+ # time_meta['start_second'].item(),
573+ # time_meta['end_second'].item(),
574+ # time_meta['vid_path'],
575+ # dataset_name = 'EK100')
576+
577+
578+
579+
580+ local_running_corrects .add_ (llava_correct )
581+ global_running_corrects .add_ (llava_correct )
544582
545- total_samples += 1
583+ local_total_samples .add_ (1 )
584+ global_total_samples .add_ (1 )
585+
586+ logger .info (f'Process { dist .get_rank ()} - local_total_samples: { local_total_samples :.4f} ' )
587+
588+ logger .info (f'Process { dist .get_rank ()} - loca_llava_correct: { llava_correct :.4f} ' )
589+
590+ logger .info (f'Process { dist .get_rank ()} - local_running_corrects: { local_running_corrects :.4f} ' )
591+
546592
547593 # Calculate and log running mean accuracy
548- running_accuracy = running_corrects / total_samples
594+ # dist.barrier()
595+ # dist.all_reduce(local_running_corrects, op=dist.ReduceOp.SUM)
596+ # dist.all_reduce(local_total_samples, op=dist.ReduceOp.SUM)
597+ # if eval_args.action_predictions:
598+ # dist.all_reduce(local_avion_correct, op=dist.ReduceOp.SUM)
599+ # dist.barrier()
600+ # # Calculate global accuracy after reduction
601+ # local_running_accuracy = local_running_corrects.item() / local_total_samples.item()
602+ # local_avion_accuracy = local_avion_correct.item() / local_total_samples.item()
603+
604+ # logger.info(f'Process {dist.get_rank()} - Running accuracy: {local_running_accuracy:.4f}')
605+ # logger.info(f'Process {dist.get_rank()} - AvionRunning accuracy: {local_avion_accuracy:.4f}')
549606
550- logger .info (f'Process { dist .get_rank ()} - Running accuracy: { running_accuracy :.4f} ' )
551- if eval_args .action_predictions :
552- avaion_accuracy = avaion_correct / total_samples
607+
553608
554609 dist .barrier ()
555- dist .all_reduce (running_corrects , op = dist .ReduceOp .SUM )
556- dist .all_reduce (total_samples , op = dist .ReduceOp .SUM )
610+ dist .all_reduce (global_running_corrects , op = dist .ReduceOp .SUM )
611+ dist .all_reduce (global_total_samples , op = dist .ReduceOp .SUM )
557612 if eval_args .action_predictions :
558- dist .all_reduce (avaion_correct , op = dist .ReduceOp .SUM )
613+ dist .all_reduce (global_avion_correct , op = dist .ReduceOp .SUM )
559614
560615 # Calculate global accuracy after reduction
561- global_accuracy = running_corrects .item () / total_samples .item ()
616+ global_accuracy = global_running_corrects .item () / global_total_samples .item ()
562617 if eval_args .action_predictions :
563- global_avaion_accuracy = avaion_correct .item () / total_samples .item ()
618+ global_avion_accuracy = global_avion_correct .item () / global_total_samples .item ()
564619
565620 # Ensure only the main process (rank 0) prints the final result
566621 if dist .get_rank () == 0 :
567622 if eval_args .action_predictions :
568- logger .info (f'Global Avaion Accuracy: { global_avaion_accuracy :.4f} ' )
623+ logger .info (f'Global Avion Accuracy: { global_avion_accuracy :.4f} ' )
569624 logger .info (f'Final Global Accuracy: { global_accuracy :.4f} ' )
570625
626+ #val_dataset.prediction_analysis.save()
627+
571628 return global_accuracy
572629
573630
0 commit comments