1
+ # torchrun --nnodes 1 --nproc-per-node 4 <fn>
1
2
import os
2
3
import sys
3
4
import torch
13
14
14
15
from log_utils import rank_log , get_logger , verify_min_gpu_count
15
16
17
+ from torch .distributed .tensor .debug import CommDebugMode
16
18
17
19
# ---- GPU check ------------
18
20
_min_gpu_count = 2
@@ -63,9 +65,10 @@ def forward(self, x):
63
65
"""
64
66
logger = get_logger ()
65
67
68
+ device_type = torch .accelerator .current_accelerator ().type
66
69
# create a device mesh based on the given world_size.
67
70
device_mesh = init_device_mesh (
68
- device_type = "cuda" , mesh_shape = (int (os .environ ["WORLD_SIZE" ]),)
71
+ device_type = device_type , mesh_shape = (int (os .environ ["WORLD_SIZE" ]),)
69
72
)
70
73
71
74
_rank = device_mesh .get_rank ()
@@ -75,7 +78,7 @@ def forward(self, x):
75
78
rank_log (_rank , logger , f"Device Mesh created: { device_mesh = } " )
76
79
77
80
# create model and move it to GPU. Init_device_mesh has already assigned gpu ids...
78
- model = ToyModel ().to ("cuda" )
81
+ model = ToyModel ().to (device_type )
79
82
80
83
# Custom parallelization plan for the model
81
84
sp_model = parallelize_module (
@@ -87,6 +90,8 @@ def forward(self, x):
87
90
},
88
91
)
89
92
93
+ if torch .distributed .get_rank () == 0 :
94
+ print (f"model { sp_model } " )
90
95
91
96
# Create a optimizer for the parallelized module.
92
97
lr = 0.25
@@ -98,12 +103,19 @@ def forward(self, x):
98
103
num_iters = 10
99
104
rank_log (_rank , logger , "Sequence Parallel training starting..." )
100
105
106
+
101
107
for i in range (num_iters ):
102
108
# For SP, input can be different across all ranks.
103
- inp = torch .rand (20 , 10 , device = "cuda" )
104
- output = sp_model (inp )
105
- output .sum ().backward ()
106
- optimizer .step ()
109
+ #inp = torch.rand(20, 10, device=device_type)
110
+ inp = torch .rand (1 , 10 , device = device_type )
111
+ comm_mode = CommDebugMode ()
112
+ with comm_mode :
113
+ output = sp_model (inp )
114
+ output .sum ().backward ()
115
+ optimizer .step ()
107
116
rank_log (_rank , logger , f"Sequence Parallel iter { i } completed" )
108
117
118
+ if i == 0 :
119
+ print (f" rank{ torch .distributed .get_rank ()} { i } get_comm_counts { comm_mode .get_comm_counts ()} get_sharding_info() { comm_mode .get_sharding_info ()} generate_comm_debug_tracing_table { comm_mode .generate_comm_debug_tracing_table (noise_level = 1 )} " )
120
+
109
121
rank_log (_rank , logger , "Sequence Parallel training completed!" )
0 commit comments