Skip to content

Commit cb48338

Browse files
committed
Adding torch accelerator to ddp-tutorial-series example
Signed-off-by: dggaytan <[email protected]>
1 parent d47f0f3 commit cb48338

File tree

6 files changed

+76
-15
lines changed

6 files changed

+76
-15
lines changed

distributed/ddp-tutorial-series/multigpu.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,24 @@ def ddp_setup(rank, world_size):
1717
world_size: Total number of processes
1818
"""
1919
os.environ["MASTER_ADDR"] = "localhost"
20-
os.environ["MASTER_PORT"] = "12355"
21-
torch.cuda.set_device(rank)
22-
init_process_group(backend="nccl", rank=rank, world_size=world_size)
20+
os.environ["MASTER_PORT"] = "12453"
21+
22+
23+
if torch.accelerator.is_available():
24+
device_type = torch.accelerator.current_accelerator()
25+
torch.accelerator.set_device_idx(rank)
26+
device: torch.device = torch.device(f"{device_type}:{rank}")
27+
torch.accelerator.device_index(rank)
28+
print(f"Running on rank {rank} on device {device}")
29+
backend = torch.distributed.get_default_backend_for_device(device)
30+
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, device_id=device)
31+
else:
32+
device = torch.device("cpu")
33+
print(f"Running on device {device}")
34+
torch.distributed.init_process_group(backend="gloo", device_id=device)
35+
36+
# torch.cuda.set_device(rank)
37+
# init_process_group(backend="xccl", rank=rank, world_size=world_size)
2338

2439
class Trainer:
2540
def __init__(
@@ -100,5 +115,6 @@ def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_s
100115
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
101116
args = parser.parse_args()
102117

103-
world_size = torch.cuda.device_count()
118+
world_size = torch.accelerator.device_count()
119+
print(world_size)
104120
mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size)

distributed/ddp-tutorial-series/multigpu_torchrun.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,21 @@
1111

1212

1313
def ddp_setup():
14-
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
15-
init_process_group(backend="nccl")
14+
rank = int(os.environ["LOCAL_RANK"])
15+
if torch.accelerator.is_available():
16+
device_type = torch.accelerator.current_accelerator()
17+
device: torch.device = torch.device(f"{device_type}:{rank}")
18+
torch.accelerator.device_index(rank)
19+
print(f"Running on rank {rank} on device {device}")
20+
backend = torch.distributed.get_default_backend_for_device(device)
21+
torch.distributed.init_process_group(backend=backend)
22+
return device_type
23+
else:
24+
device = torch.device("cpu")
25+
print(f"Running on device {device}")
26+
torch.distributed.init_process_group(backend="gloo")
27+
return device
28+
1629

1730
class Trainer:
1831
def __init__(
@@ -22,6 +35,7 @@ def __init__(
2235
optimizer: torch.optim.Optimizer,
2336
save_every: int,
2437
snapshot_path: str,
38+
device
2539
) -> None:
2640
self.gpu_id = int(os.environ["LOCAL_RANK"])
2741
self.model = model.to(self.gpu_id)
@@ -30,14 +44,15 @@ def __init__(
3044
self.save_every = save_every
3145
self.epochs_run = 0
3246
self.snapshot_path = snapshot_path
47+
self.device = device
3348
if os.path.exists(snapshot_path):
3449
print("Loading snapshot")
3550
self._load_snapshot(snapshot_path)
3651

3752
self.model = DDP(self.model, device_ids=[self.gpu_id])
3853

3954
def _load_snapshot(self, snapshot_path):
40-
loc = f"cuda:{self.gpu_id}"
55+
loc = str(self.device)
4156
snapshot = torch.load(snapshot_path, map_location=loc)
4257
self.model.load_state_dict(snapshot["MODEL_STATE"])
4358
self.epochs_run = snapshot["EPOCHS_RUN"]
@@ -92,10 +107,10 @@ def prepare_dataloader(dataset: Dataset, batch_size: int):
92107

93108

94109
def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"):
95-
ddp_setup()
110+
device = ddp_setup()
96111
dataset, model, optimizer = load_train_objs()
97112
train_data = prepare_dataloader(dataset, batch_size)
98-
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path)
113+
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path, device)
99114
trainer.train(total_epochs)
100115
destroy_process_group()
101116

distributed/ddp-tutorial-series/multinode.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,20 @@
1111

1212

1313
def ddp_setup():
14-
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
15-
init_process_group(backend="nccl")
14+
rank = int(os.environ["LOCAL_RANK"])
15+
if torch.accelerator.is_available():
16+
device_type = torch.accelerator.current_accelerator()
17+
device: torch.device = torch.device(f"{device_type}:{rank}")
18+
torch.accelerator.device_index(rank)
19+
print(f"Running on rank {rank} on device {device}")
20+
backend = torch.distributed.get_default_backend_for_device(device)
21+
torch.distributed.init_process_group(backend=backend)
22+
return device_type
23+
else:
24+
device = torch.device("cpu")
25+
print(f"Running on device {device}")
26+
torch.distributed.init_process_group(backend="gloo")
27+
return device
1628

1729
class Trainer:
1830
def __init__(
@@ -22,6 +34,7 @@ def __init__(
2234
optimizer: torch.optim.Optimizer,
2335
save_every: int,
2436
snapshot_path: str,
37+
device
2538
) -> None:
2639
self.local_rank = int(os.environ["LOCAL_RANK"])
2740
self.global_rank = int(os.environ["RANK"])
@@ -31,14 +44,15 @@ def __init__(
3144
self.save_every = save_every
3245
self.epochs_run = 0
3346
self.snapshot_path = snapshot_path
47+
self.device = device
3448
if os.path.exists(snapshot_path):
3549
print("Loading snapshot")
3650
self._load_snapshot(snapshot_path)
3751

3852
self.model = DDP(self.model, device_ids=[self.local_rank])
3953

4054
def _load_snapshot(self, snapshot_path):
41-
loc = f"cuda:{self.local_rank}"
55+
loc = str(self.device)
4256
snapshot = torch.load(snapshot_path, map_location=loc)
4357
self.model.load_state_dict(snapshot["MODEL_STATE"])
4458
self.epochs_run = snapshot["EPOCHS_RUN"]
@@ -93,10 +107,10 @@ def prepare_dataloader(dataset: Dataset, batch_size: int):
93107

94108

95109
def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"):
96-
ddp_setup()
110+
device = ddp_setup()
97111
dataset, model, optimizer = load_train_objs()
98112
train_data = prepare_dataloader(dataset, batch_size)
99-
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path)
113+
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path, device)
100114
trainer.train(total_epochs)
101115
destroy_process_group()
102116

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
torch>=1.11.0
1+
torch>=2.7
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# /bin/bash
2+
# bash run_example.sh {file_to_run.py} {num_gpus}
3+
# where file_to_run = example to run. Default = 'example.py'
4+
# num_gpus = num local gpus to use (must be at least 2). Default = 2
5+
6+
# samples to run include:
7+
# example.py
8+
9+
echo "Launching ${1:-example.py} with ${2:-2} gpus"
10+
torchrun --nnodes=1 --nproc_per_node=${2:-2} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-example.py}

run_distributed_examples.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ function distributed_tensor_parallelism() {
5050
uv run bash run_example.sh fsdp_tp_example.py || error "2D parallel example failed"
5151
}
5252

53+
function distributed_ddp-tutorial-series() {
54+
uv run bash run_example.sh multigpu.py || error "ddp tutorial series multigpu example failed"
55+
uv run bash run_example.sh multigpu_torchrun.py || error "ddp tutorial series multigpu torchrun example failed"
56+
uv run bash run_example.sh multinode.py || error "ddp tutorial series multinode example failed"
57+
}
58+
5359
function distributed_ddp() {
5460
uv run main.py || error "ddp example failed"
5561
}

0 commit comments

Comments
 (0)