Skip to content

Commit 99c44d3

Browse files
authored
Sync amax & AWQ-Lite act_scale in context parallel/data parallel [OMNIML-2813] (#359)
Signed-off-by: Jennifer Chen <[email protected]>
1 parent 40a7d24 commit 99c44d3

File tree

10 files changed

+291
-57
lines changed

10 files changed

+291
-57
lines changed

examples/nemo_run/qat/README.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,21 @@ The resulting exported checkpoint also is much smaller in memory at 6.4GB compar
5656

5757
You can run the example either locally or on a [Slurm cluster](ADVANCED.md).
5858

59-
To run the example locally, launch a [NeMo container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) with version 25.07 or higher. Clone the `TensorRT-Model-Optimizer` repository and `NeMo` repository (checkout a specific commit for NeMo), then mount it onto your docker container.
59+
To run the example locally, first clone the `TensorRT-Model-Optimizer` repository, then mount the repository to a [NeMo container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) with version 25.09. After launching the Docker container, make sure to also set your HuggingFace token for dataset/model downloading.
60+
61+
Set up repo:
6062

6163
- `git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git`
62-
- `git clone https://github.com/NVIDIA-NeMo/NeMo.git && cd NeMo && git checkout 676ed1a`
6364

64-
Example docker command:
65+
Run docker command (modify with your paths) and export the HuggingFace token:
6566

6667
```bash
67-
docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash
68+
docker run -v /home/user/:/home/user/ -v /home/user/TensorRT-Model-Optimizer/:/opt/TensorRT-Model-Optimizer/ --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.09 bash
69+
70+
export HF_TOKEN=<your-token>
6871
```
6972

70-
You will also need to set your Huggingface token with `export HF_TOKEN=<your-token>`. You may also need to enable write access to the docker container to the `examples/nemo_run` folder by doing `chmod 777 nemo_run` so that logs can be written.
73+
You may also need to enable write access to the docker container to the `examples/nemo_run` folder by doing `chmod 777 nemo_run` so that logs can be written.
7174

7275
### Running the Flow Locally
7376

@@ -92,7 +95,7 @@ In order to train using QAD, launch the example with `python qat/nemo_qat_flow.p
9295
To perform QAD training, run:
9396

9497
```bash
95-
python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment
98+
python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment --tensor_parallelism 4
9699
```
97100

98101
## Supported models

examples/nemo_run/qat/nemo_qat_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def main(args):
228228
global_batch_size=GBS,
229229
micro_batch_size=MBS,
230230
use_hf_tokenizer_chat_template=True,
231-
num_workers=2,
231+
num_workers=1,
232232
persistent_workers=True,
233233
)
234234
if args.distill:

modelopt/torch/quantization/model_calib.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from modelopt.torch.opt.searcher import ForwardLoop
2828
from modelopt.torch.utils import print_rank_0
29-
from modelopt.torch.utils.distributed import ParallelState
29+
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
3030
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
3131

3232
from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context
@@ -81,6 +81,7 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
8181
return
8282

8383
def sync_quantizer_amax_across_dp(quantizer, parallel_state):
84+
"""Synchronize the amax across all ranks in the data parallel group."""
8485
if isinstance(quantizer, SequentialQuantizer):
8586
for _q in quantizer:
8687
sync_quantizer_amax_across_dp(_q, parallel_state)
@@ -94,7 +95,6 @@ def sync_quantizer_amax_across_dp(quantizer, parallel_state):
9495
for child in module.children():
9596
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
9697
sync_quantizer_amax_across_dp(child, module.parallel_state)
97-
9898
# TP sync:
9999
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same
100100

@@ -114,6 +114,7 @@ def sync_quantizer_amax_across_tp(
114114
axes_for_sync: list,
115115
parallel_state: ParallelState,
116116
):
117+
# Syncing amax across TP for sequential quantizer
117118
if isinstance(quantizer, SequentialQuantizer):
118119
for _q in quantizer:
119120
sync_quantizer_amax_across_tp(
@@ -598,19 +599,37 @@ def forward(self, input, *args, **kwargs):
598599
# This will also perform distributed amax sync for input_quantizers
599600
max_calibrate(model, lambda model: None)
600601

602+
def sync_act_scale_across_dp(module, data_parallel_group):
603+
"""Sync activation scale across Data Parallel (DP)."""
604+
if data_parallel_group.is_initialized():
605+
dist.all_reduce(
606+
module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group
607+
)
608+
601609
for name, module in model.named_modules():
602610
if (
603611
is_quantized_linear(module)
604612
and hasattr(module, "awq_lite")
605613
and module.awq_lite.num_cache_steps > 0
606614
):
615+
# Hack: MoEs forward all tokens through all experts if _if_calib is True
616+
module._if_calib = True
607617
module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps
608-
if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
618+
619+
has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
609620
torch.isnan(module.awq_lite.weight_scale)
610-
):
621+
)
622+
has_nan = DistributedProcessGroup.get_dist_syncd_obj(
623+
has_nan_local, module.parallel_state.data_parallel_group, lambda objs: any(objs)
624+
)
625+
626+
if has_nan:
611627
module.awq_lite.is_enabled = False
612-
# Hack: MoEs forward all tokens through all experts if _if_calib is True
613-
module._if_calib = True
628+
else:
629+
sync_act_scale_across_dp(
630+
module,
631+
module.parallel_state.data_parallel_group,
632+
)
614633

615634
AWQLiteHelper.cache_mode = False
616635
print_rank_0("awq_lite: Searching parameters...")

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515

1616
"""Support quantization for megatron linear layers."""
1717

18+
import logging
1819
import warnings
1920
from typing import Any
2021

2122
import megatron.core.parallel_state as mcore_parallel
2223
import megatron.core.tensor_parallel.layers as megatron_parallel
2324
import megatron.core.transformer.mlp as megatron_mlp
2425
import torch
26+
from megatron.core.parallel_state import get_data_parallel_group
2527
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
2628
from megatron.core.transformer import MegatronModule
2729
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
@@ -38,6 +40,8 @@
3840
from ..qtensor import QTensorWrapper
3941
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
4042

43+
logger = logging.getLogger(__name__)
44+
4145
__all__ = []
4246

4347

@@ -217,8 +221,14 @@ class _MegatronParallelLinear(_ParallelLinear):
217221
]
218222

219223
def _setup(self):
224+
data_parallel_group = None
225+
try:
226+
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
227+
except AssertionError:
228+
logger.warning("Context parallel group is not initialized, using data parallel group")
229+
data_parallel_group = get_data_parallel_group()
220230
self.parallel_state = ParallelState(
221-
getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(),
231+
data_parallel_group,
222232
mcore_parallel.get_tensor_model_parallel_group(),
223233
)
224234
super()._setup()

modelopt/torch/utils/distributed.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,10 @@ def __init__(
247247
self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group)
248248

249249
def __repr__(self) -> str:
250-
return f"data_parallel_group: {self.data_parallel_group}, tensor_parallel_group: {self.tensor_parallel_group}"
250+
return (
251+
f"data_parallel_group: {self.data_parallel_group}, "
252+
f"tensor_parallel_group: {self.tensor_parallel_group}, "
253+
)
251254

252255

253256
def get_group(ranks: list[int]):

tests/_test_utils/torch_dist/plugins/megatron_common.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,12 @@
8484

8585

8686
class MegatronModel(MegatronModule):
87-
def __init__(self, tp_size: int = 1, use_te_norm: bool = False):
87+
def __init__(
88+
self, tp_size: int = 1, cp_size: int = 1, use_te_norm: bool = False, tp_group=None
89+
):
8890
config = TransformerConfig(
8991
tensor_model_parallel_size=tp_size,
92+
context_parallel_size=cp_size,
9093
pipeline_model_parallel_size=1,
9194
normalization="LayerNorm",
9295
# Unused parameters below are set to avoid ZeroDivisionError in __post_init__
@@ -104,6 +107,7 @@ def __init__(self, tp_size: int = 1, use_te_norm: bool = False):
104107
gather_output=False,
105108
skip_bias_add=True,
106109
is_expert=False,
110+
tp_group=tp_group,
107111
)
108112
self.activation = nn.ReLU()
109113
if use_te_norm:
@@ -118,6 +122,7 @@ def __init__(self, tp_size: int = 1, use_te_norm: bool = False):
118122
skip_bias_add=True,
119123
input_is_parallel=True,
120124
is_expert=False,
125+
tp_group=tp_group,
121126
)
122127

123128
def forward(self, x):
@@ -127,7 +132,11 @@ def forward(self, x):
127132
x = x[0]
128133
return x
129134

130-
def get_dummy_input(self) -> torch.Tensor:
135+
def get_dummy_input(self, seed: int | None = None) -> torch.Tensor:
136+
if seed is not None:
137+
gen = torch.Generator()
138+
gen.manual_seed(seed)
139+
return torch.randn(1, 4, 32, generator=gen)
131140
return torch.randn(1, 4, 32)
132141

133142

@@ -390,13 +399,20 @@ def run_mcore_inference_with_dummy_input(
390399

391400

392401
def initialize_for_megatron(
393-
tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234
402+
tensor_model_parallel_size=1,
403+
pipeline_model_parallel_size=1,
404+
seed=1234,
405+
context_parallel_size=1,
394406
):
395407
"""Initialize Megatron model parallelism.
396408
397409
NOTE: If used in a non-spawned process, make sure to call `megatron.core.parallel_state.destroy_model_parallel()`.
398410
"""
399-
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size)
411+
initialize_model_parallel(
412+
tensor_model_parallel_size,
413+
pipeline_model_parallel_size,
414+
context_parallel_size=context_parallel_size,
415+
)
400416
model_parallel_cuda_manual_seed(seed)
401417

402418

tests/_test_utils/torch_quantization/quantize_common.py

Lines changed: 85 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import copy
16+
from unittest.mock import patch
1617

1718
import pytest
1819
import torch
@@ -22,7 +23,9 @@
2223

2324
import modelopt.torch.opt as mto
2425
import modelopt.torch.quantization as mtq
26+
import modelopt.torch.quantization.model_calib as model_calib_module # needed for patching awq_lite
2527
from modelopt.torch.quantization.backends.gemm_registry import enable_real_quant_gemm
28+
from modelopt.torch.quantization.nn.modules.tensor_quantizer import SequentialQuantizer
2629
from modelopt.torch.quantization.utils import is_quantized_linear
2730
from modelopt.torch.utils import torch_to
2831

@@ -116,38 +119,95 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N
116119
mto.restore_from_modelopt_state(model_ref, state_dict)
117120

118121

119-
def tensor_parallel_test_helper(model, config, tp_group, dp_group):
120-
# The input to fist layer, the column parallel should be the same across all tp ranks
121-
calib_data = model.get_dummy_input().cuda()
122-
dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group)
122+
def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=[]):
123+
quantizer_attr = getattr(quantizer, attr).clone()
124+
for group in groups:
125+
if group is not None:
126+
dist.all_reduce(quantizer_attr, op=op, group=group)
127+
assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
123128

124-
def forward_loop(model):
125-
model(calib_data)
126129

127-
model = mtq.quantize(model, config, forward_loop)
130+
original_awq_lite = model_calib_module.awq_lite
128131

129-
# Sanity check
130-
forward_loop(model)
131132

132-
if config in [mtq.INT8_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT8_SMOOTHQUANT_CFG]:
133-
# Lets check the amax for row parallel input quantizer; it should be the same across all tp ranks
134-
activation_amax = model.fc2.input_quantizer.amax.clone()
135-
dist.all_reduce(activation_amax, op=dist.ReduceOp.MAX, group=tp_group)
136-
assert torch.allclose(activation_amax, model.fc2.input_quantizer.amax)
133+
def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True, **kwargs):
134+
"""Function to mock awq_lite function to always use debug=True for testing"""
135+
return original_awq_lite(model, forward_loop, alpha_step, debug=True, **kwargs)
137136

138-
# Lets check the row parallel weight amax; it should be the same across all tp ranks
139-
weight_amax = model.fc2.weight_quantizer.amax.clone()
140-
dist.all_reduce(weight_amax, op=dist.ReduceOp.MAX, group=tp_group)
141-
assert torch.allclose(weight_amax, model.fc2.weight_quantizer.amax)
142137

143-
if config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
144-
# Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks
145-
input_quantizer = model.fc1.input_quantizer
146-
pre_quant_scale = input_quantizer.pre_quant_scale.clone()
147-
dist.all_reduce(pre_quant_scale, op=dist.ReduceOp.MAX, group=tp_group)
148-
assert torch.allclose(pre_quant_scale, input_quantizer.pre_quant_scale)
138+
@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite)
139+
def data_tensor_context_parallel_test_helper(
140+
model, config, mock_awq_lite, dp_group=None, tp_group=None, test_pre_quant_scale=True
141+
):
142+
# Calib data should be different across each DP rank
143+
dp_rank = dist.get_rank(group=dp_group)
144+
calib_data = model.get_dummy_input(seed=dp_rank).cuda()
145+
146+
if tp_group is not None:
147+
# The input to first layer, the column parallel should be the same across all tp ranks
148+
dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group)
149149

150-
dist.destroy_process_group()
150+
def forward_loop(model):
151+
model(calib_data)
152+
153+
model = mtq.quantize(model, config, forward_loop)
154+
155+
# Input quantizer amax
156+
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
157+
_distributed_attr_check(
158+
model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]
159+
)
160+
_distributed_attr_check(
161+
model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]
162+
)
163+
164+
# Per-tensor quantization (FP8/NVFP4) expects same amax across row and column parallel ranks
165+
# Channel-wise (INT8) only expects same amax across row parallel ranks
166+
# Block-wise quantization does not expect same amax across row and column parallel ranks
167+
if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]:
168+
if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
169+
for quantizer in model.fc1.weight_quantizer:
170+
_distributed_attr_check(
171+
quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]
172+
)
173+
else:
174+
_distributed_attr_check(
175+
model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]
176+
)
177+
178+
if config in [
179+
mtq.FP8_DEFAULT_CFG,
180+
mtq.NVFP4_DEFAULT_CFG,
181+
mtq.INT8_DEFAULT_CFG,
182+
mtq.INT8_SMOOTHQUANT_CFG,
183+
]:
184+
if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
185+
for quantizer in model.fc2.weight_quantizer:
186+
_distributed_attr_check(
187+
quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]
188+
)
189+
else:
190+
_distributed_attr_check(
191+
model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]
192+
)
193+
194+
# Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks
195+
# It is different across DP/CP ranks since the input is different
196+
if (
197+
test_pre_quant_scale
198+
and tp_group
199+
and config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]
200+
):
201+
input_quantizer = model.fc1.input_quantizer
202+
_distributed_attr_check(
203+
input_quantizer, "pre_quant_scale", dist.ReduceOp.MAX, groups=[dp_group, tp_group]
204+
)
205+
206+
# Check act scale
207+
if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
208+
_distributed_attr_check(
209+
model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, groups=[dp_group, tp_group]
210+
)
151211

152212

153213
def auto_quantize_helper(model):

tests/gpu/torch/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ def need_2_gpus():
3434
pytest.skip("Need at least 2 GPUs to run this test")
3535

3636

37+
@pytest.fixture
38+
def need_8_gpus():
39+
if torch.cuda.device_count() < 8:
40+
pytest.skip("Need at least 8 GPUs to run this test")
41+
42+
3743
@pytest.fixture(scope="module")
3844
def set_torch_dtype(request):
3945
orig_dtype = torch.get_default_dtype()

0 commit comments

Comments
 (0)