Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
f17131f
sync amax in context parallel and awq act scale
jenchen13 Sep 24, 2025
42519cc
lint
jenchen13 Sep 25, 2025
264adbb
test weight quantizer too
jenchen13 Sep 25, 2025
7cbe5b9
Merge branch 'main' into jennifchen/cp_amax_sync
jenchen13 Sep 25, 2025
1f7d17e
fix test
jenchen13 Sep 26, 2025
71a9f7a
Merge branch 'main' into jennifchen/cp_amax_sync
jenchen13 Sep 29, 2025
d02365c
awq test
jenchen13 Sep 29, 2025
5a572da
move awq test inside megatron tests
jenchen13 Sep 29, 2025
fc0bb88
fix amax tests
jenchen13 Sep 30, 2025
95da832
fix awq lite param
jenchen13 Sep 30, 2025
34c11ef
fix test
jenchen13 Sep 30, 2025
10e3e2b
Merge branch 'main' into jennifchen/cp_amax_sync
jenchen13 Sep 30, 2025
9f0691f
uncomment test
jenchen13 Sep 30, 2025
fa8f4c8
add print
jenchen13 Oct 1, 2025
d1fac44
docstring
jenchen13 Oct 1, 2025
22b8b73
fix tests
jenchen13 Oct 2, 2025
ca7c0e8
Merge branch 'main' into jennifchen/cp_amax_sync
jenchen13 Oct 2, 2025
3f857a3
fix multiprocess size
jenchen13 Oct 2, 2025
93bfd52
fix tests
jenchen13 Oct 8, 2025
6761109
consolidate tests
jenchen13 Oct 9, 2025
291cfa3
Merge branch 'main' into jennifchen/cp_amax_sync
jenchen13 Oct 9, 2025
a106dd9
fix test
jenchen13 Oct 9, 2025
50000dd
fix bug
jenchen13 Oct 10, 2025
2664563
Merge branch 'main' into jennifchen/cp_amax_sync
jenchen13 Oct 10, 2025
440ca48
update qat readme
jenchen13 Oct 10, 2025
2e8ef58
update readme
jenchen13 Oct 10, 2025
5cb380c
Merge branch 'main' into jennifchen/cp_amax_sync
jenchen13 Oct 10, 2025
afe6f34
fix dist has_nan
jenchen13 Oct 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/nemo_run/qat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ In order to train using QAD, launch the example with `python qat/nemo_qat_flow.p
To perform QAD training, run:

```bash
python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment
python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment --tensor_parallelism 4
```

## Supported models
Expand Down
32 changes: 26 additions & 6 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,22 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
if not distributed_sync:
return

def sync_quantizer_amax_across_dp(quantizer, parallel_state):
def sync_quantizer_amax_across_dp_cp(quantizer, parallel_state):
"""Synchronize the amax across all ranks in the data parallel and context parallel groups."""
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
sync_quantizer_amax_across_dp(_q, parallel_state)
sync_quantizer_amax_across_dp_cp(_q, parallel_state)
return
if getattr(quantizer, "_amax", None) is not None:
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
quantizer.sync_amax_across_distributed_group(parallel_state.context_parallel_group)
# TODO: create sync_bias_across_distributed_group

for name, module in model.named_modules():
if isinstance(module, QuantModule):
for child in module.children():
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
sync_quantizer_amax_across_dp(child, module.parallel_state)

sync_quantizer_amax_across_dp_cp(child, module.parallel_state)
# TP sync:
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same

Expand Down Expand Up @@ -598,19 +599,38 @@ def forward(self, input, *args, **kwargs):
# This will also perform distributed amax sync for input_quantizers
max_calibrate(model, lambda model: None)

def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_group):
# Sync across Data Parallel (DP)
if data_parallel_group.is_initialized():
dist.all_reduce(
module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group
)
# Sync across Context Parallel (CP)
if context_parallel_group.is_initialized():
dist.all_reduce(
module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=context_parallel_group.group
)

for name, module in model.named_modules():
if (
is_quantized_linear(module)
and hasattr(module, "awq_lite")
and module.awq_lite.num_cache_steps > 0
):
# Hack: MoEs forward all tokens through all experts if _if_calib is True
module._if_calib = True
module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps

if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
torch.isnan(module.awq_lite.weight_scale)
):
module.awq_lite.is_enabled = False
# Hack: MoEs forward all tokens through all experts if _if_calib is True
module._if_calib = True
else:
sync_act_scale_across_dp_cp(
module,
module.parallel_state.data_parallel_group,
module.parallel_state.context_parallel_group,
)

AWQLiteHelper.cache_mode = False
print_rank_0("awq_lite: Searching parameters...")
Expand Down
9 changes: 8 additions & 1 deletion modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import megatron.core.tensor_parallel.layers as megatron_parallel
import megatron.core.transformer.mlp as megatron_mlp
import torch
from megatron.core.parallel_state import get_data_parallel_group
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
Expand Down Expand Up @@ -217,9 +218,15 @@ class _MegatronParallelLinear(_ParallelLinear):
]

def _setup(self):
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
data_parallel_group = get_data_parallel_group()
self.parallel_state = ParallelState(
getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(),
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
mcore_parallel.get_context_parallel_group(),
)
Comment on lines 224 to 233
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard get_context_parallel_group() when CP is disabled

get_context_parallel_group() asserts that context parallelism was initialized. When we run TP/DP-only (the default in plenty of setups), that assertion fires and _MegatronParallelLinear._setup() will crash. Please mirror the DP guard and fall back to -1 (unused) when the call raises.

Something along these lines keeps the DP-only path working:

-        self.parallel_state = ParallelState(
-            data_parallel_group,
-            mcore_parallel.get_tensor_model_parallel_group(),
-            mcore_parallel.get_context_parallel_group(),
-        )
+        try:
+            context_parallel_group = mcore_parallel.get_context_parallel_group()
+        except AssertionError:
+            context_parallel_group = -1
+        self.parallel_state = ParallelState(
+            data_parallel_group,
+            mcore_parallel.get_tensor_model_parallel_group(),
+            context_parallel_group,
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
data_parallel_group = get_data_parallel_group()
self.parallel_state = ParallelState(
getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(),
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
mcore_parallel.get_context_parallel_group(),
)
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
data_parallel_group = get_data_parallel_group()
try:
context_parallel_group = mcore_parallel.get_context_parallel_group()
except AssertionError:
context_parallel_group = -1
self.parallel_state = ParallelState(
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
context_parallel_group,
)
🤖 Prompt for AI Agents
In modelopt/torch/quantization/plugins/megatron.py around lines 221 to 230, the
call to mcore_parallel.get_context_parallel_group() is unguarded and will assert
(and crash) when context-parallelism is disabled; mirror the data-parallel
guard: try to call get_context_parallel_group() and if it raises
(AssertionError) set the context group to -1 (or the sentinel used for
"unused"), then pass that value into ParallelState so TP/DP-only setups won't
fail. Ensure you only catch the assertion from the context-group call and keep
the existing fallback for get_data_parallel_group() unchanged.

super()._setup()

Expand Down
8 changes: 7 additions & 1 deletion modelopt/torch/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,19 @@ def __init__(
self,
data_parallel_group: torch.distributed.ProcessGroup | int | None = None,
tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
context_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
):
"""Initialize the parallel state."""
self.data_parallel_group = DistributedProcessGroup(data_parallel_group)
self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group)
self.context_parallel_group = DistributedProcessGroup(context_parallel_group)

def __repr__(self) -> str:
return f"data_parallel_group: {self.data_parallel_group}, tensor_parallel_group: {self.tensor_parallel_group}"
return (
f"data_parallel_group: {self.data_parallel_group}, "
f"tensor_parallel_group: {self.tensor_parallel_group}, "
f"context_parallel_group: {self.context_parallel_group}"
)


def get_group(ranks: list[int]):
Expand Down
14 changes: 11 additions & 3 deletions tests/_test_utils/torch_dist/plugins/megatron_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@


class MegatronModel(MegatronModule):
def __init__(self, tp_size: int = 1, use_te_norm: bool = False):
def __init__(self, tp_size: int = 1, cp_size: int = 1, use_te_norm: bool = False):
config = TransformerConfig(
tensor_model_parallel_size=tp_size,
context_parallel_size=cp_size,
pipeline_model_parallel_size=1,
normalization="LayerNorm",
# Unused parameters below are set to avoid ZeroDivisionError in __post_init__
Expand Down Expand Up @@ -383,13 +384,20 @@ def run_mcore_inference_with_dummy_input(


def initialize_for_megatron(
tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
seed=1234,
context_parallel_size=1,
):
"""Initialize Megatron model parallelism.

NOTE: If used in a non-spawned process, make sure to call `megatron.core.parallel_state.destroy_model_parallel()`.
"""
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size)
initialize_model_parallel(
tensor_model_parallel_size,
pipeline_model_parallel_size,
context_parallel_size=context_parallel_size,
)
model_parallel_cuda_manual_seed(seed)


Expand Down
75 changes: 73 additions & 2 deletions tests/_test_utils/torch_quantization/quantize_common.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we dont need separate methods tensor_parallel_test_helper, dp_cp_parallel_test_helper and data_tensor_context_parallel_test_helper for testing out all the combinations. Can we merge them into one and do data_tensor_context_parallel_test_helper(..., tp_group=None, dp_group=None, cp_group=None)?

Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.backends.gemm_registry import enable_real_quant_gemm
from modelopt.torch.quantization.nn.modules.tensor_quantizer import SequentialQuantizer
from modelopt.torch.quantization.utils import is_quantized_linear
from modelopt.torch.utils import torch_to

Expand Down Expand Up @@ -116,8 +117,8 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N
mto.restore_from_modelopt_state(model_ref, state_dict)


def tensor_parallel_test_helper(model, config, tp_group, dp_group):
# The input to fist layer, the column parallel should be the same across all tp ranks
def tensor_parallel_test_helper(model, config, tp_group):
# The input to first layer, the column parallel should be the same across all tp ranks
calib_data = model.get_dummy_input().cuda()
dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group)

Expand Down Expand Up @@ -150,6 +151,76 @@ def forward_loop(model):
dist.destroy_process_group()


def dp_cp_parallel_test_helper(model, config, group):
calib_data = model.get_dummy_input().cuda()

def forward_loop(model):
model(calib_data)

model = mtq.quantize(model, config, forward_loop)

def reduce_amax(quantizer):
amax = quantizer.amax.clone()
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=group)
assert torch.allclose(amax, quantizer.amax)

# Input quantizer amax
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
reduce_amax(model.fc1.input_quantizer)
reduce_amax(model.fc2.input_quantizer)

# Weight quantizer amax
if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
for quantizer in model.fc1.weight_quantizer:
reduce_amax(quantizer)
else:
reduce_amax(model.fc1.weight_quantizer)
if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
for quantizer in model.fc2.weight_quantizer:
reduce_amax(quantizer)
else:
reduce_amax(model.fc2.weight_quantizer)


def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group):
calib_data = model.get_dummy_input().cuda()
# data should be same across each TP rank
dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group)

def forward_loop(model):
model(calib_data)

model = mtq.quantize(model, config, forward_loop)

def reduce_amax(quantizer):
amax = quantizer.amax.clone()
print("amax before reduce", amax)
print("quantizer.amax before reduce", quantizer.amax)
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=dp_group)
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=cp_group)
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=tp_group)
print("amax after reduce", amax)
print("quantizer.amax after reduce", quantizer.amax)
assert torch.allclose(amax, quantizer.amax)

# Input quantizer amax
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
reduce_amax(model.fc1.input_quantizer)
reduce_amax(model.fc2.input_quantizer)

if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
for quantizer in model.fc1.weight_quantizer:
reduce_amax(quantizer)
else:
reduce_amax(model.fc1.weight_quantizer)

if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
for quantizer in model.fc2.weight_quantizer:
reduce_amax(quantizer)
else:
reduce_amax(model.fc2.weight_quantizer)


def auto_quantize_helper(model):
model, search_state = mtq.auto_quantize(
model,
Expand Down
6 changes: 6 additions & 0 deletions tests/gpu/torch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def need_2_gpus():
pytest.skip("Need at least 2 GPUs to run this test")


@pytest.fixture
def need_8_gpus():
if torch.cuda.device_count() < 8:
pytest.skip("Need at least 8 GPUs to run this test")


@pytest.fixture(scope="module")
def set_torch_dtype(request):
orig_dtype = torch.get_default_dtype()
Expand Down
94 changes: 90 additions & 4 deletions tests/gpu/torch/quantization/plugins/test_megatron.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as https://github.com/NVIDIA/TensorRT-Model-Optimizer/pull/359/files#r2410523895
Can we combine dp, tp, cp tests by parameterize them?

Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from _test_utils.torch_quantization.quant_utils import get_model_size
from _test_utils.torch_quantization.quantize_common import (
auto_quantize_helper,
data_tensor_context_parallel_test_helper,
dp_cp_parallel_test_helper,
tensor_parallel_test_helper,
)
from packaging.version import Version
Expand All @@ -40,6 +42,7 @@
import megatron.core
from megatron.core.parallel_state import (
destroy_model_parallel,
get_context_parallel_group,
get_data_parallel_group,
get_tensor_model_parallel_group,
)
Expand Down Expand Up @@ -92,13 +95,12 @@ def test_convert_megatron_parallel_linear(distributed_setup_size_1):
destroy_model_parallel()


# 1. Tensor Parallel Test
def _test_tensor_parallel_helper(config, rank, size):
initialize_for_megatron(tensor_model_parallel_size=2, seed=SEED)
model = MegatronModel(size).cuda()
model = MegatronModel(tp_size=size).cuda()

tensor_parallel_test_helper(
model, config, get_tensor_model_parallel_group(), get_data_parallel_group()
)
tensor_parallel_test_helper(model, config, get_tensor_model_parallel_group())


@pytest.mark.parametrize(
Expand All @@ -119,6 +121,90 @@ def test_tensor_parallel(need_2_gpus, config):
)


# 2. Data Parallel Test
def _test_data_parallel_helper(config, rank, size):
initialize_for_megatron(seed=SEED + rank) # modify seed so data is different across ranks
model = MegatronModel().cuda()

dp_cp_parallel_test_helper(model, config, get_data_parallel_group())


@pytest.mark.parametrize(
"config",
[
mtq.INT8_DEFAULT_CFG,
mtq.FP8_DEFAULT_CFG,
mtq.W4A8_AWQ_BETA_CFG,
mtq.INT8_SMOOTHQUANT_CFG,
mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG,
mtq.INT4_AWQ_CFG,
mtq.NVFP4_DEFAULT_CFG,
],
)
def test_data_parallel(need_2_gpus, config):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For data parallel and context parallel, do we really need to test all configs? Or testing one sufficient given that we have extensive tensor parallel tests?
Thoughts @realAsma @jenchen13

spawn_multiprocess_job(size=2, job=partial(_test_data_parallel_helper, config), backend="nccl")


# 3. Context Parallel Test
def _test_context_parallel_helper(config, rank, size):
initialize_for_megatron(
context_parallel_size=size, seed=SEED + rank
) # modify seed so data is different across ranks
model = MegatronModel(cp_size=size).cuda()

dp_cp_parallel_test_helper(model, config, get_context_parallel_group())


@pytest.mark.parametrize(
"config",
[
mtq.INT8_DEFAULT_CFG,
mtq.FP8_DEFAULT_CFG,
mtq.W4A8_AWQ_BETA_CFG,
mtq.INT8_SMOOTHQUANT_CFG,
mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG,
mtq.INT4_AWQ_CFG,
mtq.NVFP4_DEFAULT_CFG,
],
)
def test_context_parallel(need_2_gpus, config):
spawn_multiprocess_job(
size=2, job=partial(_test_context_parallel_helper, config), backend="nccl"
)


# 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs)
def _test_data_tensor_context_parallel_helper(config, rank, size):
initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED)
model = MegatronModel(tp_size=2, cp_size=2).cuda()

data_tensor_context_parallel_test_helper(
model,
config,
get_data_parallel_group(),
get_tensor_model_parallel_group(),
get_context_parallel_group(),
)


@pytest.mark.parametrize(
"config",
[
mtq.INT8_DEFAULT_CFG,
mtq.FP8_DEFAULT_CFG,
mtq.W4A8_AWQ_BETA_CFG,
mtq.INT8_SMOOTHQUANT_CFG,
mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG,
mtq.INT4_AWQ_CFG,
mtq.NVFP4_DEFAULT_CFG,
],
)
def test_data_tensor_context_parallel(need_8_gpus, config):
spawn_multiprocess_job(
size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl"
)


def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False):
"""Build the model."""

Expand Down
Loading
Loading