Skip to content

Commit 3e2c3ae

Browse files
committed
Adding torch accelerator to ddp-tutorial-series example
Signed-off-by: dggaytan <[email protected]>
1 parent d47f0f3 commit 3e2c3ae

File tree

6 files changed

+65
-14
lines changed

6 files changed

+65
-14
lines changed

distributed/ddp-tutorial-series/multigpu.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,18 @@ def ddp_setup(rank, world_size):
1818
"""
1919
os.environ["MASTER_ADDR"] = "localhost"
2020
os.environ["MASTER_PORT"] = "12355"
21-
torch.cuda.set_device(rank)
22-
init_process_group(backend="nccl", rank=rank, world_size=world_size)
21+
22+
rank = int(os.environ["LOCAL_RANK"])
23+
if torch.accelerator.is_available():
24+
device = torch.device(f"{torch.accelerator.current_accelerator()}:{rank}")
25+
torch.accelerator.set_device_index(rank)
26+
print(f"Running on rank {rank} on device {device}")
27+
else:
28+
device = torch.device("cpu")
29+
print(f"Running on device {device}")
30+
31+
backend = torch.distributed.get_default_backend_for_device(device)
32+
init_process_group(backend=backend, rank=rank, world_size=world_size)
2333

2434
class Trainer:
2535
def __init__(
@@ -100,5 +110,5 @@ def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_s
100110
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
101111
args = parser.parse_args()
102112

103-
world_size = torch.cuda.device_count()
113+
world_size = torch.accelerator.device_count()
104114
mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size)

distributed/ddp-tutorial-series/multigpu_torchrun.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,19 @@
1111

1212

1313
def ddp_setup():
14-
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
15-
init_process_group(backend="nccl")
14+
rank = int(os.environ["LOCAL_RANK"])
15+
if torch.accelerator.is_available():
16+
device = torch.device(f"{torch.accelerator.current_accelerator()}:{rank}")
17+
torch.accelerator.set_device_index(rank)
18+
print(f"Running on rank {rank} on device {device}")
19+
else:
20+
device = torch.device("cpu")
21+
print(f"Running on device {device}")
22+
23+
backend = torch.distributed.get_default_backend_for_device(device)
24+
torch.distributed.init_process_group(backend=backend, device_id=device)
25+
return device
26+
1627

1728
class Trainer:
1829
def __init__(
@@ -22,6 +33,7 @@ def __init__(
2233
optimizer: torch.optim.Optimizer,
2334
save_every: int,
2435
snapshot_path: str,
36+
device: torch.device,
2537
) -> None:
2638
self.gpu_id = int(os.environ["LOCAL_RANK"])
2739
self.model = model.to(self.gpu_id)
@@ -30,14 +42,15 @@ def __init__(
3042
self.save_every = save_every
3143
self.epochs_run = 0
3244
self.snapshot_path = snapshot_path
45+
self.device = device
3346
if os.path.exists(snapshot_path):
3447
print("Loading snapshot")
3548
self._load_snapshot(snapshot_path)
3649

3750
self.model = DDP(self.model, device_ids=[self.gpu_id])
3851

3952
def _load_snapshot(self, snapshot_path):
40-
loc = f"cuda:{self.gpu_id}"
53+
loc = str(self.device)
4154
snapshot = torch.load(snapshot_path, map_location=loc)
4255
self.model.load_state_dict(snapshot["MODEL_STATE"])
4356
self.epochs_run = snapshot["EPOCHS_RUN"]
@@ -92,10 +105,10 @@ def prepare_dataloader(dataset: Dataset, batch_size: int):
92105

93106

94107
def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"):
95-
ddp_setup()
108+
device = ddp_setup()
96109
dataset, model, optimizer = load_train_objs()
97110
train_data = prepare_dataloader(dataset, batch_size)
98-
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path)
111+
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path, device)
99112
trainer.train(total_epochs)
100113
destroy_process_group()
101114

distributed/ddp-tutorial-series/multinode.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,18 @@
1111

1212

1313
def ddp_setup():
14-
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
15-
init_process_group(backend="nccl")
14+
rank = int(os.environ["LOCAL_RANK"])
15+
if torch.accelerator.is_available():
16+
device = torch.device(f"{torch.accelerator.current_accelerator()}:{rank}")
17+
torch.accelerator.set_device_index(rank)
18+
print(f"Running on rank {rank} on device {device}")
19+
else:
20+
device = torch.device("cpu")
21+
print(f"Running on device {device}")
22+
23+
backend = torch.distributed.get_default_backend_for_device(device)
24+
torch.distributed.init_process_group(backend=backend, device_id=device)
25+
return device
1626

1727
class Trainer:
1828
def __init__(
@@ -22,6 +32,7 @@ def __init__(
2232
optimizer: torch.optim.Optimizer,
2333
save_every: int,
2434
snapshot_path: str,
35+
device: torch.device,
2536
) -> None:
2637
self.local_rank = int(os.environ["LOCAL_RANK"])
2738
self.global_rank = int(os.environ["RANK"])
@@ -31,14 +42,15 @@ def __init__(
3142
self.save_every = save_every
3243
self.epochs_run = 0
3344
self.snapshot_path = snapshot_path
45+
self.device = device
3446
if os.path.exists(snapshot_path):
3547
print("Loading snapshot")
3648
self._load_snapshot(snapshot_path)
3749

3850
self.model = DDP(self.model, device_ids=[self.local_rank])
3951

4052
def _load_snapshot(self, snapshot_path):
41-
loc = f"cuda:{self.local_rank}"
53+
loc = str(self.device)
4254
snapshot = torch.load(snapshot_path, map_location=loc)
4355
self.model.load_state_dict(snapshot["MODEL_STATE"])
4456
self.epochs_run = snapshot["EPOCHS_RUN"]
@@ -93,10 +105,10 @@ def prepare_dataloader(dataset: Dataset, batch_size: int):
93105

94106

95107
def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"):
96-
ddp_setup()
108+
device = ddp_setup()
97109
dataset, model, optimizer = load_train_objs()
98110
train_data = prepare_dataloader(dataset, batch_size)
99-
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path)
111+
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path, device)
100112
trainer.train(total_epochs)
101113
destroy_process_group()
102114

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
torch>=1.11.0
1+
torch>=2.7
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# /bin/bash
2+
# bash run_example.sh {file_to_run.py} {num_gpus}
3+
# where file_to_run = example to run. Default = 'example.py'
4+
# num_gpus = num local gpus to use (must be at least 2). Default = 2
5+
6+
# samples to run include:
7+
# example.py
8+
9+
echo "Launching ${1:-example.py} with ${2:-2} gpus"
10+
torchrun --nnodes=1 --nproc_per_node=${2:-2} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-example.py}

run_distributed_examples.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ function distributed_tensor_parallelism() {
5050
uv run bash run_example.sh fsdp_tp_example.py || error "2D parallel example failed"
5151
}
5252

53+
function distributed_ddp-tutorial-series() {
54+
uv run bash run_example.sh multigpu.py || error "ddp tutorial series multigpu example failed"
55+
uv run bash run_example.sh multigpu_torchrun.py || error "ddp tutorial series multigpu torchrun example failed"
56+
uv run bash run_example.sh multinode.py || error "ddp tutorial series multinode example failed"
57+
}
58+
5359
function distributed_ddp() {
5460
uv run main.py || error "ddp example failed"
5561
}

0 commit comments

Comments
 (0)