-
Notifications
You must be signed in to change notification settings - Fork 9.7k
Adding torch accelerator to ddp-tutorial-series example #1376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
✅ Deploy Preview for pytorch-examples-preview canceled.
|
os.environ["MASTER_PORT"] = "12355" | ||
torch.cuda.set_device(rank) | ||
init_process_group(backend="nccl", rank=rank, world_size=world_size) | ||
os.environ["MASTER_PORT"] = "12453" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why port was changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was an error from my side, I've changed the port
if torch.accelerator.is_available(): | ||
device_type = torch.accelerator.current_accelerator() | ||
torch.accelerator.set_device_idx(rank) | ||
device: torch.device = torch.device(f"{device_type}:{rank}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device: torch.device = torch.device(f"{device_type}:{rank}") | |
device = torch.device(f"{device_type}:{rank}") |
device_type = torch.accelerator.current_accelerator() | ||
torch.accelerator.set_device_idx(rank) | ||
device: torch.device = torch.device(f"{device_type}:{rank}") | ||
torch.accelerator.device_index(rank) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no such API device_index()
in 2.7: https://docs.pytorch.org/docs/stable/accelerator.html
What is it doing? You did set index 2 lines above...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, device_index()
will appear only in 2.8: https://docs.pytorch.org/docs/main/generated/torch.accelerator.device_index.html#torch.accelerator.device_index. And this is a context manager, i.e. you need to use it as with device_index()
. I don't see why you are using it here. And recently merged #1375 attempts to do the same. I think it will need a fix as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does not make sense to call context manager without with
. Did you intend to call set_device_index()
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I'm making the changes, thanks
|
||
# torch.cuda.set_device(rank) | ||
# init_process_group(backend="xccl", rank=rank, world_size=world_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove comments:
# torch.cuda.set_device(rank) | |
# init_process_group(backend="xccl", rank=rank, world_size=world_size) |
@@ -100,5 +115,6 @@ def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_s | |||
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)') | |||
args = parser.parse_args() | |||
|
|||
world_size = torch.cuda.device_count() | |||
world_size = torch.accelerator.device_count() | |||
print(world_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove or convert to descriptive message:
print(world_size) |
device_type = torch.accelerator.current_accelerator() | ||
device: torch.device = torch.device(f"{device_type}:{rank}") | ||
torch.accelerator.device_index(rank) | ||
print(f"Running on rank {rank} on device {device}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have hard time to understand this code block. It does not make sense to me in multiple places. Why you name what current_accelerator()
as device_type
if you return it from ddp_setup()
in the same way as you return device
for CPU path? Does ddp_setup()
return different values? Next something is happening with the rank which is also not quite clear.
I think what you are trying to achieve is closer to this:
device_type = torch.accelerator.current_accelerator() | |
device: torch.device = torch.device(f"{device_type}:{rank}") | |
torch.accelerator.device_index(rank) | |
print(f"Running on rank {rank} on device {device}") | |
torch.accelerator.set_device_index(rank) | |
device = torch.accelerator.current_accelerator() | |
print(f"Running on rank {rank} on device {device}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, so... there is a function on this file called _load_snapshot where it gets the snapshot directly from the device in which is being run, and in my first tests it was not getting the snapshot at all, so I changed it to device_type to get only the XPU variable.
Now, I've tested again with only the "device" variable and it worked, sorry for the maze 🤓
I'm updating it with your suggestion, thanks
print(f"Running on rank {rank} on device {device}") | ||
backend = torch.distributed.get_default_backend_for_device(device) | ||
torch.distributed.init_process_group(backend=backend) | ||
return device_type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and respective to above:
return device_type | |
return device |
init_process_group(backend="nccl") | ||
rank = int(os.environ["LOCAL_RANK"]) | ||
if torch.accelerator.is_available(): | ||
device_type = torch.accelerator.current_accelerator() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comments as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not addressed.
torch>=2.7 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add new line in end of file
# example.py | ||
|
||
echo "Launching ${1:-example.py} with ${2:-2} gpus" | ||
torchrun --nnodes=1 --nproc_per_node=${2:-2} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-example.py} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add new line in end of file
Signed-off-by: dggaytan <[email protected]>
2c0eb8f
to
2ca1a5c
Compare
os.environ["MASTER_PORT"] = "12355" | ||
torch.cuda.set_device(rank) | ||
init_process_group(backend="nccl", rank=rank, world_size=world_size) | ||
os.environ["MASTER_PORT"] = "12455" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's still different port number.
device_type = torch.accelerator.current_accelerator() | ||
torch.accelerator.set_device_idx(rank) | ||
device: torch.device = torch.device(f"{device_type}:{rank}") | ||
torch.accelerator.device_index(rank) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does not make sense to call context manager without with
. Did you intend to call set_device_index()
instead?
if torch.accelerator.is_available(): | ||
device_type = torch.accelerator.current_accelerator() | ||
device = torch.device(f"{device_type}:{rank}") | ||
torch.accelerator.device_index(rank) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
@@ -22,6 +34,7 @@ def __init__( | |||
optimizer: torch.optim.Optimizer, | |||
save_every: int, | |||
snapshot_path: str, | |||
device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be nice to have type designation here:
device | |
device: torch.device, |
init_process_group(backend="nccl") | ||
rank = int(os.environ["LOCAL_RANK"]) | ||
if torch.accelerator.is_available(): | ||
device_type = torch.accelerator.current_accelerator() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not addressed.
Adding accelerator to ddp tutorials examples
Support for multiple accelerators:
ddp_setup
functions inmultigpu.py
,multigpu_torchrun.py
, andmultinode.py
to usetorch.accelerator
for device management. The initialization of process groups now dynamically selects the backend based on the device type, with a fallback to CPU if no accelerator is available.Trainer
classes inmultigpu_torchrun.py
andmultinode.py
to accept adevice
parameter and use it for model placement and snapshot loading.Improvements to example execution:
run_example.sh
to simplify running tutorial examples with configurable GPU counts and node settings.run_distributed_examples.sh
to include a new function for running all DDP tutorial series examples.Dependency updates:
requirements.txt
to2.7
to ensure compatibility with the newtorch.accelerator
API.CC: @msaroufim @malfet @dvrogozh