|
24 | 24 | from collections import Counter |
25 | 25 | import torch.distributed as dist |
26 | 26 |
|
27 | | -if not dist.is_initialized(): |
28 | | - dist.init_process_group(backend='nccl') |
29 | | - rank = dist.get_rank() |
30 | | - torch.cuda.set_device(rank) |
| 27 | + |
| 28 | + |
| 29 | + |
| 30 | +def setup(rank, world_size): |
| 31 | + # Check if the process group is already initialized |
| 32 | + if not dist.is_initialized(): |
| 33 | + # Initialize the process group if it hasn't been initialized yet |
| 34 | + os.environ['MASTER_ADDR'] = '127.0.0.1' # Replace with master node IP |
| 35 | + os.environ['MASTER_PORT'] = '29500' # Set a port for communication |
| 36 | + |
| 37 | + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) |
| 38 | + print(f"Process group initialized for rank {rank}") |
| 39 | + |
| 40 | + # Set the GPU device based on rank |
| 41 | + local_rank = rank % torch.cuda.device_count() |
| 42 | + torch.cuda.set_device(local_rank) |
| 43 | + print(f"Using GPU {local_rank} for rank {rank}") |
| 44 | + |
31 | 45 |
|
32 | 46 | def datetime2sec(str): |
33 | 47 | hh, mm, ss = str.split(':') |
@@ -318,6 +332,11 @@ def evaluate_on_EK100(eval_args, |
318 | 332 | tokenizer= None, |
319 | 333 | image_processor= None): |
320 | 334 |
|
| 335 | + world_size = int(os.environ['WORLD_SIZE']) |
| 336 | + rank = int(os.environ['RANK']) |
| 337 | + setup(rank, world_size) |
| 338 | + |
| 339 | + |
321 | 340 | if model is not None: |
322 | 341 | image_processor = model.get_vision_tower().image_processor |
323 | 342 |
|
@@ -397,7 +416,7 @@ def evaluate_on_EK100(eval_args, |
397 | 416 | local_total_samples = torch.tensor(0.0, device=device) |
398 | 417 |
|
399 | 418 | if eval_args.action_predictions: |
400 | | - mc_data = create_multi_choice_from_avion_predictions(predictions[global_index], eval_args.topk_predictions) |
| 419 | + mc_data = create_multi_choice_from_avion_predictions(predictions[str(global_index)]['predictions'], eval_args.topk_predictions) |
401 | 420 | avion_pred = mc_data['avion_pred'] |
402 | 421 | if gt_name == avion_pred: |
403 | 422 | local_avion_correct.add_(1) |
|
0 commit comments