|
8 | 8 | import torch |
9 | 9 | import argparse |
10 | 10 | import decord |
11 | | -from torch.utils.data import DataLoader |
| 11 | +from torch.utils.data import DataLoader, DistributedSampler |
12 | 12 | from tqdm import tqdm |
13 | 13 | from pathlib import Path |
14 | 14 | import sys |
|
21 | 21 | from action.utils import generate_label_map, MultiChoiceGenerator, match_answer, parse_avion_predictions |
22 | 22 | import copy |
23 | 23 | from collections import Counter |
| 24 | +import torch.distributed as dist |
| 25 | + |
| 26 | +dist.init_process_group(backend='nccl') |
24 | 27 |
|
25 | 28 | def datetime2sec(str): |
26 | 29 | hh, mm, ss = str.split(':') |
@@ -138,7 +141,8 @@ def video_loader(root, vid, ext, second, end_second, |
138 | 141 | time_meta['n_frames'] = res.shape[0] |
139 | 142 | all_frame_ids = np.concatenate(all_frame_ids, axis = 0) |
140 | 143 | frame_time = [e/fps for e in all_frame_ids] |
141 | | - frame_time = [f"{i:.2f}s" for i in frame_time] |
| 144 | + frame_time-= frame_time[0] |
| 145 | + frame_time = ",".join([f"{i:.2f}s" for i in frame_time]) |
142 | 146 | time_meta['frame_time'] = frame_time |
143 | 147 | assert res.shape[0] == clip_length, "{}, {}, {}, {}, {}, {}, {}".format(root, vid, second, end_second, res.shape[0], rel_frame_ids, frame_ids) |
144 | 148 | return res, time_meta |
@@ -342,7 +346,7 @@ def prepare_llava(pretrained): |
342 | 346 | if 'video' in pretrained: |
343 | 347 | overwrite_config = {'tie_word_embeddings': False, 'use_cache': True, "vocab_size": 152064} |
344 | 348 |
|
345 | | - print ('overwrite config', overwrite_config) |
| 349 | + |
346 | 350 | tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, |
347 | 351 | None, |
348 | 352 | model_name, |
@@ -404,9 +408,7 @@ def ensemble_llava_evaluation( |
404 | 408 | random.shuffle(options) |
405 | 409 | for idx, (option, letter) in enumerate(zip(options, letters)): |
406 | 410 | sep = option.index('.') |
407 | | - options[idx] = f'{letter}.{option[sep+1:]}' |
408 | | - rank0_print ('generated new option sequence') |
409 | | - rank0_print (options) |
| 411 | + options[idx] = f'{letter}.{option[sep+1:]}' |
410 | 412 |
|
411 | 413 | pred = llava_inference( |
412 | 414 | pretrained_name, |
@@ -466,10 +468,12 @@ def evaluate_on_EK100(eval_args, |
466 | 468 |
|
467 | 469 | ) |
468 | 470 |
|
469 | | - val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False) |
| 471 | + if dist.is_initialized(): |
| 472 | + sampler = DistributedSampler(val_dataset, shuffle=False) |
| 473 | + else: |
| 474 | + sampler = None |
470 | 475 |
|
471 | | - running_corrects = 0 |
472 | | - total_samples = 0 |
| 476 | + val_dataloader = DataLoader(val_dataset, sampler = sampler, batch_size=1, shuffle=False) |
473 | 477 |
|
474 | 478 | # Set up logging |
475 | 479 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filemode='w') |
@@ -500,8 +504,12 @@ def evaluate_on_EK100(eval_args, |
500 | 504 | with open(eval_args.action_predictions, 'r') as f: |
501 | 505 | predictions = json.load(f) |
502 | 506 |
|
503 | | - avaion_correct = 0 |
504 | | - |
| 507 | + |
| 508 | + |
| 509 | + avaion_correct = torch.tensor(0, device='cuda') |
| 510 | + running_corrects = torch.tensor(0, device='cuda') |
| 511 | + total_samples = torch.tensor(0, device='cuda') |
| 512 | + |
505 | 513 | for idx, (frames, mc_data, time_meta) in tqdm(enumerate(val_dataloader)): |
506 | 514 |
|
507 | 515 | gt_name = mc_data['gt_answer_name'][0][0] |
@@ -538,14 +546,28 @@ def evaluate_on_EK100(eval_args, |
538 | 546 | # Calculate and log running mean accuracy |
539 | 547 | running_accuracy = running_corrects / total_samples |
540 | 548 |
|
541 | | - logger.info(f'running accuracy: {running_accuracy:.4f}') |
| 549 | + logger.info(f'Process {dist.get_rank()} - Running accuracy: {running_accuracy:.4f}') |
542 | 550 | if eval_args.action_predictions: |
543 | 551 | avaion_accuracy = avaion_correct / total_samples |
544 | 552 |
|
545 | 553 |
|
546 | | - logger.info(f'Running avaion accuracy after {total_samples} samples: {avaion_accuracy:.4f}') |
547 | | - logger.info(f'Final accuracy: {running_accuracy:.4f}') |
548 | | - return running_accuracy |
| 554 | + dist.all_reduce(running_corrects, op=dist.ReduceOp.SUM) |
| 555 | + dist.all_reduce(total_samples, op=dist.ReduceOp.SUM) |
| 556 | + if eval_args.action_predictions: |
| 557 | + dist.all_reduce(avaion_correct, op=dist.ReduceOp.SUM) |
| 558 | + |
| 559 | + # Calculate global accuracy after reduction |
| 560 | + global_accuracy = running_corrects.item() / total_samples.item() |
| 561 | + if eval_args.action_predictions: |
| 562 | + global_avaion_accuracy = avaion_correct.item() / total_samples.item() |
| 563 | + |
| 564 | + # Ensure only the main process (rank 0) prints the final result |
| 565 | + if dist.get_rank() == 0: |
| 566 | + if eval_args.action_predictions: |
| 567 | + logger.info(f'Global Avaion Accuracy: {global_avaion_accuracy:.4f}') |
| 568 | + logger.info(f'Final Global Accuracy: {global_accuracy:.4f}') |
| 569 | + |
| 570 | + return global_accuracy |
549 | 571 |
|
550 | 572 |
|
551 | 573 | if __name__ == '__main__': |
|
0 commit comments