Skip to content

Commit f17131f

Browse files
committed
sync amax in context parallel and awq act scale
Signed-off-by: Jennifer Chen <[email protected]>
1 parent 7d5f636 commit f17131f

File tree

8 files changed

+178
-15
lines changed

8 files changed

+178
-15
lines changed

examples/nemo_run/qat/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ In order to train using QAD, launch the example with `python qat/nemo_qat_flow.p
9292
To perform QAD training, run:
9393

9494
```bash
95-
python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment
95+
python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment --tensor_parallelism 4
9696
```
9797

9898
## Supported models

modelopt/torch/quantization/model_calib.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,21 +79,22 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
7979
if not distributed_sync:
8080
return
8181

82-
def sync_quantizer_amax_across_dp(quantizer, parallel_state):
82+
def sync_quantizer_amax_across_dp_cp(quantizer, parallel_state):
83+
"""Synchronize the amax across all ranks in the data parallel and context parallel groups."""
8384
if isinstance(quantizer, SequentialQuantizer):
8485
for _q in quantizer:
85-
sync_quantizer_amax_across_dp(_q, parallel_state)
86+
sync_quantizer_amax_across_dp_cp(_q, parallel_state)
8687
return
8788
if getattr(quantizer, "_amax", None) is not None:
8889
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
90+
quantizer.sync_amax_across_distributed_group(parallel_state.context_parallel_group)
8991
# TODO: create sync_bias_across_distributed_group
9092

9193
for name, module in model.named_modules():
9294
if isinstance(module, QuantModule):
9395
for child in module.children():
9496
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
95-
sync_quantizer_amax_across_dp(child, module.parallel_state)
96-
97+
sync_quantizer_amax_across_dp_cp(child, module.parallel_state)
9798
# TP sync:
9899
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same
99100

@@ -624,13 +625,23 @@ def forward(self, input, *args, **kwargs):
624625
# This will also perform distributed amax sync for input_quantizers
625626
max_calibrate(model, lambda model: None)
626627

628+
def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_group):
629+
# Sync across Data Parallel (DP)
630+
if data_parallel_group.is_initialized():
631+
dist.all_reduce(module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group)
632+
# Sync across Context Parallel (CP)
633+
if context_parallel_group.is_initialized():
634+
dist.all_reduce(module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=context_parallel_group.group)
635+
627636
for name, module in model.named_modules():
628637
if (
629638
is_quantized_linear(module)
630639
and hasattr(module, "awq_lite")
631640
and module.awq_lite.num_cache_steps > 0
632641
):
633642
module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps
643+
sync_act_scale_across_dp_cp(module, module.parallel_state.data_parallel_group, module.parallel_state.context_parallel_group)
644+
634645
# Hack: MoEs forward all tokens through all experts if _if_calib is True
635646
module._if_calib = True
636647

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import megatron.core.transformer.mlp as megatron_mlp
2424
import torch
2525
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
26+
from megatron.core.parallel_state import get_data_parallel_group
2627
from megatron.core.transformer import MegatronModule
2728
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
2829
from megatron.core.utils import get_tensor_model_parallel_group_if_none
@@ -217,9 +218,15 @@ class _MegatronParallelLinear(_ParallelLinear):
217218
]
218219

219220
def _setup(self):
221+
data_parallel_group = None
222+
try:
223+
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
224+
except:
225+
data_parallel_group = get_data_parallel_group()
220226
self.parallel_state = ParallelState(
221-
getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(),
227+
data_parallel_group,
222228
mcore_parallel.get_tensor_model_parallel_group(),
229+
mcore_parallel.get_context_parallel_group(),
223230
)
224231
super()._setup()
225232

modelopt/torch/utils/distributed.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,13 +241,15 @@ def __init__(
241241
self,
242242
data_parallel_group: torch.distributed.ProcessGroup | int | None = None,
243243
tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
244+
context_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
244245
):
245246
"""Initialize the parallel state."""
246247
self.data_parallel_group = DistributedProcessGroup(data_parallel_group)
247248
self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group)
249+
self.context_parallel_group = DistributedProcessGroup(context_parallel_group)
248250

249251
def __repr__(self) -> str:
250-
return f"data_parallel_group: {self.data_parallel_group}, tensor_parallel_group: {self.tensor_parallel_group}"
252+
return f"data_parallel_group: {self.data_parallel_group}, tensor_parallel_group: {self.tensor_parallel_group}, context_parallel_group: {self.context_parallel_group}"
251253

252254

253255
def get_group(ranks: list[int]):

tests/_test_utils/torch_dist/plugins/megatron_common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,10 @@
8383

8484

8585
class MegatronModel(MegatronModule):
86-
def __init__(self, tp_size: int = 1, use_te_norm: bool = False):
86+
def __init__(self, tp_size: int = 1, cp_size: int = 1, use_te_norm: bool = False):
8787
config = TransformerConfig(
8888
tensor_model_parallel_size=tp_size,
89+
context_parallel_size=cp_size,
8990
pipeline_model_parallel_size=1,
9091
normalization="LayerNorm",
9192
# Unused parameters below are set to avoid ZeroDivisionError in __post_init__
@@ -383,13 +384,13 @@ def run_mcore_inference_with_dummy_input(
383384

384385

385386
def initialize_for_megatron(
386-
tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234
387+
tensor_model_parallel_size=1, pipeline_model_parallel_size=1, context_parallel_size=1, seed=1234
387388
):
388389
"""Initialize Megatron model parallelism.
389390
390391
NOTE: If used in a non-spawned process, make sure to call `megatron.core.parallel_state.destroy_model_parallel()`.
391392
"""
392-
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size)
393+
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, context_parallel_size=context_parallel_size)
393394
model_parallel_cuda_manual_seed(seed)
394395

395396

tests/_test_utils/torch_quantization/quantize_common.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N
116116
mto.restore_from_modelopt_state(model_ref, state_dict)
117117

118118

119-
def tensor_parallel_test_helper(model, config, tp_group, dp_group):
120-
# The input to fist layer, the column parallel should be the same across all tp ranks
119+
def tensor_parallel_test_helper(model, config, tp_group):
120+
# The input to first layer, the column parallel should be the same across all tp ranks
121121
calib_data = model.get_dummy_input().cuda()
122122
dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group)
123123

@@ -149,6 +149,59 @@ def forward_loop(model):
149149

150150
dist.destroy_process_group()
151151

152+
def data_parallel_test_helper(model, config, dp_group):
153+
calib_data = model.get_dummy_input().cuda()
154+
155+
def forward_loop(model):
156+
model(calib_data)
157+
158+
model = mtq.quantize(model, config, forward_loop)
159+
160+
fc1_amax = model.fc1.input_quantizer.amax.clone()
161+
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group)
162+
assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax)
163+
164+
fc2_amax = model.fc2.input_quantizer.amax.clone()
165+
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group)
166+
assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax)
167+
168+
def context_parallel_test_helper(model, config, cp_group):
169+
calib_data = model.get_dummy_input().cuda()
170+
171+
def forward_loop(model):
172+
model(calib_data)
173+
174+
model = mtq.quantize(model, config, forward_loop)
175+
176+
fc1_amax = model.fc1.input_quantizer.amax.clone()
177+
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group)
178+
assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax)
179+
180+
fc2_amax = model.fc2.input_quantizer.amax.clone()
181+
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group)
182+
assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax)
183+
184+
def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group):
185+
calib_data = model.get_dummy_input().cuda()
186+
# data should be same across each TP rank
187+
dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group)
188+
189+
def forward_loop(model):
190+
model(calib_data)
191+
192+
model = mtq.quantize(model, config, forward_loop)
193+
194+
fc1_amax = model.fc1.input_quantizer.amax.clone()
195+
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=tp_group)
196+
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group)
197+
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group)
198+
assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax)
199+
200+
fc2_amax = model.fc2.input_quantizer.amax.clone()
201+
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=tp_group)
202+
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group)
203+
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group)
204+
assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax)
152205

153206
def auto_quantize_helper(model):
154207
model, search_state = mtq.auto_quantize(

tests/gpu/torch/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ def need_2_gpus():
3333
if torch.cuda.device_count() < 2:
3434
pytest.skip("Need at least 2 GPUs to run this test")
3535

36+
@pytest.fixture
37+
def need_8_gpus():
38+
if torch.cuda.device_count() < 8:
39+
pytest.skip("Need at least 8 GPUs to run this test")
40+
41+
3642

3743
@pytest.fixture(scope="module")
3844
def set_torch_dtype(request):

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
from _test_utils.torch_quantization.quantize_common import (
3333
auto_quantize_helper,
3434
tensor_parallel_test_helper,
35+
data_parallel_test_helper,
36+
context_parallel_test_helper,
37+
data_tensor_context_parallel_test_helper,
3538
)
3639
from packaging.version import Version
3740

@@ -41,6 +44,7 @@
4144
from megatron.core.parallel_state import (
4245
destroy_model_parallel,
4346
get_data_parallel_group,
47+
get_context_parallel_group,
4448
get_tensor_model_parallel_group,
4549
)
4650
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
@@ -91,13 +95,13 @@ def test_convert_megatron_parallel_linear(distributed_setup_size_1):
9195
# Clean up since this is not a spawned process
9296
destroy_model_parallel()
9397

94-
98+
# 1. Tensor Parallel Test
9599
def _test_tensor_parallel_helper(config, rank, size):
96100
initialize_for_megatron(tensor_model_parallel_size=2, seed=SEED)
97-
model = MegatronModel(size).cuda()
101+
model = MegatronModel(tp_size=size).cuda()
98102

99103
tensor_parallel_test_helper(
100-
model, config, get_tensor_model_parallel_group(), get_data_parallel_group()
104+
model, config, get_tensor_model_parallel_group()
101105
)
102106

103107

@@ -118,6 +122,85 @@ def test_tensor_parallel(need_2_gpus, config):
118122
size=2, job=partial(_test_tensor_parallel_helper, config), backend="nccl"
119123
)
120124

125+
# 2. Data Parallel Test
126+
def _test_data_parallel_helper(config, rank, size):
127+
# TODO does this model automatically get copied to both DP ranks?
128+
initialize_for_megatron(seed=SEED)
129+
model = MegatronModel().cuda()
130+
131+
data_parallel_test_helper(
132+
model, config, get_data_parallel_group()
133+
)
134+
135+
136+
@pytest.mark.parametrize(
137+
"config",
138+
[
139+
mtq.INT8_DEFAULT_CFG,
140+
mtq.FP8_DEFAULT_CFG,
141+
mtq.W4A8_AWQ_BETA_CFG,
142+
mtq.INT8_SMOOTHQUANT_CFG,
143+
mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG,
144+
mtq.INT4_AWQ_CFG,
145+
mtq.NVFP4_DEFAULT_CFG,
146+
],
147+
)
148+
def test_data_parallel(need_2_gpus, config):
149+
spawn_multiprocess_job(
150+
size=2, job=partial(_test_data_parallel_helper, config), backend="nccl"
151+
)
152+
153+
# 3. Context Parallel Test
154+
def _test_context_parallel_helper(config, rank, size):
155+
initialize_for_megatron(context_parallel_size=size, seed=SEED)
156+
model = MegatronModel(cp_size=size).cuda()
157+
158+
context_parallel_test_helper(
159+
model, config, get_context_parallel_group()
160+
)
161+
162+
@pytest.mark.parametrize(
163+
"config",
164+
[
165+
mtq.INT8_DEFAULT_CFG,
166+
mtq.FP8_DEFAULT_CFG,
167+
mtq.W4A8_AWQ_BETA_CFG,
168+
mtq.INT8_SMOOTHQUANT_CFG,
169+
mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG,
170+
mtq.INT4_AWQ_CFG,
171+
mtq.NVFP4_DEFAULT_CFG,
172+
],
173+
)
174+
def test_context_parallel(need_2_gpus, config):
175+
spawn_multiprocess_job(
176+
size=2, job=partial(_test_context_parallel_helper, config), backend="nccl"
177+
)
178+
179+
# 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs)
180+
def _test_data_tensor_context_parallel_helper(config, rank, size):
181+
initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED)
182+
model = MegatronModel(tp_size=2, cp_size=2).cuda()
183+
184+
data_tensor_context_parallel_test_helper(
185+
model, config, get_data_parallel_group(), get_tensor_model_parallel_group(), get_context_parallel_group()
186+
)
187+
188+
@pytest.mark.parametrize(
189+
"config",
190+
[
191+
mtq.INT8_DEFAULT_CFG,
192+
mtq.FP8_DEFAULT_CFG,
193+
mtq.W4A8_AWQ_BETA_CFG,
194+
mtq.INT8_SMOOTHQUANT_CFG,
195+
mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG,
196+
mtq.INT4_AWQ_CFG,
197+
mtq.NVFP4_DEFAULT_CFG,
198+
],
199+
)
200+
def test_data_tensor_context_parallel(need_8_gpus, config):
201+
spawn_multiprocess_job(
202+
size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl"
203+
)
121204

122205
def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False):
123206
"""Build the model."""

0 commit comments

Comments
 (0)