Skip to content

Commit 1934163

Browse files
author
Ye Shaokai
committed
fixed one more issue
1 parent 1db0cf2 commit 1934163

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

action/ek_eval.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import torch.distributed as dist
2525

2626
dist.init_process_group(backend='nccl')
27+
rank = dist.get_rank()
28+
torch.cuda.set_device(rank)
2729

2830
def datetime2sec(str):
2931
hh, mm, ss = str.split(':')
@@ -506,14 +508,11 @@ def evaluate_on_EK100(eval_args,
506508
with open(eval_args.action_predictions, 'r') as f:
507509
predictions = json.load(f)
508510

509-
avaion_correct = 0
510-
running_corrects = 0
511-
total_samples = 0
512-
513-
for idx, (frames, mc_data, time_meta, global_index) in tqdm(enumerate(val_dataloader)):
514-
515-
logger.info(f'Process {dist.get_rank()} got index {global_index}')
511+
avaion_correct = torch.tensor(0, device='cuda')
512+
running_corrects = torch.tensor(0, device='cuda')
513+
total_samples = torch.tensor(0, device='cuda')
516514

515+
for idx, (frames, mc_data, time_meta, global_index) in tqdm(enumerate(val_dataloader)):
517516
gt_name = mc_data['gt_answer_name'][0][0]
518517

519518
if eval_args.action_predictions:
@@ -523,7 +522,7 @@ def evaluate_on_EK100(eval_args,
523522
avaion_correct+=1
524523

525524
# we don't want to evaluate the whole thing
526-
# let's evaluate 1000 samples to get the complete picture
525+
# let's evaluate 1000 samples to get the complete picture
527526
if finish_early and idx> (1000 / dist.get_world_size()):
528527
break
529528

0 commit comments

Comments
 (0)