Skip to content

Commit 6505c66

Browse files
committed
Refactor DDP example to use Accelerator API
Signed-off-by: jafraustro <[email protected]>
1 parent c8029c6 commit 6505c66

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

distributed/ddp/README.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,10 @@ def spmd_main():
7474
local_world_size = int(env_dict['LOCAL_WORLD_SIZE'])
7575

7676
print(f"[{os.getpid()}] Initializing process group with: {env_dict}")
77-
dist.init_process_group(backend="nccl")
78-
print(
79-
f"[{os.getpid()}] world_size = {dist.get_world_size()}, "
80-
+ f"rank = {dist.get_rank()}, backend={dist.get_backend()}"
81-
)
77+
acc = torch.accelerator.current_accelerator()
78+
vendor_backend = torch.distributed.get_default_backend_for_device(acc)
79+
torch.accelerator.set_device_index(rank)
80+
torch.distributed.init_process_group(backend=vendor_backend)
8281

8382
demo_basic(rank)
8483

distributed/ddp/example.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@ def main():
7272
dist.init_process_group(backend="gloo", init_method=init_method, rank=int(env_dict["RANK"]), world_size=int(env_dict["WORLD_SIZE"]))
7373
else:
7474
print(f"[{os.getpid()}] Initializing process group with: {env_dict}")
75-
dist.init_process_group(backend="nccl")
75+
acc = torch.accelerator.current_accelerator()
76+
backend = torch.distributed.get_default_backend_for_device(acc)
77+
torch.accelerator.set_device_index(rank)
78+
dist.init_process_group(backend=backend)
7679

7780
print(
7881
f"[{os.getpid()}]: world_size = {dist.get_world_size()}, "

distributed/ddp/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
torch
1+
torch>=2.8

0 commit comments

Comments
 (0)