Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/nemo_run/qat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ In order to train using QAD, launch the example with `python qat/nemo_qat_flow.p
To perform QAD training, run:

```bash
python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment
python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment --tensor_parallelism 4
```

## Supported models
Expand Down
32 changes: 26 additions & 6 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,22 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
if not distributed_sync:
return

def sync_quantizer_amax_across_dp(quantizer, parallel_state):
def sync_quantizer_amax_across_dp_cp(quantizer, parallel_state):
"""Synchronize the amax across all ranks in the data parallel and context parallel groups."""
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
sync_quantizer_amax_across_dp(_q, parallel_state)
sync_quantizer_amax_across_dp_cp(_q, parallel_state)
return
if getattr(quantizer, "_amax", None) is not None:
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
quantizer.sync_amax_across_distributed_group(parallel_state.context_parallel_group)
# TODO: create sync_bias_across_distributed_group

for name, module in model.named_modules():
if isinstance(module, QuantModule):
for child in module.children():
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
sync_quantizer_amax_across_dp(child, module.parallel_state)

sync_quantizer_amax_across_dp_cp(child, module.parallel_state)
# TP sync:
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same

Expand Down Expand Up @@ -598,19 +599,38 @@ def forward(self, input, *args, **kwargs):
# This will also perform distributed amax sync for input_quantizers
max_calibrate(model, lambda model: None)

def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_group):
# Sync across Data Parallel (DP)
if data_parallel_group.is_initialized():
dist.all_reduce(
module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group
)
# Sync across Context Parallel (CP)
if context_parallel_group.is_initialized():
dist.all_reduce(
module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=context_parallel_group.group
)

for name, module in model.named_modules():
if (
is_quantized_linear(module)
and hasattr(module, "awq_lite")
and module.awq_lite.num_cache_steps > 0
):
# Hack: MoEs forward all tokens through all experts if _if_calib is True
module._if_calib = True
module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps

if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
torch.isnan(module.awq_lite.weight_scale)
):
module.awq_lite.is_enabled = False
# Hack: MoEs forward all tokens through all experts if _if_calib is True
module._if_calib = True
else:
sync_act_scale_across_dp_cp(
module,
module.parallel_state.data_parallel_group,
module.parallel_state.context_parallel_group,
)
Comment on lines 615 to 627
Copy link

@coderabbitai coderabbitai bot Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Past review concern about NaN propagation remains unresolved.

The previous review identified that NaN detection is performed locally (lines 624-627), but the synchronization call (lines 629-633) will still propagate NaNs from any rank to all ranks in the DP/CP group through the AVG operation. This creates an inconsistent state where:

  1. Rank A detects NaN → sets is_enabled = False locally
  2. Rank B has valid data → enters else block and calls sync
  3. During sync, Rank A's NaN propagates to Rank B via AVG operation
  4. Rank B now has NaN-corrupted act_scale but is_enabled = True

Consider one of these solutions:

  • Add a collective NaN check before sync: use all_reduce to check if ANY rank has NaN, then disable all ranks if true
  • Move the sync before the NaN check and perform NaN detection after synchronization
  • Guard the sync call: only call sync_act_scale_across_dp_cp when module.awq_lite.is_enabled is true (but this requires collective agreement on enabled state)

Based on past review comments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i now only sync the act_scale if no nans have been detected (it is after an else statement)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jenchen13 Thank you for the update. While moving the sync to the else block prevents syncing on ranks that detect NaN locally, there's a critical distributed synchronization issue that remains:

The Problem: Collective Operation Mismatch

The dist.all_reduce() call inside sync_act_scale_across_dp_cp is a collective operation that requires all ranks in the DP/CP group to participate. With the current implementation:

  • Rank A (has NaN): Detects NaN → sets is_enabled = Falseskips the else block → does NOT call all_reduce
  • Rank B (no NaN): No NaN detected → enters else block → calls all_reduce

This creates a deadlock because Rank B is waiting for Rank A to participate in the collective operation, but Rank A never makes the call.

Suggested Solution

You need to ensure all ranks agree on whether to sync before making the collective call:

# First, check if ANY rank has NaN
has_nan = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
    torch.isnan(module.awq_lite.weight_scale)
)

# Create a global NaN indicator across all ranks
has_nan_tensor = torch.tensor(int(has_nan), device=module.weight.device)
if module.parallel_state.data_parallel_group.is_initialized():
    dist.all_reduce(has_nan_tensor, op=dist.ReduceOp.MAX, 
                    group=module.parallel_state.data_parallel_group.group)
if module.parallel_state.context_parallel_group.is_initialized():
    dist.all_reduce(has_nan_tensor, op=dist.ReduceOp.MAX,
                    group=module.parallel_state.context_parallel_group.group)

if has_nan_tensor.item() > 0:
    # At least one rank has NaN - disable on ALL ranks
    module.awq_lite.is_enabled = False
else:
    # No rank has NaN - safe to sync on ALL ranks
    sync_act_scale_across_dp_cp(
        module,
        module.parallel_state.data_parallel_group,
        module.parallel_state.context_parallel_group,
    )

This ensures all ranks take the same code path and avoid the collective operation mismatch.


AWQLiteHelper.cache_mode = False
print_rank_0("awq_lite: Searching parameters...")
Expand Down
9 changes: 8 additions & 1 deletion modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import megatron.core.tensor_parallel.layers as megatron_parallel
import megatron.core.transformer.mlp as megatron_mlp
import torch
from megatron.core.parallel_state import get_data_parallel_group
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
Expand Down Expand Up @@ -217,9 +218,15 @@ class _MegatronParallelLinear(_ParallelLinear):
]

def _setup(self):
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
data_parallel_group = get_data_parallel_group()
self.parallel_state = ParallelState(
getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(),
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
mcore_parallel.get_context_parallel_group(),
)
Comment on lines 224 to 233
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard get_context_parallel_group() when CP is disabled

get_context_parallel_group() asserts that context parallelism was initialized. When we run TP/DP-only (the default in plenty of setups), that assertion fires and _MegatronParallelLinear._setup() will crash. Please mirror the DP guard and fall back to -1 (unused) when the call raises.

Something along these lines keeps the DP-only path working:

-        self.parallel_state = ParallelState(
-            data_parallel_group,
-            mcore_parallel.get_tensor_model_parallel_group(),
-            mcore_parallel.get_context_parallel_group(),
-        )
+        try:
+            context_parallel_group = mcore_parallel.get_context_parallel_group()
+        except AssertionError:
+            context_parallel_group = -1
+        self.parallel_state = ParallelState(
+            data_parallel_group,
+            mcore_parallel.get_tensor_model_parallel_group(),
+            context_parallel_group,
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
data_parallel_group = get_data_parallel_group()
self.parallel_state = ParallelState(
getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(),
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
mcore_parallel.get_context_parallel_group(),
)
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
data_parallel_group = get_data_parallel_group()
try:
context_parallel_group = mcore_parallel.get_context_parallel_group()
except AssertionError:
context_parallel_group = -1
self.parallel_state = ParallelState(
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
context_parallel_group,
)
🤖 Prompt for AI Agents
In modelopt/torch/quantization/plugins/megatron.py around lines 221 to 230, the
call to mcore_parallel.get_context_parallel_group() is unguarded and will assert
(and crash) when context-parallelism is disabled; mirror the data-parallel
guard: try to call get_context_parallel_group() and if it raises
(AssertionError) set the context group to -1 (or the sentinel used for
"unused"), then pass that value into ParallelState so TP/DP-only setups won't
fail. Ensure you only catch the assertion from the context-group call and keep
the existing fallback for get_data_parallel_group() unchanged.

super()._setup()

Expand Down
8 changes: 7 additions & 1 deletion modelopt/torch/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,19 @@ def __init__(
self,
data_parallel_group: torch.distributed.ProcessGroup | int | None = None,
tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
context_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
):
"""Initialize the parallel state."""
self.data_parallel_group = DistributedProcessGroup(data_parallel_group)
self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group)
self.context_parallel_group = DistributedProcessGroup(context_parallel_group)

def __repr__(self) -> str:
return f"data_parallel_group: {self.data_parallel_group}, tensor_parallel_group: {self.tensor_parallel_group}"
return (
f"data_parallel_group: {self.data_parallel_group}, "
f"tensor_parallel_group: {self.tensor_parallel_group}, "
f"context_parallel_group: {self.context_parallel_group}"
)


def get_group(ranks: list[int]):
Expand Down
14 changes: 11 additions & 3 deletions tests/_test_utils/torch_dist/plugins/megatron_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@


class MegatronModel(MegatronModule):
def __init__(self, tp_size: int = 1, use_te_norm: bool = False):
def __init__(self, tp_size: int = 1, cp_size: int = 1, use_te_norm: bool = False):
config = TransformerConfig(
tensor_model_parallel_size=tp_size,
context_parallel_size=cp_size,
pipeline_model_parallel_size=1,
normalization="LayerNorm",
# Unused parameters below are set to avoid ZeroDivisionError in __post_init__
Expand Down Expand Up @@ -383,13 +384,20 @@ def run_mcore_inference_with_dummy_input(


def initialize_for_megatron(
tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
seed=1234,
context_parallel_size=1,
):
"""Initialize Megatron model parallelism.

NOTE: If used in a non-spawned process, make sure to call `megatron.core.parallel_state.destroy_model_parallel()`.
"""
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size)
initialize_model_parallel(
tensor_model_parallel_size,
pipeline_model_parallel_size,
context_parallel_size=context_parallel_size,
)
model_parallel_cuda_manual_seed(seed)


Expand Down
75 changes: 73 additions & 2 deletions tests/_test_utils/torch_quantization/quantize_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.backends.gemm_registry import enable_real_quant_gemm
from modelopt.torch.quantization.nn.modules.tensor_quantizer import SequentialQuantizer
from modelopt.torch.quantization.utils import is_quantized_linear
from modelopt.torch.utils import torch_to

Expand Down Expand Up @@ -116,8 +117,8 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N
mto.restore_from_modelopt_state(model_ref, state_dict)


def tensor_parallel_test_helper(model, config, tp_group, dp_group):
# The input to fist layer, the column parallel should be the same across all tp ranks
def tensor_parallel_test_helper(model, config, tp_group):
# The input to first layer, the column parallel should be the same across all tp ranks
calib_data = model.get_dummy_input().cuda()
dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group)

Expand Down Expand Up @@ -150,6 +151,76 @@ def forward_loop(model):
dist.destroy_process_group()


def dp_cp_parallel_test_helper(model, config, group):
calib_data = model.get_dummy_input().cuda()

def forward_loop(model):
model(calib_data)

model = mtq.quantize(model, config, forward_loop)

def reduce_amax(quantizer):
amax = quantizer.amax.clone()
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=group)
assert torch.allclose(amax, quantizer.amax)

# Input quantizer amax
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
reduce_amax(model.fc1.input_quantizer)
reduce_amax(model.fc2.input_quantizer)

# Weight quantizer amax
if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
for quantizer in model.fc1.weight_quantizer:
reduce_amax(quantizer)
else:
reduce_amax(model.fc1.weight_quantizer)
if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
for quantizer in model.fc2.weight_quantizer:
reduce_amax(quantizer)
else:
reduce_amax(model.fc2.weight_quantizer)


def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group):
calib_data = model.get_dummy_input().cuda()
# data should be same across each TP rank
dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group)

def forward_loop(model):
model(calib_data)

model = mtq.quantize(model, config, forward_loop)

def reduce_amax(quantizer):
amax = quantizer.amax.clone()
print("amax before reduce", amax)
print("quantizer.amax before reduce", quantizer.amax)
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=dp_group)
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=cp_group)
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=tp_group)
print("amax after reduce", amax)
print("quantizer.amax after reduce", quantizer.amax)
assert torch.allclose(amax, quantizer.amax)

# Input quantizer amax
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
reduce_amax(model.fc1.input_quantizer)
reduce_amax(model.fc2.input_quantizer)

if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
for quantizer in model.fc1.weight_quantizer:
reduce_amax(quantizer)
else:
reduce_amax(model.fc1.weight_quantizer)

if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
for quantizer in model.fc2.weight_quantizer:
reduce_amax(quantizer)
else:
reduce_amax(model.fc2.weight_quantizer)


def auto_quantize_helper(model):
model, search_state = mtq.auto_quantize(
model,
Expand Down
6 changes: 6 additions & 0 deletions tests/gpu/torch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def need_2_gpus():
pytest.skip("Need at least 2 GPUs to run this test")


@pytest.fixture
def need_8_gpus():
if torch.cuda.device_count() < 8:
pytest.skip("Need at least 8 GPUs to run this test")


@pytest.fixture(scope="module")
def set_torch_dtype(request):
orig_dtype = torch.get_default_dtype()
Expand Down
94 changes: 90 additions & 4 deletions tests/gpu/torch/quantization/plugins/test_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from _test_utils.torch_quantization.quant_utils import get_model_size
from _test_utils.torch_quantization.quantize_common import (
auto_quantize_helper,
data_tensor_context_parallel_test_helper,
dp_cp_parallel_test_helper,
tensor_parallel_test_helper,
)
from packaging.version import Version
Expand All @@ -40,6 +42,7 @@
import megatron.core
from megatron.core.parallel_state import (
destroy_model_parallel,
get_context_parallel_group,
get_data_parallel_group,
get_tensor_model_parallel_group,
)
Expand Down Expand Up @@ -92,13 +95,12 @@ def test_convert_megatron_parallel_linear(distributed_setup_size_1):
destroy_model_parallel()


# 1. Tensor Parallel Test
def _test_tensor_parallel_helper(config, rank, size):
initialize_for_megatron(tensor_model_parallel_size=2, seed=SEED)
model = MegatronModel(size).cuda()
model = MegatronModel(tp_size=size).cuda()

tensor_parallel_test_helper(
model, config, get_tensor_model_parallel_group(), get_data_parallel_group()
)
tensor_parallel_test_helper(model, config, get_tensor_model_parallel_group())


@pytest.mark.parametrize(
Expand All @@ -119,6 +121,90 @@ def test_tensor_parallel(need_2_gpus, config):
)


# 2. Data Parallel Test
def _test_data_parallel_helper(config, rank, size):
initialize_for_megatron(seed=SEED + rank) # modify seed so data is different across ranks
model = MegatronModel().cuda()

dp_cp_parallel_test_helper(model, config, get_data_parallel_group())


@pytest.mark.parametrize(
"config",
[
mtq.INT8_DEFAULT_CFG,
mtq.FP8_DEFAULT_CFG,
mtq.W4A8_AWQ_BETA_CFG,
mtq.INT8_SMOOTHQUANT_CFG,
mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG,
mtq.INT4_AWQ_CFG,
mtq.NVFP4_DEFAULT_CFG,
],
)
def test_data_parallel(need_2_gpus, config):
spawn_multiprocess_job(size=2, job=partial(_test_data_parallel_helper, config), backend="nccl")


# 3. Context Parallel Test
def _test_context_parallel_helper(config, rank, size):
initialize_for_megatron(
context_parallel_size=size, seed=SEED + rank
) # modify seed so data is different across ranks
model = MegatronModel(cp_size=size).cuda()

dp_cp_parallel_test_helper(model, config, get_context_parallel_group())


@pytest.mark.parametrize(
"config",
[
mtq.INT8_DEFAULT_CFG,
mtq.FP8_DEFAULT_CFG,
mtq.W4A8_AWQ_BETA_CFG,
mtq.INT8_SMOOTHQUANT_CFG,
mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG,
mtq.INT4_AWQ_CFG,
mtq.NVFP4_DEFAULT_CFG,
],
)
def test_context_parallel(need_2_gpus, config):
spawn_multiprocess_job(
size=2, job=partial(_test_context_parallel_helper, config), backend="nccl"
)


# 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs)
def _test_data_tensor_context_parallel_helper(config, rank, size):
initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED)
model = MegatronModel(tp_size=2, cp_size=2).cuda()

data_tensor_context_parallel_test_helper(
model,
config,
get_data_parallel_group(),
get_tensor_model_parallel_group(),
get_context_parallel_group(),
)


@pytest.mark.parametrize(
"config",
[
mtq.INT8_DEFAULT_CFG,
mtq.FP8_DEFAULT_CFG,
mtq.W4A8_AWQ_BETA_CFG,
mtq.INT8_SMOOTHQUANT_CFG,
mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG,
mtq.INT4_AWQ_CFG,
mtq.NVFP4_DEFAULT_CFG,
],
)
def test_data_tensor_context_parallel(need_8_gpus, config):
spawn_multiprocess_job(
size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl"
)


def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False):
"""Build the model."""

Expand Down
Loading
Loading