Skip to content

Commit 93d5c43

Browse files
author
Ye Shaokai
committed
fixed multi node eval
1 parent 92053a1 commit 93d5c43

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

.vscode/launch.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@
297297
"--action_representation", "official_key",
298298
"--topk_predictions", "10",
299299
"--eval_steps", "1",
300-
"--vision_supervision", "all_newlines",
300+
"--vision_supervision", "three_tokens",
301301
"--action_types", "97,300,3806",
302302
"--n_narration", "5"
303303
],

llava/action/ek_eval.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,19 @@ def process_raw_pred(raw_pred):
3030
return raw_pred
3131

3232
def setup(rank, world_size):
33-
# Check if the process group is already initialized
3433
if not dist.is_initialized():
35-
# Initialize the process group if it hasn't been initialized yet
36-
os.environ['MASTER_ADDR'] = '127.0.0.1' # Replace with master node IP
37-
os.environ['MASTER_PORT'] = '29500' # Set a port for communication
34+
os.environ['MASTER_ADDR'] = '127.0.0.1'
35+
os.environ['MASTER_PORT'] = '29500'
3836

3937
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
4038
print(f"Process group initialized for rank {rank}")
4139

42-
# Set the GPU device based on rank
4340
local_rank = rank % torch.cuda.device_count()
4441
torch.cuda.set_device(local_rank)
4542
print(f"Using GPU {local_rank} for rank {rank}")
43+
44+
# Return the device
45+
return torch.device(f'cuda:{rank % torch.cuda.device_count()}')
4646

4747

4848
def datetime2sec(str):
@@ -187,7 +187,7 @@ def evaluate_on_EK100(eval_args,
187187

188188
world_size = int(os.environ['WORLD_SIZE'])
189189
rank = int(os.environ['RANK'])
190-
setup(rank, world_size)
190+
device = setup(rank, world_size)
191191

192192

193193
if model is not None:
@@ -248,6 +248,7 @@ def collate_fn(batch):
248248
collate_fn=collate_fn,
249249
sampler = sampler,
250250
batch_size=1,
251+
pin_memory = False,
251252
shuffle=False)
252253

253254
# Set up logging
@@ -275,7 +276,6 @@ def collate_fn(batch):
275276
pretrained = eval_args.llava_checkpoint
276277
tokenizer, model, image_processor, _ = prepare_llava(pretrained)
277278

278-
device = torch.device(f'cuda:{rank}')
279279

280280
global_avion_correct = torch.tensor(0.0, device=device)
281281
global_running_corrects = torch.tensor(0.0, device=device)

0 commit comments

Comments
 (0)