32
32
from _test_utils .torch_quantization .quantize_common import (
33
33
auto_quantize_helper ,
34
34
tensor_parallel_test_helper ,
35
+ data_parallel_test_helper ,
36
+ context_parallel_test_helper ,
37
+ data_tensor_context_parallel_test_helper ,
35
38
)
36
39
from packaging .version import Version
37
40
41
44
from megatron .core .parallel_state import (
42
45
destroy_model_parallel ,
43
46
get_data_parallel_group ,
47
+ get_context_parallel_group ,
44
48
get_tensor_model_parallel_group ,
45
49
)
46
50
from megatron .core .tensor_parallel .layers import ColumnParallelLinear , RowParallelLinear
@@ -91,13 +95,13 @@ def test_convert_megatron_parallel_linear(distributed_setup_size_1):
91
95
# Clean up since this is not a spawned process
92
96
destroy_model_parallel ()
93
97
94
-
98
+ # 1. Tensor Parallel Test
95
99
def _test_tensor_parallel_helper (config , rank , size ):
96
100
initialize_for_megatron (tensor_model_parallel_size = 2 , seed = SEED )
97
- model = MegatronModel (size ).cuda ()
101
+ model = MegatronModel (tp_size = size ).cuda ()
98
102
99
103
tensor_parallel_test_helper (
100
- model , config , get_tensor_model_parallel_group (), get_data_parallel_group ()
104
+ model , config , get_tensor_model_parallel_group ()
101
105
)
102
106
103
107
@@ -118,6 +122,85 @@ def test_tensor_parallel(need_2_gpus, config):
118
122
size = 2 , job = partial (_test_tensor_parallel_helper , config ), backend = "nccl"
119
123
)
120
124
125
+ # 2. Data Parallel Test
126
+ def _test_data_parallel_helper (config , rank , size ):
127
+ # TODO does this model automatically get copied to both DP ranks?
128
+ initialize_for_megatron (seed = SEED )
129
+ model = MegatronModel ().cuda ()
130
+
131
+ data_parallel_test_helper (
132
+ model , config , get_data_parallel_group ()
133
+ )
134
+
135
+
136
+ @pytest .mark .parametrize (
137
+ "config" ,
138
+ [
139
+ mtq .INT8_DEFAULT_CFG ,
140
+ mtq .FP8_DEFAULT_CFG ,
141
+ mtq .W4A8_AWQ_BETA_CFG ,
142
+ mtq .INT8_SMOOTHQUANT_CFG ,
143
+ mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG ,
144
+ mtq .INT4_AWQ_CFG ,
145
+ mtq .NVFP4_DEFAULT_CFG ,
146
+ ],
147
+ )
148
+ def test_data_parallel (need_2_gpus , config ):
149
+ spawn_multiprocess_job (
150
+ size = 2 , job = partial (_test_data_parallel_helper , config ), backend = "nccl"
151
+ )
152
+
153
+ # 3. Context Parallel Test
154
+ def _test_context_parallel_helper (config , rank , size ):
155
+ initialize_for_megatron (context_parallel_size = size , seed = SEED )
156
+ model = MegatronModel (cp_size = size ).cuda ()
157
+
158
+ context_parallel_test_helper (
159
+ model , config , get_context_parallel_group ()
160
+ )
161
+
162
+ @pytest .mark .parametrize (
163
+ "config" ,
164
+ [
165
+ mtq .INT8_DEFAULT_CFG ,
166
+ mtq .FP8_DEFAULT_CFG ,
167
+ mtq .W4A8_AWQ_BETA_CFG ,
168
+ mtq .INT8_SMOOTHQUANT_CFG ,
169
+ mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG ,
170
+ mtq .INT4_AWQ_CFG ,
171
+ mtq .NVFP4_DEFAULT_CFG ,
172
+ ],
173
+ )
174
+ def test_context_parallel (need_2_gpus , config ):
175
+ spawn_multiprocess_job (
176
+ size = 2 , job = partial (_test_context_parallel_helper , config ), backend = "nccl"
177
+ )
178
+
179
+ # 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs)
180
+ def _test_data_tensor_context_parallel_helper (config , rank , size ):
181
+ initialize_for_megatron (tensor_model_parallel_size = 2 , context_parallel_size = 2 , seed = SEED )
182
+ model = MegatronModel (tp_size = 2 , cp_size = 2 ).cuda ()
183
+
184
+ data_tensor_context_parallel_test_helper (
185
+ model , config , get_data_parallel_group (), get_tensor_model_parallel_group (), get_context_parallel_group ()
186
+ )
187
+
188
+ @pytest .mark .parametrize (
189
+ "config" ,
190
+ [
191
+ mtq .INT8_DEFAULT_CFG ,
192
+ mtq .FP8_DEFAULT_CFG ,
193
+ mtq .W4A8_AWQ_BETA_CFG ,
194
+ mtq .INT8_SMOOTHQUANT_CFG ,
195
+ mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG ,
196
+ mtq .INT4_AWQ_CFG ,
197
+ mtq .NVFP4_DEFAULT_CFG ,
198
+ ],
199
+ )
200
+ def test_data_tensor_context_parallel (need_8_gpus , config ):
201
+ spawn_multiprocess_job (
202
+ size = 8 , job = partial (_test_data_tensor_context_parallel_helper , config ), backend = "nccl"
203
+ )
121
204
122
205
def _gpt_model_provider (tp_size : int , hidden_size = 256 , vocab_size = 64 , meta_device = False ):
123
206
"""Build the model."""
0 commit comments