@@ -30,19 +30,19 @@ def process_raw_pred(raw_pred):
3030 return raw_pred
3131
3232def setup (rank , world_size ):
33- # Check if the process group is already initialized
3433 if not dist .is_initialized ():
35- # Initialize the process group if it hasn't been initialized yet
36- os .environ ['MASTER_ADDR' ] = '127.0.0.1' # Replace with master node IP
37- os .environ ['MASTER_PORT' ] = '29500' # Set a port for communication
34+ os .environ ['MASTER_ADDR' ] = '127.0.0.1'
35+ os .environ ['MASTER_PORT' ] = '29500'
3836
3937 dist .init_process_group (backend = "nccl" , rank = rank , world_size = world_size )
4038 print (f"Process group initialized for rank { rank } " )
4139
42- # Set the GPU device based on rank
4340 local_rank = rank % torch .cuda .device_count ()
4441 torch .cuda .set_device (local_rank )
4542 print (f"Using GPU { local_rank } for rank { rank } " )
43+
44+ # Return the device
45+ return torch .device (f'cuda:{ rank % torch .cuda .device_count ()} ' )
4646
4747
4848def datetime2sec (str ):
@@ -187,7 +187,7 @@ def evaluate_on_EK100(eval_args,
187187
188188 world_size = int (os .environ ['WORLD_SIZE' ])
189189 rank = int (os .environ ['RANK' ])
190- setup (rank , world_size )
190+ device = setup (rank , world_size )
191191
192192
193193 if model is not None :
@@ -248,6 +248,7 @@ def collate_fn(batch):
248248 collate_fn = collate_fn ,
249249 sampler = sampler ,
250250 batch_size = 1 ,
251+ pin_memory = False ,
251252 shuffle = False )
252253
253254 # Set up logging
@@ -275,7 +276,6 @@ def collate_fn(batch):
275276 pretrained = eval_args .llava_checkpoint
276277 tokenizer , model , image_processor , _ = prepare_llava (pretrained )
277278
278- device = torch .device (f'cuda:{ rank } ' )
279279
280280 global_avion_correct = torch .tensor (0.0 , device = device )
281281 global_running_corrects = torch .tensor (0.0 , device = device )
0 commit comments