Skip to content

Commit 1db0cf2

Browse files
author
Ye Shaokai
committed
fixed distributed evaluation
1 parent ff72c19 commit 1db0cf2

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

action/ek_eval.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,9 @@ def __getitem__(self, i):
297297

298298
data = self.mc_generator.generate_multi_choice(label, self.topk_predictions)
299299

300-
return frames, data, time_meta
300+
dataset_size = len(self.samples)
301+
302+
return frames, data, time_meta, i
301303

302304

303305

@@ -502,27 +504,27 @@ def evaluate_on_EK100(eval_args,
502504

503505
if eval_args.action_predictions:
504506
with open(eval_args.action_predictions, 'r') as f:
505-
predictions = json.load(f)
506-
507-
507+
predictions = json.load(f)
508508

509-
avaion_correct = torch.tensor(0, device='cuda')
510-
running_corrects = torch.tensor(0, device='cuda')
511-
total_samples = torch.tensor(0, device='cuda')
509+
avaion_correct = 0
510+
running_corrects = 0
511+
total_samples = 0
512512

513-
for idx, (frames, mc_data, time_meta) in tqdm(enumerate(val_dataloader)):
513+
for idx, (frames, mc_data, time_meta, global_index) in tqdm(enumerate(val_dataloader)):
514+
515+
logger.info(f'Process {dist.get_rank()} got index {global_index}')
514516

515517
gt_name = mc_data['gt_answer_name'][0][0]
516-
518+
517519
if eval_args.action_predictions:
518-
mc_data = get_topk_predictions(predictions, idx, eval_args.topk_predictions)
520+
mc_data = get_topk_predictions(predictions, global_index.item(), eval_args.topk_predictions)
519521
avion_pred = mc_data['avion_pred']
520522
if gt_name == avion_pred:
521523
avaion_correct+=1
522524

523525
# we don't want to evaluate the whole thing
524526
# let's evaluate 1000 samples to get the complete picture
525-
if finish_early and idx>999:
527+
if finish_early and idx> (1000 / dist.get_world_size()):
526528
break
527529

528530
# Update running corrects and total samples
@@ -550,7 +552,7 @@ def evaluate_on_EK100(eval_args,
550552
if eval_args.action_predictions:
551553
avaion_accuracy = avaion_correct / total_samples
552554

553-
555+
dist.barrier()
554556
dist.all_reduce(running_corrects, op=dist.ReduceOp.SUM)
555557
dist.all_reduce(total_samples, op=dist.ReduceOp.SUM)
556558
if eval_args.action_predictions:

0 commit comments

Comments
 (0)