@@ -297,7 +297,9 @@ def __getitem__(self, i):
297297
298298 data = self .mc_generator .generate_multi_choice (label , self .topk_predictions )
299299
300- return frames , data , time_meta
300+ dataset_size = len (self .samples )
301+
302+ return frames , data , time_meta , i
301303
302304
303305
@@ -502,27 +504,27 @@ def evaluate_on_EK100(eval_args,
502504
503505 if eval_args .action_predictions :
504506 with open (eval_args .action_predictions , 'r' ) as f :
505- predictions = json .load (f )
506-
507-
507+ predictions = json .load (f )
508508
509- avaion_correct = torch . tensor ( 0 , device = 'cuda' )
510- running_corrects = torch . tensor ( 0 , device = 'cuda' )
511- total_samples = torch . tensor ( 0 , device = 'cuda' )
509+ avaion_correct = 0
510+ running_corrects = 0
511+ total_samples = 0
512512
513- for idx , (frames , mc_data , time_meta ) in tqdm (enumerate (val_dataloader )):
513+ for idx , (frames , mc_data , time_meta , global_index ) in tqdm (enumerate (val_dataloader )):
514+
515+ logger .info (f'Process { dist .get_rank ()} got index { global_index } ' )
514516
515517 gt_name = mc_data ['gt_answer_name' ][0 ][0 ]
516-
518+
517519 if eval_args .action_predictions :
518- mc_data = get_topk_predictions (predictions , idx , eval_args .topk_predictions )
520+ mc_data = get_topk_predictions (predictions , global_index . item () , eval_args .topk_predictions )
519521 avion_pred = mc_data ['avion_pred' ]
520522 if gt_name == avion_pred :
521523 avaion_correct += 1
522524
523525 # we don't want to evaluate the whole thing
524526 # let's evaluate 1000 samples to get the complete picture
525- if finish_early and idx > 999 :
527+ if finish_early and idx > ( 1000 / dist . get_world_size ()) :
526528 break
527529
528530 # Update running corrects and total samples
@@ -550,7 +552,7 @@ def evaluate_on_EK100(eval_args,
550552 if eval_args .action_predictions :
551553 avaion_accuracy = avaion_correct / total_samples
552554
553-
555+ dist . barrier ()
554556 dist .all_reduce (running_corrects , op = dist .ReduceOp .SUM )
555557 dist .all_reduce (total_samples , op = dist .ReduceOp .SUM )
556558 if eval_args .action_predictions :
0 commit comments