Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions examples/nemo_run/qat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,16 @@ The resulting exported checkpoint also is much smaller in memory at 6.4GB compar

You can run the example either locally or on a [Slurm cluster](ADVANCED.md).

To run the example locally, launch a [NeMo container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) with version 25.07 or higher. Clone the `TensorRT-Model-Optimizer` repository and `NeMo` repository (checkout a specific commit for NeMo), then mount it onto your docker container.
To run the example locally, launch a [NeMo container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) with version 25.09 or higher. Clone the `TensorRT-Model-Optimizer` repository and `NeMo` repository (checkout a specific commit for NeMo), then mount it onto your docker container.

- `git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git`
- `git clone https://github.com/NVIDIA-NeMo/NeMo.git && cd NeMo && git checkout 676ed1a`
- `git clone https://github.com/NVIDIA-NeMo/NeMo.git`
- `git clone https://github.com/NVIDIA/Megatron-LM.git`

Example docker command:

```bash
docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash
docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt -v /home/user/Megatron-LM:/opt/megatron-lm --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.09 bash
```

You will also need to set your Huggingface token with `export HF_TOKEN=<your-token>`. You may also need to enable write access to the docker container to the `examples/nemo_run` folder by doing `chmod 777 nemo_run` so that logs can be written.
Expand Down Expand Up @@ -92,7 +93,7 @@ In order to train using QAD, launch the example with `python qat/nemo_qat_flow.p
To perform QAD training, run:

```bash
python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment
python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment --tensor_parallelism 4
```

## Supported models
Expand Down
32 changes: 26 additions & 6 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,22 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
if not distributed_sync:
return

def sync_quantizer_amax_across_dp(quantizer, parallel_state):
def sync_quantizer_amax_across_dp_cp(quantizer, parallel_state):
"""Synchronize the amax across all ranks in the data parallel and context parallel groups."""
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
sync_quantizer_amax_across_dp(_q, parallel_state)
sync_quantizer_amax_across_dp_cp(_q, parallel_state)
return
if getattr(quantizer, "_amax", None) is not None:
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
quantizer.sync_amax_across_distributed_group(parallel_state.context_parallel_group)
# TODO: create sync_bias_across_distributed_group

for name, module in model.named_modules():
if isinstance(module, QuantModule):
for child in module.children():
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
sync_quantizer_amax_across_dp(child, module.parallel_state)

sync_quantizer_amax_across_dp_cp(child, module.parallel_state)
# TP sync:
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same

Expand Down Expand Up @@ -598,19 +599,38 @@ def forward(self, input, *args, **kwargs):
# This will also perform distributed amax sync for input_quantizers
max_calibrate(model, lambda model: None)

def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_group):
# Sync across Data Parallel (DP)
if data_parallel_group.is_initialized():
dist.all_reduce(
module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group
)
# Sync across Context Parallel (CP)
if context_parallel_group.is_initialized():
dist.all_reduce(
module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=context_parallel_group.group
)

for name, module in model.named_modules():
if (
is_quantized_linear(module)
and hasattr(module, "awq_lite")
and module.awq_lite.num_cache_steps > 0
):
# Hack: MoEs forward all tokens through all experts if _if_calib is True
module._if_calib = True
module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps

if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
torch.isnan(module.awq_lite.weight_scale)
):
module.awq_lite.is_enabled = False
# Hack: MoEs forward all tokens through all experts if _if_calib is True
module._if_calib = True
else:
sync_act_scale_across_dp_cp(
module,
module.parallel_state.data_parallel_group,
module.parallel_state.context_parallel_group,
)
Comment on lines 615 to 627
Copy link

@coderabbitai coderabbitai bot Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Past review concern about NaN propagation remains unresolved.

The previous review identified that NaN detection is performed locally (lines 624-627), but the synchronization call (lines 629-633) will still propagate NaNs from any rank to all ranks in the DP/CP group through the AVG operation. This creates an inconsistent state where:

  1. Rank A detects NaN → sets is_enabled = False locally
  2. Rank B has valid data → enters else block and calls sync
  3. During sync, Rank A's NaN propagates to Rank B via AVG operation
  4. Rank B now has NaN-corrupted act_scale but is_enabled = True

Consider one of these solutions:

  • Add a collective NaN check before sync: use all_reduce to check if ANY rank has NaN, then disable all ranks if true
  • Move the sync before the NaN check and perform NaN detection after synchronization
  • Guard the sync call: only call sync_act_scale_across_dp_cp when module.awq_lite.is_enabled is true (but this requires collective agreement on enabled state)

Based on past review comments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i now only sync the act_scale if no nans have been detected (it is after an else statement)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jenchen13 Thank you for the update. While moving the sync to the else block prevents syncing on ranks that detect NaN locally, there's a critical distributed synchronization issue that remains:

The Problem: Collective Operation Mismatch

The dist.all_reduce() call inside sync_act_scale_across_dp_cp is a collective operation that requires all ranks in the DP/CP group to participate. With the current implementation:

  • Rank A (has NaN): Detects NaN → sets is_enabled = Falseskips the else block → does NOT call all_reduce
  • Rank B (no NaN): No NaN detected → enters else block → calls all_reduce

This creates a deadlock because Rank B is waiting for Rank A to participate in the collective operation, but Rank A never makes the call.

Suggested Solution

You need to ensure all ranks agree on whether to sync before making the collective call:

# First, check if ANY rank has NaN
has_nan = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
    torch.isnan(module.awq_lite.weight_scale)
)

# Create a global NaN indicator across all ranks
has_nan_tensor = torch.tensor(int(has_nan), device=module.weight.device)
if module.parallel_state.data_parallel_group.is_initialized():
    dist.all_reduce(has_nan_tensor, op=dist.ReduceOp.MAX, 
                    group=module.parallel_state.data_parallel_group.group)
if module.parallel_state.context_parallel_group.is_initialized():
    dist.all_reduce(has_nan_tensor, op=dist.ReduceOp.MAX,
                    group=module.parallel_state.context_parallel_group.group)

if has_nan_tensor.item() > 0:
    # At least one rank has NaN - disable on ALL ranks
    module.awq_lite.is_enabled = False
else:
    # No rank has NaN - safe to sync on ALL ranks
    sync_act_scale_across_dp_cp(
        module,
        module.parallel_state.data_parallel_group,
        module.parallel_state.context_parallel_group,
    )

This ensures all ranks take the same code path and avoid the collective operation mismatch.


AWQLiteHelper.cache_mode = False
print_rank_0("awq_lite: Searching parameters...")
Expand Down
9 changes: 8 additions & 1 deletion modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import megatron.core.tensor_parallel.layers as megatron_parallel
import megatron.core.transformer.mlp as megatron_mlp
import torch
from megatron.core.parallel_state import get_data_parallel_group
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
Expand Down Expand Up @@ -217,9 +218,15 @@ class _MegatronParallelLinear(_ParallelLinear):
]

def _setup(self):
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
data_parallel_group = get_data_parallel_group()
self.parallel_state = ParallelState(
getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(),
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
mcore_parallel.get_context_parallel_group(),
)
Comment on lines 224 to 233
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard get_context_parallel_group() when CP is disabled

get_context_parallel_group() asserts that context parallelism was initialized. When we run TP/DP-only (the default in plenty of setups), that assertion fires and _MegatronParallelLinear._setup() will crash. Please mirror the DP guard and fall back to -1 (unused) when the call raises.

Something along these lines keeps the DP-only path working:

-        self.parallel_state = ParallelState(
-            data_parallel_group,
-            mcore_parallel.get_tensor_model_parallel_group(),
-            mcore_parallel.get_context_parallel_group(),
-        )
+        try:
+            context_parallel_group = mcore_parallel.get_context_parallel_group()
+        except AssertionError:
+            context_parallel_group = -1
+        self.parallel_state = ParallelState(
+            data_parallel_group,
+            mcore_parallel.get_tensor_model_parallel_group(),
+            context_parallel_group,
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
data_parallel_group = get_data_parallel_group()
self.parallel_state = ParallelState(
getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(),
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
mcore_parallel.get_context_parallel_group(),
)
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
data_parallel_group = get_data_parallel_group()
try:
context_parallel_group = mcore_parallel.get_context_parallel_group()
except AssertionError:
context_parallel_group = -1
self.parallel_state = ParallelState(
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
context_parallel_group,
)
🤖 Prompt for AI Agents
In modelopt/torch/quantization/plugins/megatron.py around lines 221 to 230, the
call to mcore_parallel.get_context_parallel_group() is unguarded and will assert
(and crash) when context-parallelism is disabled; mirror the data-parallel
guard: try to call get_context_parallel_group() and if it raises
(AssertionError) set the context group to -1 (or the sentinel used for
"unused"), then pass that value into ParallelState so TP/DP-only setups won't
fail. Ensure you only catch the assertion from the context-group call and keep
the existing fallback for get_data_parallel_group() unchanged.

super()._setup()

Expand Down
8 changes: 7 additions & 1 deletion modelopt/torch/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,19 @@ def __init__(
self,
data_parallel_group: torch.distributed.ProcessGroup | int | None = None,
tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
context_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
):
"""Initialize the parallel state."""
self.data_parallel_group = DistributedProcessGroup(data_parallel_group)
self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group)
self.context_parallel_group = DistributedProcessGroup(context_parallel_group)

def __repr__(self) -> str:
return f"data_parallel_group: {self.data_parallel_group}, tensor_parallel_group: {self.tensor_parallel_group}"
return (
f"data_parallel_group: {self.data_parallel_group}, "
f"tensor_parallel_group: {self.tensor_parallel_group}, "
f"context_parallel_group: {self.context_parallel_group}"
)


def get_group(ranks: list[int]):
Expand Down
14 changes: 11 additions & 3 deletions tests/_test_utils/torch_dist/plugins/megatron_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@


class MegatronModel(MegatronModule):
def __init__(self, tp_size: int = 1, use_te_norm: bool = False):
def __init__(self, tp_size: int = 1, cp_size: int = 1, use_te_norm: bool = False):
config = TransformerConfig(
tensor_model_parallel_size=tp_size,
context_parallel_size=cp_size,
pipeline_model_parallel_size=1,
normalization="LayerNorm",
# Unused parameters below are set to avoid ZeroDivisionError in __post_init__
Expand Down Expand Up @@ -383,13 +384,20 @@ def run_mcore_inference_with_dummy_input(


def initialize_for_megatron(
tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
seed=1234,
context_parallel_size=1,
):
"""Initialize Megatron model parallelism.

NOTE: If used in a non-spawned process, make sure to call `megatron.core.parallel_state.destroy_model_parallel()`.
"""
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size)
initialize_model_parallel(
tensor_model_parallel_size,
pipeline_model_parallel_size,
context_parallel_size=context_parallel_size,
)
model_parallel_cuda_manual_seed(seed)


Expand Down
155 changes: 142 additions & 13 deletions tests/_test_utils/torch_quantization/quantize_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from unittest.mock import patch

import pytest
import torch
Expand All @@ -22,7 +23,9 @@

import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
import modelopt.torch.quantization.model_calib as model_calib_module # needed for patching awq_lite
from modelopt.torch.quantization.backends.gemm_registry import enable_real_quant_gemm
from modelopt.torch.quantization.nn.modules.tensor_quantizer import SequentialQuantizer
from modelopt.torch.quantization.utils import is_quantized_linear
from modelopt.torch.utils import torch_to

Expand Down Expand Up @@ -116,40 +119,166 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N
mto.restore_from_modelopt_state(model_ref, state_dict)


def tensor_parallel_test_helper(model, config, tp_group, dp_group):
# The input to fist layer, the column parallel should be the same across all tp ranks
def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None):
quantizer_attr = getattr(quantizer, attr).clone()
print("quantizer.attr before reduce", getattr(quantizer, attr))
dist.all_reduce(quantizer_attr, op=op, group=group)
print("quantizer.attr after reduce", getattr(quantizer, attr))
print("quantizer_attr after reduce", quantizer_attr)
assert torch.allclose(quantizer_attr, getattr(quantizer, attr))

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix default arg type and reduce noisy prints.

  • attr default should be a string annotation, not the str type.
  • Unconditional prints will spam across ranks. Gate on rank 0 or remove.
-def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None):
+def _reduce_quantizer_attr(quantizer, attr: str, op=dist.ReduceOp.MAX, group=None):
     quantizer_attr = getattr(quantizer, attr).clone()
-    print("quantizer.attr before reduce", getattr(quantizer, attr))
+    # Optional: guard debug prints or remove
+    # if dist.is_initialized() and dist.get_rank() == 0:
+    #     print("quantizer.attr before reduce", getattr(quantizer, attr))
     dist.all_reduce(quantizer_attr, op=op, group=group)
-    print("quantizer.attr after reduce", getattr(quantizer, attr))
-    print("quantizer_attr after reduce", quantizer_attr)
+    # if dist.is_initialized() and dist.get_rank() == 0:
+    #     print("quantizer.attr after reduce", getattr(quantizer, attr))
+    #     print("quantizer_attr after reduce", quantizer_attr)
     assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None):
quantizer_attr = getattr(quantizer, attr).clone()
print("quantizer.attr before reduce", getattr(quantizer, attr))
dist.all_reduce(quantizer_attr, op=op, group=group)
print("quantizer.attr after reduce", getattr(quantizer, attr))
print("quantizer_attr after reduce", quantizer_attr)
assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
def _reduce_quantizer_attr(quantizer, attr: str, op=dist.ReduceOp.MAX, group=None):
quantizer_attr = getattr(quantizer, attr).clone()
# Optional: guard debug prints or remove
# if dist.is_initialized() and dist.get_rank() == 0:
# print("quantizer.attr before reduce", getattr(quantizer, attr))
dist.all_reduce(quantizer_attr, op=op, group=group)
# if dist.is_initialized() and dist.get_rank() == 0:
# print("quantizer.attr after reduce", getattr(quantizer, attr))
# print("quantizer_attr after reduce", quantizer_attr)
assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
🤖 Prompt for AI Agents
In tests/_test_utils/torch_quantization/quantize_common.py around lines 122-129,
the function signature incorrectly uses attr=str (making the default the str
type) and unconditionally prints from every rank; change the signature to
annotate attr as a string (e.g. def _reduce_quantizer_attr(quantizer, attr: str,
op=dist.ReduceOp.MAX, group=None):) so attr is typed properly (no stray
default), and remove or gate the print statements behind a single rank (e.g.
only print when dist.is_initialized() and dist.get_rank() == 0) to avoid
spamming across ranks; keep the clone, all_reduce, and assertion logic
unchanged.


original_awq_lite = model_calib_module.awq_lite


def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True):
"""Function to mock awq_lite function to always use debug=True for testing"""
return original_awq_lite(model, forward_loop, alpha_step, debug=True)

Comment on lines +134 to +137
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Forward the AWQ-Lite kwargs in the patch

The _debug_awq_lite wrapper drops every extra keyword argument that callers pass to awq_lite (e.g., tensor_parallel_group, data_parallel_group, max_calib_steps). The upstream API explicitly accepts **kwargs, so the first call that includes one of those options will now raise a TypeError, breaking AWQ-Lite calibration in the very paths this PR is exercising. Please mirror the original signature and forward **kwargs to original_awq_lite while forcing debug=True.

-def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True):
-    """Function to mock awq_lite function to always use debug=True for testing"""
-    return original_awq_lite(model, forward_loop, alpha_step, debug=True)
+def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True, **kwargs):
+    """Force awq_lite debug mode during tests without dropping optional args."""
+    return original_awq_lite(
+        model,
+        forward_loop,
+        alpha_step=alpha_step,
+        debug=True,
+        **kwargs,
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True):
"""Function to mock awq_lite function to always use debug=True for testing"""
return original_awq_lite(model, forward_loop, alpha_step, debug=True)
def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True, **kwargs):
"""Force awq_lite debug mode during tests without dropping optional args."""
return original_awq_lite(
model,
forward_loop,
alpha_step=alpha_step,
debug=True,
**kwargs,
)
🤖 Prompt for AI Agents
In tests/_test_utils/torch_quantization/quantize_common.py around lines 134 to
137, the _debug_awq_lite wrapper drops any extra keyword arguments callers pass
to awq_lite which causes TypeError when upstream calls include options like
tensor_parallel_group or max_calib_steps; update the wrapper to mirror the
original awq_lite signature by accepting *args and **kwargs (or the same
explicit params plus **kwargs) and forward them to original_awq_lite while
forcing debug=True (i.e., call original_awq_lite(..., debug=True, **kwargs) so
all upstream options are preserved).


@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite)
def tensor_parallel_test_helper(model, config, tp_group, mock_awq_lite):
# The input to first layer, the column parallel should be the same across all tp ranks
calib_data = model.get_dummy_input().cuda()
dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group)

def forward_loop(model):
model(calib_data)

model = mtq.quantize(model, config, forward_loop)

# Sanity check
forward_loop(model)

if config in [mtq.INT8_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT8_SMOOTHQUANT_CFG]:
# Lets check the amax for row parallel input quantizer; it should be the same across all tp ranks
activation_amax = model.fc2.input_quantizer.amax.clone()
dist.all_reduce(activation_amax, op=dist.ReduceOp.MAX, group=tp_group)
assert torch.allclose(activation_amax, model.fc2.input_quantizer.amax)

_reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group)
# Lets check the row parallel weight amax; it should be the same across all tp ranks
weight_amax = model.fc2.weight_quantizer.amax.clone()
dist.all_reduce(weight_amax, op=dist.ReduceOp.MAX, group=tp_group)
assert torch.allclose(weight_amax, model.fc2.weight_quantizer.amax)
_reduce_quantizer_attr(
model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group
)

if config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
# Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks
input_quantizer = model.fc1.input_quantizer
pre_quant_scale = input_quantizer.pre_quant_scale.clone()
dist.all_reduce(pre_quant_scale, op=dist.ReduceOp.MAX, group=tp_group)
assert torch.allclose(pre_quant_scale, input_quantizer.pre_quant_scale)
_reduce_quantizer_attr(
input_quantizer, "pre_quant_scale", dist.ReduceOp.MAX, group=tp_group
)

if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
# Check activation scale for AWQ lite
_reduce_quantizer_attr(
Copy link
Contributor Author

@jenchen13 jenchen13 Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@realAsma For TP, I only test fc1 (column parallel) act scale during awq lite, because fc2 row parallel will fail. For DP/CP I can test both column + row parallel act scale. I'm assuming row parallel fails because it's split across the c_in dimension in activation ... is this right?

model.fc1.awq_lite,
"act_scale",
dist.ReduceOp.AVG,
group=tp_group,
)
# TODO fc2 assert is failing
"""
_reduce_quantizer_attr(
model.fc2.awq_lite, "act_scale", dist.ReduceOp.AVG, group=tp_group,
)
"""

dist.destroy_process_group()


@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite)
def dp_cp_parallel_test_helper(model, config, group, mock_awq_lite):
calib_data = model.get_dummy_input().cuda()

def forward_loop(model):
model(calib_data)

model = mtq.quantize(model, config, forward_loop)

# Input quantizer amax
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
_reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group)
_reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=group)

# Weight quantizer amax
if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
for quantizer in model.fc1.weight_quantizer:
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group)
else:
_reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group)
if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
for quantizer in model.fc2.weight_quantizer:
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group)
else:
_reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group)

if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
# Check act scale
_reduce_quantizer_attr(
model.fc1.awq_lite,
"act_scale",
dist.ReduceOp.AVG,
group=group,
)
_reduce_quantizer_attr(
model.fc2.awq_lite,
"act_scale",
dist.ReduceOp.AVG,
group=group,
)


@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite)
def data_tensor_context_parallel_test_helper(
model, config, dp_group, tp_group, cp_group, mock_awq_lite
):
calib_data = model.get_dummy_input().cuda()
# data should be same across each TP rank
dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group)

def forward_loop(model):
model(calib_data)

model = mtq.quantize(model, config, forward_loop)

def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX):
quantizer_attr = getattr(quantizer, attr).clone()
print("quantizer_attr before reduce", quantizer_attr)
print("quantizer.attr before reduce", getattr(quantizer, attr))
dist.all_reduce(quantizer_attr, op=op, group=dp_group)
dist.all_reduce(quantizer_attr, op=op, group=cp_group)
dist.all_reduce(quantizer_attr, op=op, group=tp_group)
print("quantizer_attr after reduce", quantizer_attr)
print("quantizer.attr after reduce", getattr(quantizer, attr))
assert torch.allclose(quantizer_attr, getattr(quantizer, attr))

# Input quantizer amax
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
_reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX)
_reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX)

if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
for quantizer in model.fc1.weight_quantizer:
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX)
else:
_reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX)

if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
for quantizer in model.fc2.weight_quantizer:
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX)
else:
_reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX)

# Check act scale
if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
_reduce_quantizer_attr(
model.fc1.awq_lite,
"act_scale",
dist.ReduceOp.AVG,
)
_reduce_quantizer_attr(
model.fc2.awq_lite,
"act_scale",
dist.ReduceOp.AVG,
)


def auto_quantize_helper(model):
model, search_state = mtq.auto_quantize(
model,
Expand Down
6 changes: 6 additions & 0 deletions tests/gpu/torch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def need_2_gpus():
pytest.skip("Need at least 2 GPUs to run this test")


@pytest.fixture
def need_8_gpus():
if torch.cuda.device_count() < 8:
pytest.skip("Need at least 8 GPUs to run this test")


@pytest.fixture(scope="module")
def set_torch_dtype(request):
orig_dtype = torch.get_default_dtype()
Expand Down
Loading
Loading