11
11
12
12
13
13
def ddp_setup ():
14
- torch .cuda .set_device (int (os .environ ["LOCAL_RANK" ]))
15
- init_process_group (backend = "nccl" )
14
+ rank = int (os .environ ["LOCAL_RANK" ])
15
+ if torch .accelerator .is_available ():
16
+ device_type = torch .accelerator .current_accelerator ()
17
+ device : torch .device = torch .device (f"{ device_type } :{ rank } " )
18
+ torch .accelerator .device_index (rank )
19
+ print (f"Running on rank { rank } on device { device } " )
20
+ backend = torch .distributed .get_default_backend_for_device (device )
21
+ torch .distributed .init_process_group (backend = backend )
22
+ return device_type
23
+ else :
24
+ device = torch .device ("cpu" )
25
+ print (f"Running on device { device } " )
26
+ torch .distributed .init_process_group (backend = "gloo" )
27
+ return device
16
28
17
29
class Trainer :
18
30
def __init__ (
@@ -22,6 +34,7 @@ def __init__(
22
34
optimizer : torch .optim .Optimizer ,
23
35
save_every : int ,
24
36
snapshot_path : str ,
37
+ device
25
38
) -> None :
26
39
self .local_rank = int (os .environ ["LOCAL_RANK" ])
27
40
self .global_rank = int (os .environ ["RANK" ])
@@ -31,14 +44,15 @@ def __init__(
31
44
self .save_every = save_every
32
45
self .epochs_run = 0
33
46
self .snapshot_path = snapshot_path
47
+ self .device = device
34
48
if os .path .exists (snapshot_path ):
35
49
print ("Loading snapshot" )
36
50
self ._load_snapshot (snapshot_path )
37
51
38
52
self .model = DDP (self .model , device_ids = [self .local_rank ])
39
53
40
54
def _load_snapshot (self , snapshot_path ):
41
- loc = f"cuda: { self .local_rank } "
55
+ loc = str ( self .device )
42
56
snapshot = torch .load (snapshot_path , map_location = loc )
43
57
self .model .load_state_dict (snapshot ["MODEL_STATE" ])
44
58
self .epochs_run = snapshot ["EPOCHS_RUN" ]
@@ -93,10 +107,10 @@ def prepare_dataloader(dataset: Dataset, batch_size: int):
93
107
94
108
95
109
def main (save_every : int , total_epochs : int , batch_size : int , snapshot_path : str = "snapshot.pt" ):
96
- ddp_setup ()
110
+ device = ddp_setup ()
97
111
dataset , model , optimizer = load_train_objs ()
98
112
train_data = prepare_dataloader (dataset , batch_size )
99
- trainer = Trainer (model , train_data , optimizer , save_every , snapshot_path )
113
+ trainer = Trainer (model , train_data , optimizer , save_every , snapshot_path , device )
100
114
trainer .train (total_epochs )
101
115
destroy_process_group ()
102
116
0 commit comments