11
11
12
12
13
13
def ddp_setup ():
14
- torch .cuda .set_device (int (os .environ ["LOCAL_RANK" ]))
15
- init_process_group (backend = "nccl" )
14
+ rank = int (os .environ ["LOCAL_RANK" ])
15
+ if torch .accelerator .is_available ():
16
+ device = torch .device (f"{ torch .accelerator .current_accelerator ()} :{ rank } " )
17
+ torch .accelerator .set_device_index (rank )
18
+ print (f"Running on rank { rank } on device { device } " )
19
+ else :
20
+ device = torch .device ("cpu" )
21
+ print (f"Running on device { device } " )
22
+
23
+ backend = torch .distributed .get_default_backend_for_device (device )
24
+ torch .distributed .init_process_group (backend = backend , device_id = device )
25
+ return device
16
26
17
27
class Trainer :
18
28
def __init__ (
@@ -22,6 +32,7 @@ def __init__(
22
32
optimizer : torch .optim .Optimizer ,
23
33
save_every : int ,
24
34
snapshot_path : str ,
35
+ device : torch .device ,
25
36
) -> None :
26
37
self .local_rank = int (os .environ ["LOCAL_RANK" ])
27
38
self .global_rank = int (os .environ ["RANK" ])
@@ -31,14 +42,15 @@ def __init__(
31
42
self .save_every = save_every
32
43
self .epochs_run = 0
33
44
self .snapshot_path = snapshot_path
45
+ self .device = device
34
46
if os .path .exists (snapshot_path ):
35
47
print ("Loading snapshot" )
36
48
self ._load_snapshot (snapshot_path )
37
49
38
50
self .model = DDP (self .model , device_ids = [self .local_rank ])
39
51
40
52
def _load_snapshot (self , snapshot_path ):
41
- loc = f"cuda: { self .local_rank } "
53
+ loc = str ( self .device )
42
54
snapshot = torch .load (snapshot_path , map_location = loc )
43
55
self .model .load_state_dict (snapshot ["MODEL_STATE" ])
44
56
self .epochs_run = snapshot ["EPOCHS_RUN" ]
@@ -93,10 +105,10 @@ def prepare_dataloader(dataset: Dataset, batch_size: int):
93
105
94
106
95
107
def main (save_every : int , total_epochs : int , batch_size : int , snapshot_path : str = "snapshot.pt" ):
96
- ddp_setup ()
108
+ device = ddp_setup ()
97
109
dataset , model , optimizer = load_train_objs ()
98
110
train_data = prepare_dataloader (dataset , batch_size )
99
- trainer = Trainer (model , train_data , optimizer , save_every , snapshot_path )
111
+ trainer = Trainer (model , train_data , optimizer , save_every , snapshot_path , device )
100
112
trainer .train (total_epochs )
101
113
destroy_process_group ()
102
114
0 commit comments