Skip to content

Commit 997c123

Browse files
githubsgisoumith
authored andcommitted
Updating Python to 3.10, fsdp_tp_example.py to accelerator
1 parent 3fc7853 commit 997c123

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

.github/workflows/main_distributed.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ jobs:
1717

1818
steps:
1919
- uses: actions/checkout@v2
20-
- name: Set up Python 3.9
20+
- name: Set up Python 3.10
2121
uses: actions/setup-python@v2
2222
with:
23-
python-version: 3.9
23+
python-version: 3.10
2424
- name: Install PyTorch
2525
uses: astral-sh/setup-uv@v6
2626
- name: Run Tests

distributed/tensor_parallelism/fsdp_tp_example.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,11 @@
7777
# create a sharding plan based on the given world_size.
7878
dp_size = _world_size // tp_size
7979

80+
device_type = torch.accelerator.current_accelerator().type
8081
# Create a device mesh with 2 dimensions.
8182
# First dim is the data parallel dimension
8283
# Second dim is the tensor parallel dimension.
83-
device_mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp"))
84+
device_mesh = init_device_mesh(device_type, (dp_size, tp_size), mesh_dim_names=("dp", "tp"))
8485

8586
rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}")
8687
tp_mesh = device_mesh["tp"]
@@ -92,10 +93,10 @@
9293
# to mimic the behavior of the dataloader.
9394
dp_rank = dp_mesh.get_local_rank()
9495

95-
# create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids.
96+
# create model and move it to GPU - initdevice_type_mesh has already mapped GPU ids.
9697
simple_llama2_config = ModelArgs(dim=256, n_layers=2, n_heads=16, vocab_size=32000)
9798

98-
model = Transformer.from_model_args(simple_llama2_config).to("cuda")
99+
model = Transformer.from_model_args(simple_llama2_config).to(device_type)
99100

100101
# init model weights
101102
model.init_weights()
@@ -170,7 +171,7 @@
170171
for i in range(num_iterations):
171172
# seeding with dp_rank to ensure identical inputs for TP groups
172173
torch.manual_seed(i + dp_rank)
173-
inp = torch.randint(32000, (8, 256), device="cuda")
174+
inp = torch.randint(32000, (8, 256), device=device_type)
174175

175176
output = sharded_model(inp)
176177
output.sum().backward()

0 commit comments

Comments
 (0)