|
12 | 12 |
|
13 | 13 | def ddp_setup():
|
14 | 14 | rank = int(os.environ["LOCAL_RANK"])
|
15 |
| - if torch.accelerator.is_available(): |
16 |
| - device_type = torch.accelerator.current_accelerator() |
17 |
| - device: torch.device = torch.device(f"{device_type}:{rank}") |
18 |
| - torch.accelerator.device_index(rank) |
19 |
| - print(f"Running on rank {rank} on device {device}") |
20 |
| - backend = torch.distributed.get_default_backend_for_device(device) |
21 |
| - torch.distributed.init_process_group(backend=backend) |
22 |
| - return device_type |
23 |
| - else: |
24 |
| - device = torch.device("cpu") |
25 |
| - print(f"Running on device {device}") |
26 |
| - torch.distributed.init_process_group(backend="gloo") |
27 |
| - return device |
| 15 | + |
| 16 | + device = torch.device(f"{torch.accelerator.current_accelerator()}:{rank}") |
| 17 | + torch.accelerator.set_device_index(rank) |
| 18 | + print(f"Running on rank {rank} on device {device}") |
| 19 | + |
| 20 | + backend = torch.distributed.get_default_backend_for_device(rank) |
| 21 | + torch.distributed.init_process_group(backend=backend, rank=rank, device_id=rank) |
| 22 | + |
28 | 23 |
|
29 | 24 | class Trainer:
|
30 | 25 | def __init__(
|
@@ -52,7 +47,8 @@ def __init__(
|
52 | 47 | self.model = DDP(self.model, device_ids=[self.local_rank])
|
53 | 48 |
|
54 | 49 | def _load_snapshot(self, snapshot_path):
|
55 |
| - loc = str(self.device) |
| 50 | + loc = str(torch.accelerator.current_accelerator()) |
| 51 | + |
56 | 52 | snapshot = torch.load(snapshot_path, map_location=loc)
|
57 | 53 | self.model.load_state_dict(snapshot["MODEL_STATE"])
|
58 | 54 | self.epochs_run = snapshot["EPOCHS_RUN"]
|
@@ -118,8 +114,8 @@ def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str
|
118 | 114 | if __name__ == "__main__":
|
119 | 115 | import argparse
|
120 | 116 | parser = argparse.ArgumentParser(description='simple distributed training job')
|
121 |
| - parser.add_argument('total_epochs', type=int, help='Total epochs to train the model') |
122 |
| - parser.add_argument('save_every', type=int, help='How often to save a snapshot') |
| 117 | + parser.add_argument('total_epochs', default=50, type=int, help='Total epochs to train the model') |
| 118 | + parser.add_argument('save_every', default=5, type=int, help='How often to save a snapshot') |
123 | 119 | parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
|
124 | 120 | args = parser.parse_args()
|
125 | 121 |
|
|
0 commit comments