Skip to content

Commit 6bfe651

Browse files
githubsgisoumith
authored andcommitted
Cuda to accelerator, +CommDebugMode
1 parent d0b7e37 commit 6bfe651

File tree

2 files changed

+34
-14
lines changed

2 files changed

+34
-14
lines changed

distributed/tensor_parallelism/sequence_parallel_example.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# torchrun --nnodes 1 --nproc-per-node 4 <fn>
12
import os
23
import sys
34
import torch
@@ -13,6 +14,7 @@
1314

1415
from log_utils import rank_log, get_logger, verify_min_gpu_count
1516

17+
from torch.distributed.tensor.debug import CommDebugMode
1618

1719
# ---- GPU check ------------
1820
_min_gpu_count = 2
@@ -63,9 +65,10 @@ def forward(self, x):
6365
"""
6466
logger = get_logger()
6567

68+
device_type = torch.accelerator.current_accelerator().type
6669
# create a device mesh based on the given world_size.
6770
device_mesh = init_device_mesh(
68-
device_type="cuda", mesh_shape=(int(os.environ["WORLD_SIZE"]),)
71+
device_type=device_type, mesh_shape=(int(os.environ["WORLD_SIZE"]),)
6972
)
7073

7174
_rank = device_mesh.get_rank()
@@ -75,7 +78,7 @@ def forward(self, x):
7578
rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}")
7679

7780
# create model and move it to GPU. Init_device_mesh has already assigned gpu ids...
78-
model = ToyModel().to("cuda")
81+
model = ToyModel().to(device_type)
7982

8083
# Custom parallelization plan for the model
8184
sp_model = parallelize_module(
@@ -87,6 +90,8 @@ def forward(self, x):
8790
},
8891
)
8992

93+
if torch.distributed.get_rank() == 0:
94+
print (f"model {sp_model}")
9095

9196
# Create a optimizer for the parallelized module.
9297
lr = 0.25
@@ -98,12 +103,19 @@ def forward(self, x):
98103
num_iters = 10
99104
rank_log(_rank, logger, "Sequence Parallel training starting...")
100105

106+
101107
for i in range(num_iters):
102108
# For SP, input can be different across all ranks.
103-
inp = torch.rand(20, 10, device="cuda")
104-
output = sp_model(inp)
105-
output.sum().backward()
106-
optimizer.step()
109+
#inp = torch.rand(20, 10, device=device_type)
110+
inp = torch.rand(1, 10, device=device_type)
111+
comm_mode = CommDebugMode()
112+
with comm_mode:
113+
output = sp_model(inp)
114+
output.sum().backward()
115+
optimizer.step()
107116
rank_log(_rank, logger, f"Sequence Parallel iter {i} completed")
108117

118+
if i == 0:
119+
print (f" rank{torch.distributed.get_rank()} {i} get_comm_counts {comm_mode.get_comm_counts()} get_sharding_info() {comm_mode.get_sharding_info()} generate_comm_debug_tracing_table {comm_mode.generate_comm_debug_tracing_table(noise_level=1)} ")
120+
109121
rank_log(_rank, logger, "Sequence Parallel training completed!")

distributed/tensor_parallelism/tensor_parallel_example.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111

1212
from log_utils import rank_log, get_logger, verify_min_gpu_count
13+
from torch.distributed.tensor.debug import CommDebugMode
1314

1415
# ---- GPU check ------------
1516
_min_gpu_count = 2
@@ -76,8 +77,8 @@ def forward(self, x):
7677

7778
# create a device mesh based on the given world_size.
7879
_world_size = int(os.environ["WORLD_SIZE"])
79-
80-
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
80+
device_type = torch.accelerator.current_accelerator().type
81+
device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(_world_size,))
8182
_rank = device_mesh.get_rank()
8283

8384

@@ -88,8 +89,8 @@ def forward(self, x):
8889

8990
rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}")
9091

91-
# create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids.
92-
tp_model = ToyModel().to("cuda")
92+
# create model and move it to GPU - initdevice_type_mesh has already mapped GPU ids.
93+
tp_model = ToyModel().to(device_type)
9394

9495

9596
# Custom parallelization plan for the model
@@ -102,6 +103,9 @@ def forward(self, x):
102103
},
103104
)
104105

106+
if torch.distributed.get_rank() == 0:
107+
print (f"model {tp_model}")
108+
105109
# Create an optimizer for the parallelized module.
106110
lr = 0.25
107111
optimizer = torch.optim.AdamW(tp_model.parameters(), lr=lr, foreach=True)
@@ -116,10 +120,14 @@ def forward(self, x):
116120
# For TP, input needs to be same across all TP ranks.
117121
# Setting the random seed is to mimic the behavior of dataloader.
118122
torch.manual_seed(i)
119-
inp = torch.rand(20, 10, device="cuda")
120-
output = tp_model(inp)
121-
output.sum().backward()
122-
optimizer.step()
123+
inp = torch.rand(4, 10, device=device_type)
124+
comm_mode = CommDebugMode()
125+
with comm_mode:
126+
output = tp_model(inp)
127+
output.sum().backward()
128+
optimizer.step()
123129
rank_log(_rank, logger, f"Tensor Parallel iter {i} completed")
130+
if i == 1:
131+
print (f" rank{torch.distributed.get_rank()} {i} get_comm_counts {comm_mode.get_comm_counts()} get_sharding_info() {comm_mode.get_sharding_info()} generate_comm_debug_tracing_table {comm_mode.generate_comm_debug_tracing_table(noise_level=1)} ")
124132

125133
rank_log(_rank, logger, "Tensor Parallel training completed!")

0 commit comments

Comments
 (0)