Skip to content

Commit 2c0eb8f

Browse files
committed
Adding torch accelerator to ddp-tutorial-series example
Signed-off-by: dggaytan <[email protected]>
1 parent cb48338 commit 2c0eb8f

File tree

4 files changed

+11
-17
lines changed

4 files changed

+11
-17
lines changed

distributed/ddp-tutorial-series/multigpu.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,20 @@ def ddp_setup(rank, world_size):
1717
world_size: Total number of processes
1818
"""
1919
os.environ["MASTER_ADDR"] = "localhost"
20-
os.environ["MASTER_PORT"] = "12453"
20+
os.environ["MASTER_PORT"] = "12455"
2121

2222

23+
rank = int(os.environ["LOCAL_RANK"])
2324
if torch.accelerator.is_available():
2425
device_type = torch.accelerator.current_accelerator()
25-
torch.accelerator.set_device_idx(rank)
26-
device: torch.device = torch.device(f"{device_type}:{rank}")
26+
device = torch.device(f"{device_type}:{rank}")
2727
torch.accelerator.device_index(rank)
2828
print(f"Running on rank {rank} on device {device}")
29-
backend = torch.distributed.get_default_backend_for_device(device)
30-
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, device_id=device)
3129
else:
3230
device = torch.device("cpu")
3331
print(f"Running on device {device}")
34-
torch.distributed.init_process_group(backend="gloo", device_id=device)
3532

36-
# torch.cuda.set_device(rank)
37-
# init_process_group(backend="xccl", rank=rank, world_size=world_size)
33+
backend = torch.distributed.get_default_backend_for_device(device)
3834

3935
class Trainer:
4036
def __init__(
@@ -116,5 +112,4 @@ def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_s
116112
args = parser.parse_args()
117113

118114
world_size = torch.accelerator.device_count()
119-
print(world_size)
120115
mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size)

distributed/ddp-tutorial-series/multigpu_torchrun.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,16 @@ def ddp_setup():
1414
rank = int(os.environ["LOCAL_RANK"])
1515
if torch.accelerator.is_available():
1616
device_type = torch.accelerator.current_accelerator()
17-
device: torch.device = torch.device(f"{device_type}:{rank}")
17+
device = torch.device(f"{device_type}:{rank}")
1818
torch.accelerator.device_index(rank)
1919
print(f"Running on rank {rank} on device {device}")
20-
backend = torch.distributed.get_default_backend_for_device(device)
21-
torch.distributed.init_process_group(backend=backend)
22-
return device_type
2320
else:
2421
device = torch.device("cpu")
2522
print(f"Running on device {device}")
26-
torch.distributed.init_process_group(backend="gloo")
27-
return device
23+
24+
backend = torch.distributed.get_default_backend_for_device(device)
25+
torch.distributed.init_process_group(backend=backend, device_id=device)
26+
return device
2827

2928

3029
class Trainer:
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
torch>=2.7
1+
torch>=2.7

distributed/ddp-tutorial-series/run_example.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@
77
# example.py
88

99
echo "Launching ${1:-example.py} with ${2:-2} gpus"
10-
torchrun --nnodes=1 --nproc_per_node=${2:-2} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-example.py}
10+
torchrun --nnodes=1 --nproc_per_node=${2:-2} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-example.py}

0 commit comments

Comments
 (0)