Skip to content

Commit cc9f51e

Browse files
authored
Adding torch accelerator to ddp-tutorial-series example
* Adding torch accelerator to ddp-tutorial-series example Signed-off-by: dggaytan <[email protected]> * Adding torch accelerator to ddp-tutorial-series example Signed-off-by: dggaytan <[email protected]> * Adding torch accelerator to ddp-tutorial-series example Signed-off-by: dggaytan <[email protected]> --------- Signed-off-by: dggaytan <[email protected]>
1 parent 152ca30 commit cc9f51e

File tree

7 files changed

+68
-54
lines changed

7 files changed

+68
-54
lines changed

distributed/ddp-tutorial-series/README.md

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,27 @@ Each code file extends upon the previous one. The series starts with a non-distr
1515
* [slurm/setup_pcluster_slurm.md](slurm/setup_pcluster_slurm.md): instructions to set up an AWS cluster
1616
* [slurm/config.yaml.template](slurm/config.yaml.template): configuration to set up an AWS cluster
1717
* [slurm/sbatch_run.sh](slurm/sbatch_run.sh): slurm script to launch the training job
18-
19-
20-
21-
18+
## Installation
19+
```
20+
pip install -r requirements.txt
21+
```
22+
## Running Examples
23+
For running the examples to run for 20 Epochs and save checkpoints every 5 Epochs, you can use the following command:
24+
### Single GPU
25+
```
26+
python single_gpu.py 20 5
27+
```
28+
### Multi-GPU
29+
```
30+
python multigpu.py 20 5
31+
```
32+
### Multi-GPU Torchrun
33+
```
34+
torchrun --nnodes=1 --nproc_per_node=4 multigpu_torchrun.py 20 5
35+
```
36+
### Multi-Node
37+
```
38+
torchrun --nnodes=2 --nproc_per_node=4 multinode.py 20 5
39+
```
40+
41+
For more details, check the [run_examples.sh](distributed/ddp-tutorial-series/run_examples.sh) script.

distributed/ddp-tutorial-series/multigpu.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,15 @@ def ddp_setup(rank, world_size):
1717
world_size: Total number of processes
1818
"""
1919
os.environ["MASTER_ADDR"] = "localhost"
20-
os.environ["MASTER_PORT"] = "12455"
20+
os.environ["MASTER_PORT"] = "12355"
2121

22+
device = torch.device(f"{torch.accelerator.current_accelerator()}:{rank}")
23+
torch.accelerator.set_device_index(rank)
24+
print(f"Running on rank {rank} on device {device}")
2225

23-
rank = int(os.environ["LOCAL_RANK"])
24-
if torch.accelerator.is_available():
25-
device_type = torch.accelerator.current_accelerator()
26-
device = torch.device(f"{device_type}:{rank}")
27-
torch.accelerator.device_index(rank)
28-
print(f"Running on rank {rank} on device {device}")
29-
else:
30-
device = torch.device("cpu")
31-
print(f"Running on device {device}")
32-
3326
backend = torch.distributed.get_default_backend_for_device(device)
27+
init_process_group(backend=backend, rank=rank, world_size=world_size)
28+
3429

3530
class Trainer:
3631
def __init__(
@@ -106,8 +101,8 @@ def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_s
106101
if __name__ == "__main__":
107102
import argparse
108103
parser = argparse.ArgumentParser(description='simple distributed training job')
109-
parser.add_argument('total_epochs', type=int, help='Total epochs to train the model')
110-
parser.add_argument('save_every', type=int, help='How often to save a snapshot')
104+
parser.add_argument('total_epochs', default=50, type=int, help='Total epochs to train the model')
105+
parser.add_argument('save_every', default=5, type=int, help='How often to save a snapshot')
111106
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
112107
args = parser.parse_args()
113108

distributed/ddp-tutorial-series/multigpu_torchrun.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,14 @@
1212

1313
def ddp_setup():
1414
rank = int(os.environ["LOCAL_RANK"])
15-
if torch.accelerator.is_available():
16-
device_type = torch.accelerator.current_accelerator()
17-
device = torch.device(f"{device_type}:{rank}")
18-
torch.accelerator.device_index(rank)
19-
print(f"Running on rank {rank} on device {device}")
20-
else:
21-
device = torch.device("cpu")
22-
print(f"Running on device {device}")
23-
24-
backend = torch.distributed.get_default_backend_for_device(device)
25-
torch.distributed.init_process_group(backend=backend, device_id=device)
26-
return device
2715

16+
device = torch.device(f"{torch.accelerator.current_accelerator()}:{rank}")
17+
torch.accelerator.set_device_index(rank)
18+
print(f"Running on rank {rank} on device {device}")
19+
20+
backend = torch.distributed.get_default_backend_for_device(rank)
21+
torch.distributed.init_process_group(backend=backend, rank=rank, device_id=rank)
22+
2823

2924
class Trainer:
3025
def __init__(
@@ -51,7 +46,9 @@ def __init__(
5146
self.model = DDP(self.model, device_ids=[self.gpu_id])
5247

5348
def _load_snapshot(self, snapshot_path):
54-
loc = str(self.device)
49+
50+
loc = str(torch.accelerator.current_accelerator())
51+
5552
snapshot = torch.load(snapshot_path, map_location=loc)
5653
self.model.load_state_dict(snapshot["MODEL_STATE"])
5754
self.epochs_run = snapshot["EPOCHS_RUN"]
@@ -117,8 +114,8 @@ def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str
117114
if __name__ == "__main__":
118115
import argparse
119116
parser = argparse.ArgumentParser(description='simple distributed training job')
120-
parser.add_argument('total_epochs', type=int, help='Total epochs to train the model')
121-
parser.add_argument('save_every', type=int, help='How often to save a snapshot')
117+
parser.add_argument('total_epochs', default=50, type=int, help='Total epochs to train the model')
118+
parser.add_argument('save_every', default=5, type=int, help='How often to save a snapshot')
122119
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
123120
args = parser.parse_args()
124121

distributed/ddp-tutorial-series/multinode.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,14 @@
1212

1313
def ddp_setup():
1414
rank = int(os.environ["LOCAL_RANK"])
15-
if torch.accelerator.is_available():
16-
device_type = torch.accelerator.current_accelerator()
17-
device: torch.device = torch.device(f"{device_type}:{rank}")
18-
torch.accelerator.device_index(rank)
19-
print(f"Running on rank {rank} on device {device}")
20-
backend = torch.distributed.get_default_backend_for_device(device)
21-
torch.distributed.init_process_group(backend=backend)
22-
return device_type
23-
else:
24-
device = torch.device("cpu")
25-
print(f"Running on device {device}")
26-
torch.distributed.init_process_group(backend="gloo")
27-
return device
15+
16+
device = torch.device(f"{torch.accelerator.current_accelerator()}:{rank}")
17+
torch.accelerator.set_device_index(rank)
18+
print(f"Running on rank {rank} on device {device}")
19+
20+
backend = torch.distributed.get_default_backend_for_device(rank)
21+
torch.distributed.init_process_group(backend=backend, rank=rank, device_id=rank)
22+
2823

2924
class Trainer:
3025
def __init__(
@@ -52,7 +47,8 @@ def __init__(
5247
self.model = DDP(self.model, device_ids=[self.local_rank])
5348

5449
def _load_snapshot(self, snapshot_path):
55-
loc = str(self.device)
50+
loc = str(torch.accelerator.current_accelerator())
51+
5652
snapshot = torch.load(snapshot_path, map_location=loc)
5753
self.model.load_state_dict(snapshot["MODEL_STATE"])
5854
self.epochs_run = snapshot["EPOCHS_RUN"]
@@ -118,8 +114,8 @@ def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str
118114
if __name__ == "__main__":
119115
import argparse
120116
parser = argparse.ArgumentParser(description='simple distributed training job')
121-
parser.add_argument('total_epochs', type=int, help='Total epochs to train the model')
122-
parser.add_argument('save_every', type=int, help='How often to save a snapshot')
117+
parser.add_argument('total_epochs', default=50, type=int, help='Total epochs to train the model')
118+
parser.add_argument('save_every', default=5, type=int, help='How often to save a snapshot')
123119
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
124120
args = parser.parse_args()
125121

distributed/ddp-tutorial-series/run_example.sh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
# num_gpus = num local gpus to use (must be at least 2). Default = 2
55

66
# samples to run include:
7-
# example.py
7+
8+
# multigpu_torchrun.py
9+
# multinode.py
810

911
echo "Launching ${1:-example.py} with ${2:-2} gpus"
10-
torchrun --nnodes=1 --nproc_per_node=${2:-2} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-example.py}
12+
torchrun --nnodes=1 --nproc_per_node=${2:-2} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-example.py} 10 1
13+

distributed/ddp-tutorial-series/single_gpu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ def main(device, total_epochs, save_every, batch_size):
7373
if __name__ == "__main__":
7474
import argparse
7575
parser = argparse.ArgumentParser(description='simple distributed training job')
76-
parser.add_argument('total_epochs', type=int, help='Total epochs to train the model')
77-
parser.add_argument('save_every', type=int, help='How often to save a snapshot')
76+
parser.add_argument('total_epochs', default=50, type=int, help='Total epochs to train the model')
77+
parser.add_argument('save_every', default=5, type=int, help='How often to save a snapshot')
7878
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
7979
args = parser.parse_args()
8080

81-
device = 0 # shorthand for cuda:0
81+
device = 0
8282
main(device, args.total_epochs, args.save_every, args.batch_size)

run_distributed_examples.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,16 @@ function distributed_tensor_parallelism() {
5151
}
5252

5353
function distributed_ddp-tutorial-series() {
54-
uv run bash run_example.sh multigpu.py || error "ddp tutorial series multigpu example failed"
54+
uv python multigpu.py 10 1 || error "ddp tutorial series multigpu example failed"
5555
uv run bash run_example.sh multigpu_torchrun.py || error "ddp tutorial series multigpu torchrun example failed"
5656
uv run bash run_example.sh multinode.py || error "ddp tutorial series multinode example failed"
57+
uv python single_gpu.py 10 1 || error "ddp tutorial series single gpu example failed"
58+
5759
}
5860

5961
function distributed_FSDP2() {
6062
uv run bash run_example.sh example.py || error "FSDP2 example failed"
63+
6164
}
6265

6366
function distributed_ddp() {

0 commit comments

Comments
 (0)