Skip to content

Commit 07747e4

Browse files
committed
Merge branch 'shaokai/dev' of github.com:HaozheQi/LLaVA-NeXT into shaokai/dev
2 parents ce4129e + ccf3ff2 commit 07747e4

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

action/ek_eval.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,24 @@
2424
from collections import Counter
2525
import torch.distributed as dist
2626

27-
if not dist.is_initialized():
28-
dist.init_process_group(backend='nccl')
29-
rank = dist.get_rank()
30-
torch.cuda.set_device(rank)
27+
28+
29+
30+
def setup(rank, world_size):
31+
# Check if the process group is already initialized
32+
if not dist.is_initialized():
33+
# Initialize the process group if it hasn't been initialized yet
34+
os.environ['MASTER_ADDR'] = '127.0.0.1' # Replace with master node IP
35+
os.environ['MASTER_PORT'] = '29500' # Set a port for communication
36+
37+
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
38+
print(f"Process group initialized for rank {rank}")
39+
40+
# Set the GPU device based on rank
41+
local_rank = rank % torch.cuda.device_count()
42+
torch.cuda.set_device(local_rank)
43+
print(f"Using GPU {local_rank} for rank {rank}")
44+
3145

3246
def datetime2sec(str):
3347
hh, mm, ss = str.split(':')
@@ -318,6 +332,11 @@ def evaluate_on_EK100(eval_args,
318332
tokenizer= None,
319333
image_processor= None):
320334

335+
world_size = int(os.environ['WORLD_SIZE'])
336+
rank = int(os.environ['RANK'])
337+
setup(rank, world_size)
338+
339+
321340
if model is not None:
322341
image_processor = model.get_vision_tower().image_processor
323342

@@ -397,7 +416,7 @@ def evaluate_on_EK100(eval_args,
397416
local_total_samples = torch.tensor(0.0, device=device)
398417

399418
if eval_args.action_predictions:
400-
mc_data = create_multi_choice_from_avion_predictions(predictions[global_index], eval_args.topk_predictions)
419+
mc_data = create_multi_choice_from_avion_predictions(predictions[str(global_index)]['predictions'], eval_args.topk_predictions)
401420
avion_pred = mc_data['avion_pred']
402421
if gt_name == avion_pred:
403422
local_avion_correct.add_(1)

0 commit comments

Comments
 (0)