Skip to content

Adding torch accelerator to ddp-tutorial-series example #1376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

dggaytan
Copy link
Contributor

Adding accelerator to ddp tutorials examples

Support for multiple accelerators:

  • Updated ddp_setup functions in multigpu.py, multigpu_torchrun.py, and multinode.py to use torch.accelerator for device management. The initialization of process groups now dynamically selects the backend based on the device type, with a fallback to CPU if no accelerator is available.
  • Modified Trainer classes in multigpu_torchrun.py and multinode.py to accept a device parameter and use it for model placement and snapshot loading.

Improvements to example execution:

  • Added run_example.sh to simplify running tutorial examples with configurable GPU counts and node settings.
  • Updated run_distributed_examples.sh to include a new function for running all DDP tutorial series examples.

Dependency updates:

  • Increased the minimum PyTorch version requirement in requirements.txt to 2.7 to ensure compatibility with the new torch.accelerator API.

CC: @msaroufim @malfet @dvrogozh

Copy link

netlify bot commented Jul 21, 2025

Deploy Preview for pytorch-examples-preview canceled.

Name Link
🔨 Latest commit 2ca1a5c
🔍 Latest deploy log https://app.netlify.com/projects/pytorch-examples-preview/deploys/6893c439faa5d90008174f85

@meta-cla meta-cla bot added the cla signed label Jul 21, 2025
os.environ["MASTER_PORT"] = "12355"
torch.cuda.set_device(rank)
init_process_group(backend="nccl", rank=rank, world_size=world_size)
os.environ["MASTER_PORT"] = "12453"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why port was changed?

Copy link
Contributor Author

@dggaytan dggaytan Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was an error from my side, I've changed the port

if torch.accelerator.is_available():
device_type = torch.accelerator.current_accelerator()
torch.accelerator.set_device_idx(rank)
device: torch.device = torch.device(f"{device_type}:{rank}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
device: torch.device = torch.device(f"{device_type}:{rank}")
device = torch.device(f"{device_type}:{rank}")

device_type = torch.accelerator.current_accelerator()
torch.accelerator.set_device_idx(rank)
device: torch.device = torch.device(f"{device_type}:{rank}")
torch.accelerator.device_index(rank)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no such API device_index() in 2.7: https://docs.pytorch.org/docs/stable/accelerator.html

What is it doing? You did set index 2 lines above...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, device_index() will appear only in 2.8: https://docs.pytorch.org/docs/main/generated/torch.accelerator.device_index.html#torch.accelerator.device_index. And this is a context manager, i.e. you need to use it as with device_index(). I don't see why you are using it here. And recently merged #1375 attempts to do the same. I think it will need a fix as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not make sense to call context manager without with. Did you intend to call set_device_index() instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I'm making the changes, thanks

Comment on lines 35 to 37

# torch.cuda.set_device(rank)
# init_process_group(backend="xccl", rank=rank, world_size=world_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove comments:

Suggested change
# torch.cuda.set_device(rank)
# init_process_group(backend="xccl", rank=rank, world_size=world_size)

@@ -100,5 +115,6 @@ def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_s
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
args = parser.parse_args()

world_size = torch.cuda.device_count()
world_size = torch.accelerator.device_count()
print(world_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove or convert to descriptive message:

Suggested change
print(world_size)

Comment on lines 16 to 19
device_type = torch.accelerator.current_accelerator()
device: torch.device = torch.device(f"{device_type}:{rank}")
torch.accelerator.device_index(rank)
print(f"Running on rank {rank} on device {device}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have hard time to understand this code block. It does not make sense to me in multiple places. Why you name what current_accelerator() as device_type if you return it from ddp_setup() in the same way as you return device for CPU path? Does ddp_setup() return different values? Next something is happening with the rank which is also not quite clear.

I think what you are trying to achieve is closer to this:

Suggested change
device_type = torch.accelerator.current_accelerator()
device: torch.device = torch.device(f"{device_type}:{rank}")
torch.accelerator.device_index(rank)
print(f"Running on rank {rank} on device {device}")
torch.accelerator.set_device_index(rank)
device = torch.accelerator.current_accelerator()
print(f"Running on rank {rank} on device {device}")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, so... there is a function on this file called _load_snapshot where it gets the snapshot directly from the device in which is being run, and in my first tests it was not getting the snapshot at all, so I changed it to device_type to get only the XPU variable.

Now, I've tested again with only the "device" variable and it worked, sorry for the maze 🤓

I'm updating it with your suggestion, thanks

print(f"Running on rank {rank} on device {device}")
backend = torch.distributed.get_default_backend_for_device(device)
torch.distributed.init_process_group(backend=backend)
return device_type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and respective to above:

Suggested change
return device_type
return device

init_process_group(backend="nccl")
rank = int(os.environ["LOCAL_RANK"])
if torch.accelerator.is_available():
device_type = torch.accelerator.current_accelerator()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comments as above

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not addressed.

torch>=2.7
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add new line in end of file

# example.py

echo "Launching ${1:-example.py} with ${2:-2} gpus"
torchrun --nnodes=1 --nproc_per_node=${2:-2} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-example.py}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add new line in end of file

@dggaytan dggaytan force-pushed the dggaytan/distributed_DDP branch from 2c0eb8f to 2ca1a5c Compare August 6, 2025 21:08
os.environ["MASTER_PORT"] = "12355"
torch.cuda.set_device(rank)
init_process_group(backend="nccl", rank=rank, world_size=world_size)
os.environ["MASTER_PORT"] = "12455"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still different port number.

device_type = torch.accelerator.current_accelerator()
torch.accelerator.set_device_idx(rank)
device: torch.device = torch.device(f"{device_type}:{rank}")
torch.accelerator.device_index(rank)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not make sense to call context manager without with. Did you intend to call set_device_index() instead?

if torch.accelerator.is_available():
device_type = torch.accelerator.current_accelerator()
device = torch.device(f"{device_type}:{rank}")
torch.accelerator.device_index(rank)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

@@ -22,6 +34,7 @@ def __init__(
optimizer: torch.optim.Optimizer,
save_every: int,
snapshot_path: str,
device
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice to have type designation here:

Suggested change
device
device: torch.device,

init_process_group(backend="nccl")
rank = int(os.environ["LOCAL_RANK"])
if torch.accelerator.is_available():
device_type = torch.accelerator.current_accelerator()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not addressed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants