Skip to content

Commit faff217

Browse files
committed
add draft of tests
Signed-off-by: Jennifer Chen <[email protected]>
1 parent 75df0f2 commit faff217

File tree

4 files changed

+151
-8
lines changed

4 files changed

+151
-8
lines changed

tests/_test_utils/torch_dist/plugins/megatron_common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,10 @@
8383

8484

8585
class MegatronModel(MegatronModule):
86-
def __init__(self, tp_size: int = 1, use_te_norm: bool = False):
86+
def __init__(self, tp_size: int = 1, cp_size: int = 1, use_te_norm: bool = False):
8787
config = TransformerConfig(
8888
tensor_model_parallel_size=tp_size,
89+
context_parallel_size=cp_size,
8990
pipeline_model_parallel_size=1,
9091
normalization="LayerNorm",
9192
# Unused parameters below are set to avoid ZeroDivisionError in __post_init__
@@ -383,13 +384,13 @@ def run_mcore_inference_with_dummy_input(
383384

384385

385386
def initialize_for_megatron(
386-
tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234
387+
tensor_model_parallel_size=1, pipeline_model_parallel_size=1, context_parallel_size=1, seed=1234
387388
):
388389
"""Initialize Megatron model parallelism.
389390
390391
NOTE: If used in a non-spawned process, make sure to call `megatron.core.parallel_state.destroy_model_parallel()`.
391392
"""
392-
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size)
393+
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, context_parallel_size=context_parallel_size)
393394
model_parallel_cuda_manual_seed(seed)
394395

395396

tests/_test_utils/torch_quantization/quantize_common.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N
116116
mto.restore_from_modelopt_state(model_ref, state_dict)
117117

118118

119-
def tensor_parallel_test_helper(model, config, tp_group, dp_group):
120-
# The input to fist layer, the column parallel should be the same across all tp ranks
119+
def tensor_parallel_test_helper(model, config, tp_group):
120+
# The input to first layer, the column parallel should be the same across all tp ranks
121121
calib_data = model.get_dummy_input().cuda()
122122
dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group)
123123

@@ -149,6 +149,59 @@ def forward_loop(model):
149149

150150
dist.destroy_process_group()
151151

152+
def data_parallel_test_helper(model, config, dp_group):
153+
calib_data = model.get_dummy_input().cuda()
154+
155+
def forward_loop(model):
156+
model(calib_data)
157+
158+
model = mtq.quantize(model, config, forward_loop)
159+
160+
fc1_amax = model.fc1.input_quantizer.amax.clone()
161+
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group)
162+
assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax)
163+
164+
fc2_amax = model.fc2.input_quantizer.amax.clone()
165+
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group)
166+
assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax)
167+
168+
def context_parallel_test_helper(model, config, cp_group):
169+
calib_data = model.get_dummy_input().cuda()
170+
171+
def forward_loop(model):
172+
model(calib_data)
173+
174+
model = mtq.quantize(model, config, forward_loop)
175+
176+
fc1_amax = model.fc1.input_quantizer.amax.clone()
177+
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group)
178+
assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax)
179+
180+
fc2_amax = model.fc2.input_quantizer.amax.clone()
181+
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group)
182+
assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax)
183+
184+
def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group):
185+
calib_data = model.get_dummy_input().cuda()
186+
# data should be same across each TP rank
187+
dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group)
188+
189+
def forward_loop(model):
190+
model(calib_data)
191+
192+
model = mtq.quantize(model, config, forward_loop)
193+
194+
fc1_amax = model.fc1.input_quantizer.amax.clone()
195+
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=tp_group)
196+
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group)
197+
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group)
198+
assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax)
199+
200+
fc2_amax = model.fc2.input_quantizer.amax.clone()
201+
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=tp_group)
202+
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group)
203+
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group)
204+
assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax)
152205

153206
def auto_quantize_helper(model):
154207
model, search_state = mtq.auto_quantize(

tests/gpu/torch/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ def need_2_gpus():
3333
if torch.cuda.device_count() < 2:
3434
pytest.skip("Need at least 2 GPUs to run this test")
3535

36+
@pytest.fixture
37+
def need_8_gpus():
38+
if torch.cuda.device_count() < 8:
39+
pytest.skip("Need at least 8 GPUs to run this test")
40+
41+
3642

3743
@pytest.fixture(scope="module")
3844
def set_torch_dtype(request):

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
from _test_utils.torch_quantization.quantize_common import (
3333
auto_quantize_helper,
3434
tensor_parallel_test_helper,
35+
data_parallel_test_helper,
36+
context_parallel_test_helper,
37+
data_tensor_context_parallel_test_helper,
3538
)
3639
from packaging.version import Version
3740

@@ -41,6 +44,7 @@
4144
from megatron.core.parallel_state import (
4245
destroy_model_parallel,
4346
get_data_parallel_group,
47+
get_context_parallel_group,
4448
get_tensor_model_parallel_group,
4549
)
4650
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
@@ -91,13 +95,13 @@ def test_convert_megatron_parallel_linear(distributed_setup_size_1):
9195
# Clean up since this is not a spawned process
9296
destroy_model_parallel()
9397

94-
98+
# 1. Tensor Parallel Test
9599
def _test_tensor_parallel_helper(config, rank, size):
96100
initialize_for_megatron(tensor_model_parallel_size=2, seed=SEED)
97-
model = MegatronModel(size).cuda()
101+
model = MegatronModel(tp_size=size).cuda()
98102

99103
tensor_parallel_test_helper(
100-
model, config, get_tensor_model_parallel_group(), get_data_parallel_group()
104+
model, config, get_tensor_model_parallel_group()
101105
)
102106

103107

@@ -118,6 +122,85 @@ def test_tensor_parallel(need_2_gpus, config):
118122
size=2, job=partial(_test_tensor_parallel_helper, config), backend="nccl"
119123
)
120124

125+
# 2. Data Parallel Test
126+
def _test_data_parallel_helper(config, rank, size):
127+
# TODO does this model automatically get copied to both DP ranks?
128+
initialize_for_megatron(seed=SEED)
129+
model = MegatronModel().cuda()
130+
131+
data_parallel_test_helper(
132+
model, config, get_data_parallel_group()
133+
)
134+
135+
136+
@pytest.mark.parametrize(
137+
"config",
138+
[
139+
mtq.INT8_DEFAULT_CFG,
140+
mtq.FP8_DEFAULT_CFG,
141+
mtq.W4A8_AWQ_BETA_CFG,
142+
mtq.INT8_SMOOTHQUANT_CFG,
143+
mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG,
144+
mtq.INT4_AWQ_CFG,
145+
mtq.NVFP4_DEFAULT_CFG,
146+
],
147+
)
148+
def test_data_parallel(need_2_gpus, config):
149+
spawn_multiprocess_job(
150+
size=2, job=partial(_test_data_parallel_helper, config), backend="nccl"
151+
)
152+
153+
# 3. Context Parallel Test
154+
def _test_context_parallel_helper(config, rank, size):
155+
initialize_for_megatron(context_parallel_size=size, seed=SEED)
156+
model = MegatronModel(cp_size=size).cuda()
157+
158+
context_parallel_test_helper(
159+
model, config, get_context_parallel_group()
160+
)
161+
162+
@pytest.mark.parametrize(
163+
"config",
164+
[
165+
mtq.INT8_DEFAULT_CFG,
166+
mtq.FP8_DEFAULT_CFG,
167+
mtq.W4A8_AWQ_BETA_CFG,
168+
mtq.INT8_SMOOTHQUANT_CFG,
169+
mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG,
170+
mtq.INT4_AWQ_CFG,
171+
mtq.NVFP4_DEFAULT_CFG,
172+
],
173+
)
174+
def test_context_parallel(need_2_gpus, config):
175+
spawn_multiprocess_job(
176+
size=2, job=partial(_test_context_parallel_helper, config), backend="nccl"
177+
)
178+
179+
# 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs)
180+
def _test_data_tensor_context_parallel_helper(config, rank, size):
181+
initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED)
182+
model = MegatronModel(tp_size=2, cp_size=2).cuda()
183+
184+
data_tensor_context_parallel_test_helper(
185+
model, config, get_data_parallel_group(), get_tensor_model_parallel_group(), get_context_parallel_group()
186+
)
187+
188+
@pytest.mark.parametrize(
189+
"config",
190+
[
191+
mtq.INT8_DEFAULT_CFG,
192+
mtq.FP8_DEFAULT_CFG,
193+
mtq.W4A8_AWQ_BETA_CFG,
194+
mtq.INT8_SMOOTHQUANT_CFG,
195+
mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG,
196+
mtq.INT4_AWQ_CFG,
197+
mtq.NVFP4_DEFAULT_CFG,
198+
],
199+
)
200+
def test_data_tensor_context_parallel(need_8_gpus, config):
201+
spawn_multiprocess_job(
202+
size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl"
203+
)
121204

122205
def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False):
123206
"""Build the model."""

0 commit comments

Comments
 (0)