@@ -17,24 +17,20 @@ def ddp_setup(rank, world_size):
17
17
world_size: Total number of processes
18
18
"""
19
19
os .environ ["MASTER_ADDR" ] = "localhost"
20
- os .environ ["MASTER_PORT" ] = "12453 "
20
+ os .environ ["MASTER_PORT" ] = "12455 "
21
21
22
22
23
+ rank = int (os .environ ["LOCAL_RANK" ])
23
24
if torch .accelerator .is_available ():
24
25
device_type = torch .accelerator .current_accelerator ()
25
- torch .accelerator .set_device_idx (rank )
26
- device : torch .device = torch .device (f"{ device_type } :{ rank } " )
26
+ device = torch .device (f"{ device_type } :{ rank } " )
27
27
torch .accelerator .device_index (rank )
28
28
print (f"Running on rank { rank } on device { device } " )
29
- backend = torch .distributed .get_default_backend_for_device (device )
30
- torch .distributed .init_process_group (backend = backend , rank = rank , world_size = world_size , device_id = device )
31
29
else :
32
30
device = torch .device ("cpu" )
33
31
print (f"Running on device { device } " )
34
- torch .distributed .init_process_group (backend = "gloo" , device_id = device )
35
32
36
- # torch.cuda.set_device(rank)
37
- # init_process_group(backend="xccl", rank=rank, world_size=world_size)
33
+ backend = torch .distributed .get_default_backend_for_device (device )
38
34
39
35
class Trainer :
40
36
def __init__ (
@@ -116,5 +112,4 @@ def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_s
116
112
args = parser .parse_args ()
117
113
118
114
world_size = torch .accelerator .device_count ()
119
- print (world_size )
120
115
mp .spawn (main , args = (world_size , args .save_every , args .total_epochs , args .batch_size ), nprocs = world_size )
0 commit comments