2424import torch .distributed as dist
2525
2626dist .init_process_group (backend = 'nccl' )
27+ rank = dist .get_rank ()
28+ torch .cuda .set_device (rank )
2729
2830def datetime2sec (str ):
2931 hh , mm , ss = str .split (':' )
@@ -506,14 +508,11 @@ def evaluate_on_EK100(eval_args,
506508 with open (eval_args .action_predictions , 'r' ) as f :
507509 predictions = json .load (f )
508510
509- avaion_correct = 0
510- running_corrects = 0
511- total_samples = 0
512-
513- for idx , (frames , mc_data , time_meta , global_index ) in tqdm (enumerate (val_dataloader )):
514-
515- logger .info (f'Process { dist .get_rank ()} got index { global_index } ' )
511+ avaion_correct = torch .tensor (0 , device = 'cuda' )
512+ running_corrects = torch .tensor (0 , device = 'cuda' )
513+ total_samples = torch .tensor (0 , device = 'cuda' )
516514
515+ for idx , (frames , mc_data , time_meta , global_index ) in tqdm (enumerate (val_dataloader )):
517516 gt_name = mc_data ['gt_answer_name' ][0 ][0 ]
518517
519518 if eval_args .action_predictions :
@@ -523,7 +522,7 @@ def evaluate_on_EK100(eval_args,
523522 avaion_correct += 1
524523
525524 # we don't want to evaluate the whole thing
526- # let's evaluate 1000 samples to get the complete picture
525+ # let's evaluate 1000 samples to get the complete picture
527526 if finish_early and idx > (1000 / dist .get_world_size ()):
528527 break
529528
0 commit comments