-
Notifications
You must be signed in to change notification settings - Fork 9.7k
TP SP and FSDP examples device generalization ( Cuda, XPU, etc.) #1354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
4b3abaa
cf381e0
d16c819
79f4657
66b83dd
976f270
eb8aa68
8aa4203
43bbd92
d98e000
46441a4
4d4d696
f9bcd08
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,3 +1,4 @@ | ||||||
# torchrun --nnodes 1 --nproc-per-node 4 <fn> | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to be user specific comment. It actually might be helpful in the beginning of the script, but need to add more wording prefacing that's this command + describe what's There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree, a known command significantly speeds up running . Torchrun is also relatively platform independent, hence can pretty much run OOB. Will change to something like following. The following is an example command to run this code.torchrun --nnodes 1 --nproc-per-node 4There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
import os | ||||||
import sys | ||||||
import torch | ||||||
|
@@ -76,8 +77,8 @@ def forward(self, x): | |||||
|
||||||
# create a device mesh based on the given world_size. | ||||||
_world_size = int(os.environ["WORLD_SIZE"]) | ||||||
|
||||||
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,)) | ||||||
device_type = torch.accelerator.current_accelerator().type | ||||||
device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(_world_size,)) | ||||||
_rank = device_mesh.get_rank() | ||||||
|
||||||
|
||||||
|
@@ -88,8 +89,8 @@ def forward(self, x): | |||||
|
||||||
rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}") | ||||||
|
||||||
# create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids. | ||||||
tp_model = ToyModel().to("cuda") | ||||||
# create model and move it to GPU - initdevice_type_mesh has already mapped GPU ids. | ||||||
githubsgi marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
tp_model = ToyModel().to(device_type) | ||||||
|
||||||
|
||||||
# Custom parallelization plan for the model | ||||||
|
@@ -116,7 +117,7 @@ def forward(self, x): | |||||
# For TP, input needs to be same across all TP ranks. | ||||||
# Setting the random seed is to mimic the behavior of dataloader. | ||||||
torch.manual_seed(i) | ||||||
inp = torch.rand(20, 10, device="cuda") | ||||||
inp = torch.rand(20, 10, device=device_type) | ||||||
output = tp_model(inp) | ||||||
output.sum().backward() | ||||||
optimizer.step() | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
3.8 | ||
3.9 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need all these index urls here? Can we just have:
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can make those changes. Although , I do not think that is effecting any of the tests. A maintainer needs to kickoff the 3 tests mentioned below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still seeing Python 3.8 in the log . Is this file used ? It mentions Python 3.8 specifically.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can a maintainer please kickoff the3 wf's ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you can always check which workflow file was used here: https://github.com/pytorch/examples/actions/runs/15981567115/workflow
So changing this file is a way to go I think.