Skip to content

Conversation

@kinjalpatel27
Copy link
Contributor

@kinjalpatel27 kinjalpatel27 commented Oct 7, 2025

What does this PR do?

new feature
- Added support for quantizing TEGroupedMLP for Megatron-LM
- Added support for synchronize amax across experts for SequentialMLP
- Added support to synchronize amax across expert_model_parallel

Usage

pytest gpu/torch/quantization/plugins/test_megatron.py -k test_expert_parallel_sync
pytest gpu/torch/quantization/plugins/test_megatron.py -k test_expert_parallel_sync_with_tp
pytest gpu/torch/quantization/plugins/test_megatron.py -k test_te_grouped_vs_sequential_quantize
pytest gpu/torch/quantization/plugins/test_megatron.py -k test_moe_sharded_state_dict

Testing

  • Added tests for EP, ETP for amax sync
  • Added tests to compare outputs between SequentialMLP model TEGroupedMLP model, before and after quantization
  • Added tests for sharded state dict store and restore

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: -
  • Did you update Changelog?: Not yet

Additional Information

Summary by CodeRabbit

  • New Features

    • MoE quantization with cross-group amax synchronization across data-, tensor-, and expert-parallel groups
    • Transformer Engine (TE) GroupedLinear quantization and TE-aware quantization paths
    • Recognition of alternative weight layouts so more linear layers are detected as quantized
  • Tests

    • MoE-focused quantization tests, amax synchronization checks, utilities for grouped-vs-sequential comparisons, and a new 4-GPU test fixture

@copy-pr-bot
Copy link

copy-pr-bot bot commented Oct 7, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 7, 2025

Walkthrough

The changes add expert-model-parallel (EP) and MOE support to quantization: extend amax synchronization across DP+EP+TP and MOE local experts, integrate Transformer Engine GroupedLinear quantization, add Megatron MOE/TE-aware quantizers and registrations, extend distributed ParallelState with expert group, and expand tests/fixtures for MOE multi-GPU validation.

Changes

Cohort / File(s) Summary
Distributed infrastructure updates
modelopt/torch/utils/distributed.py
Add expert_model_parallel_group parameter to ParallelState, initialize self.expert_model_parallel_group via DistributedProcessGroup, and include it in __repr__.
Quantization calibration & utils
modelopt/torch/quantization/model_calib.py, modelopt/torch/quantization/utils.py
Rename/extend DP sync to sync_quantizer_amax_across_dp_ep to sync amax across data-parallel and expert-model-parallel groups; add post-TP pass calling module.sync_moe_local_experts_amax() when present; update is_quantized_linear() to accept weight or weight0.
Megatron plugin: MOE and TE integration
modelopt/torch/quantization/plugins/megatron.py
Make _setup lazy for parallel_state; add MOE-aware _MegatronSequentialMLP; add TE-aware quantized classes and registrations (_QuantMegatronTEGroupedLinear, _MegatronTEGroupedColumnParallelLinear, _MegatronTEGroupedRowParallelLinear, _MegatronTEGroupedMLP); propagate parallel_state to local experts and filter TE grouped-linear extra state on load.
Transformer Engine plugin
modelopt/torch/quantization/plugins/transformer_engine.py
Register TE GroupedLinear handler; add _QuantTEGroupedLinear that aliases weight0 to weight during setup/post-restore and provides a TE-aware quantized forward/apply function handling TE argument/layout differences.
Test utilities (Megatron)
tests/_test_utils/torch_dist/plugins/megatron_common.py
Extend helpers with EP/ETP parameters, num_moe_experts, moe_grouped_gemm, and use_te; add copy_weights_from_grouped_to_non_grouped and compare_amax_sync_across_expert_parallel; propagate expert-parallel sizes into model initialization.
Test fixtures
tests/gpu/torch/conftest.py
Add need_4_gpus pytest fixture (skips when CUDA device count < 4).
Megatron quantization tests
tests/gpu/torch/quantization/plugins/test_megatron.py
Expand _gpt_model_provider and _test_sharded_state_dict signatures for MOE/TE; add MOE-focused tests (MOE sharded state-dict, TE-grouped vs sequential quantize, expert-parallel amax sync); import new utilities and MLP types.

Sequence Diagram(s)

sequenceDiagram
    participant Calib as Calibration Flow
    participant DP as DataParallel
    participant EP as ExpertParallel
    participant TP as TensorParallel
    participant MOE as MOE Local Experts

    Calib->>DP: sync_quantizer_amax_across_dp_ep()
    DP->>DP: collect & reduce amax across DP ranks
    DP->>EP: synchronize amax across expert_model_parallel_group
    Calib->>TP: sync_quantizer_amax_across_tp() (tensor parallel)
    Calib->>MOE: for each module with sync_moe_local_experts_amax -> call it
    MOE->>MOE: sync local-expert amax across EP/ETP ranks
    Note over Calib: Amax synchronized across DP, EP, TP, and MOE local experts
Loading
sequenceDiagram
    participant Model as Quantized MOE Model
    participant TE as TransformerEngine GroupedLinear
    participant Seq as Sequential MLP
    participant Qt as Quantizers

    Model->>TE: forward(GroupedLinear)
    TE->>Qt: quantize input
    TE->>Qt: quantize weight (uses `weight0` alias during setup)
    Qt-->>TE: return quant params
    TE->>Model: execute grouped linear op
    Model->>Seq: forward(Sequential MLP)
    Seq->>Qt: input & weight quantize (standard path)
    Seq->>Model: sync amax across expert-parallel groups when calibrating
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐰
I hopped through syncs of DP and EP,
counted amax under moonlit spree,
grouped weights whispered via TE,
experts aligned in tidy rows,
quant dreams bloom where the rabbit goes. 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The PR title "Added support for quantizing TEGroupedMLP for megatron-lm" accurately identifies and focuses on the primary deliverable of this changeset. While the PR introduces several supporting infrastructure changes—including expert parallel group synchronization in distributed.py, amax synchronization logic in model_calib.py, and various MOE-related classes in megatron.py—the core objective stated in the PR summary is to enable quantization of TEGroupedMLP models for Megatron-LM. The title clearly conveys this main outcome and is specific enough that a developer reviewing the commit history would understand the key enablement being added. The supporting changes (expert parallel synchronization, grouped linear support) serve as enabling mechanisms rather than the primary deliverable.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch kinjal/grouped_linear

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Oct 7, 2025

Codecov Report

❌ Patch coverage is 87.50000% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 73.38%. Comparing base (99c44d3) to head (ca55348).
⚠️ Report is 12 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/model_calib.py 85.71% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main     #403   +/-   ##
=======================================
  Coverage   73.37%   73.38%           
=======================================
  Files         180      180           
  Lines       17925    17942   +17     
=======================================
+ Hits        13152    13166   +14     
- Misses       4773     4776    +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kinjalpatel27 kinjalpatel27 force-pushed the kinjal/grouped_linear branch 2 times, most recently from 22bfe0e to 1c821d8 Compare October 8, 2025 23:54
@realAsma realAsma changed the base branch from main to jennifchen/cp_amax_sync October 9, 2025 15:53
@kinjalpatel27 kinjalpatel27 force-pushed the kinjal/grouped_linear branch from e2858f9 to 4d7dbce Compare October 9, 2025 16:49
)


def _test_expert_model_parallel_amax_sync(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not need register_custom_post_calibration_plugins. Lets not introduce new infrastructure un-necessarily.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the point of post_calibration plugins now. Let's keep them as we discussed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change looks good!

Base automatically changed from jennifchen/cp_amax_sync to main October 10, 2025 23:16
Signed-off-by: Jennifer Chen <[email protected]>
Signed-off-by: Jennifer Chen <[email protected]>
Signed-off-by: Jennifer Chen <[email protected]>
Signed-off-by: Jennifer Chen <[email protected]>
Signed-off-by: Jennifer Chen <[email protected]>
Signed-off-by: Jennifer Chen <[email protected]>
Signed-off-by: Jennifer Chen <[email protected]>
Signed-off-by: Kinjal Patel <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
tests/gpu/torch/quantization/plugins/test_megatron.py (2)

540-564: Cap spawn size to the required topology dimensions.

Spawning torch.cuda.device_count() ranks (line 541) violates the MOE topology constraints when the device count exceeds tp_size * ep_size = 4. For instance, a 6-GPU host will attempt world size 6 with TP=2, yielding DP=3, which fails Megatron's requirement that DP is divisible by EP=2. The need_4_gpus fixture only ensures a minimum, not an exact match.

Apply this diff to restrict the spawn size:

-def test_moe_sharded_state_dict(need_4_gpus, tmp_path, config, moe_grouped_gemm):
-    size = torch.cuda.device_count()
+def test_moe_sharded_state_dict(tmp_path, config, moe_grouped_gemm):
+    device_count = torch.cuda.device_count()
+    required_size = 4  # tp_size * ep_size = 2 * 2
+    if device_count < required_size:
+        pytest.skip(f"Requires exactly {required_size} GPUs, found {device_count}")
+    size = required_size
     # TODO: Add support for compress=True for TEGroupedMLP

699-718: Enforce valid world size for MOE topology.

The skip condition size < ep_size * etp_size is insufficient. With TP=2 (hardcoded at line 711), EP=2, and device_count=6, the test spawns 6 ranks, yielding DP=3, which violates Megatron's data_parallel_size % ep_size == 0 constraint. You must ensure the world size is an exact multiple of tp_size * ep_size.

Apply this diff to fix the topology check:

 def test_expert_parallel_sync(ep_size, etp_size, moe_grouped_gemm):
     """Test expert model parallel synchronization."""
-    size = torch.cuda.device_count()
-    if size < ep_size * etp_size:
-        pytest.skip(f"Requires at least {ep_size * etp_size} GPUs for expert model parallel test")
+    device_count = torch.cuda.device_count()
+    tp_size = 2  # hardcoded in the partial call below
+    required_size = tp_size * ep_size
+    if device_count < required_size:
+        pytest.skip(f"Requires at least {required_size} GPUs (TP={tp_size}, EP={ep_size})")
+    # Use the largest valid multiple that doesn't exceed device_count
+    size = (device_count // required_size) * required_size
 
     spawn_multiprocess_job(
         size=size,
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)

563-609: Critical: Multi-element amax tensors still break this checker.

Despite the past review comment indicating this was resolved, the code at line 566 still calls module.amax.item(), which raises RuntimeError for per-channel quantizers that produce multi-element tensors. Additionally, lines 607–609 use Python's max() and min() on a list that could contain tensors (when .item() is skipped via the else branch), causing a TypeError.

This breaks synchronization checks for exactly the expert-parallel cases this utility is meant to cover.

Apply this diff to handle multi-element tensors correctly:

             if "local_experts" in name or ("experts" in name and "linear_fc" in name):
-                amax_val = module.amax.item() if hasattr(module.amax, "item") else module.amax
-                expert_amax_values[name] = amax_val
+                expert_amax_values[name] = module.amax.detach().clone().cpu()
...
     for quantizer_type, rank_values in expert_quantizers.items():
         if len(rank_values) > 1:  # Only check if we have multiple ranks
-            values = list(rank_values.values())
-            max_diff = max(values) - min(values)
-            if max_diff > 1e-6:  # Allow for small floating point differences
+            values = [v.flatten() for v in rank_values.values()]
+            stacked = torch.stack(values)
+            max_diff = (stacked.max(dim=0).values - stacked.min(dim=0).values).max().item()
+            if max_diff > 1e-6:  # Allow for small floating point differences
                 return False, quantizer_type, rank_values
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4919b08 and 23daf38.

📒 Files selected for processing (3)
  • modelopt/torch/quantization/plugins/megatron.py (4 hunks)
  • tests/_test_utils/torch_dist/plugins/megatron_common.py (10 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (9 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)
modelopt/torch/utils/distributed.py (2)
  • size (61-65)
  • world_size (204-206)
modelopt/torch/trace/symbols.py (1)
  • named_modules (444-447)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • amax (236-241)
  • amax (244-255)
tests/gpu/torch/quantization/plugins/test_megatron.py (4)
tests/_test_utils/torch_dist/plugins/megatron_common.py (4)
  • compare_amax_sync_across_expert_parallel (544-611)
  • copy_weights_from_grouped_to_non_grouped (511-541)
  • get_mcore_gpt_model (147-242)
  • initialize_for_megatron (425-444)
modelopt/torch/utils/plugins/megatron_generate.py (1)
  • megatron_prefill (41-130)
modelopt/torch/quantization/model_calib.py (1)
  • max_calibrate (62-177)
modelopt/torch/quantization/plugins/megatron.py (1)
  • sync_amax_across_sequential_mlp (51-81)
modelopt/torch/quantization/plugins/megatron.py (4)
modelopt/torch/quantization/plugins/custom.py (3)
  • _ParallelLinear (83-181)
  • _setup (114-122)
  • modelopt_post_restore (124-181)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • TensorQuantizer (65-1111)
  • amax (236-241)
  • amax (244-255)
modelopt/torch/trace/symbols.py (3)
  • named_modules (444-447)
  • register (289-324)
  • items (434-437)
modelopt/torch/utils/distributed.py (1)
  • ParallelState (232-257)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (2)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)

511-542: Weight and bias mapping looks correct.

The template correctly omits the hardcoded .weight suffix, and param_name is properly extracted and appended to the sequential key. This addresses the past review concern about bias parameters mapping incorrectly.

modelopt/torch/quantization/plugins/megatron.py (1)

51-84: LGTM! Amax synchronization correctly uses tensor operations.

The function properly:

  • Uses torch.maximum() for element-wise maximum across experts (lines 72–75)
  • Assigns the synchronized amax back to each expert with proper device placement (lines 78–81)
  • Handles both collection and distribution phases cleanly

This addresses the past review concern about using Python's max() instead of tensor operations.

@cjluo-nv cjluo-nv requested a review from sugunav14 October 14, 2025 16:45
__all__ = []


def sync_amax_across_sequential_mlp(model: torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should do this only for per-tensor amax

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

per-channel weight amax: element wise maximum for RowParallel (fc2 - RowParallel Cout dim is shared across experts)
per-channel weight amax for ColumnParallel - No Op

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do in the followup MR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: per-tensor amax also works now. I have modified the test case to correctly check that.

Signed-off-by: Kinjal Patel <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)

642-650: Cap spawn size to 4 ranks for EP×ETP topology.

Line 645 spawns all available GPUs, but the test topology (EP=2, ETP=2) requires exactly 4 ranks. On hosts with more than 4 GPUs, the excess ranks break Megatron's parallelism constraints.

Apply this diff:

 def test_te_grouped_vs_sequential_quantize(need_4_gpus):
     """Test that TEGrouped and sequential MoE models produce similar quantized models."""
     pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently")
-    size = torch.cuda.device_count()
+    required_size = 4  # ep_size * etp_size = 2 * 2
+    size = required_size
     spawn_multiprocess_job(

714-737: Cap spawn size to match expert-parallel topology.

Line 719 spawns all available GPUs, but the test parametrization requires at most 4 ranks (max of EP×ETP across test cases is 2×2=4). On hosts with 6+ GPUs, the extra ranks violate Megatron's topology constraints.

Apply this diff:

 def test_expert_parallel_sync(config, ep_size, etp_size, moe_grouped_gemm):
     """Test expert model parallel synchronization."""
-    size = torch.cuda.device_count()
+    device_count = torch.cuda.device_count()
+    required_size = ep_size * etp_size
-    if size < ep_size * etp_size:
-        pytest.skip(f"Requires at least {ep_size * etp_size} GPUs for expert model parallel test")
+    if device_count < required_size:
+        pytest.skip(f"Requires at least {required_size} GPUs for expert model parallel test")
+    size = required_size
 
     if moe_grouped_gemm:

542-569: Cap spawn size to match EP×TP topology.

Line 545 spawns torch.cuda.device_count() ranks, but the test requires exactly 4 ranks (TP=2, EP=2). On a host with 6+ GPUs, the extra ranks violate Megatron's topology constraints (e.g., data-parallel size must be divisible by EP).

Apply this diff:

 def test_moe_sharded_state_dict(need_4_gpus, tmp_path, config, moe_grouped_gemm):
     if moe_grouped_gemm:
         pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently")
-    size = torch.cuda.device_count()
+    required_size = 4  # tp_size * ep_size = 2 * 2
+    size = required_size
     # TODO: Add support for compress=True for TEGroupedMLP
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8bff6b0 and 5481d10.

📒 Files selected for processing (3)
  • modelopt/torch/quantization/plugins/megatron.py (4 hunks)
  • tests/_test_utils/torch_dist/plugins/megatron_common.py (10 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (8 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/quantization/plugins/megatron.py
🧰 Additional context used
🧬 Code graph analysis (2)
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)
modelopt/torch/utils/distributed.py (2)
  • size (61-65)
  • world_size (204-206)
modelopt/torch/trace/symbols.py (1)
  • named_modules (444-447)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • TensorQuantizer (65-1111)
  • amax (236-241)
  • amax (244-255)
tests/gpu/torch/quantization/plugins/test_megatron.py (4)
tests/_test_utils/torch_dist/plugins/megatron_common.py (4)
  • compare_amax_sync_across_expert_parallel (552-663)
  • copy_weights_from_grouped_to_non_grouped (519-549)
  • get_mcore_gpt_model (149-250)
  • initialize_for_megatron (433-452)
modelopt/torch/utils/plugins/megatron_generate.py (1)
  • megatron_prefill (41-130)
modelopt/torch/quantization/model_calib.py (1)
  • max_calibrate (62-181)
modelopt/torch/quantization/plugins/megatron.py (1)
  • sync_moe_local_experts_amax (508-533)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (10)
tests/_test_utils/torch_dist/plugins/megatron_common.py (5)

16-17: LGTM!

The new imports (re and defaultdict) are used in the helper functions added later in the file.


43-46: LGTM!

The expert-parallel imports are correctly added and used in compare_amax_sync_across_expert_parallel.


152-246: LGTM!

The MOE and expert-parallel parameters are correctly propagated through get_mcore_gpt_model and into the model configuration and initialization flow.


438-451: LGTM!

The expert-parallel parameters are correctly added to initialize_for_megatron and propagated to initialize_model_parallel.


519-550: LGTM!

The weight-copying logic correctly handles both weight and bias parameters, and the past review concern about the hardcoded .weight in the template has been properly addressed.

tests/gpu/torch/quantization/plugins/test_megatron.py (5)

24-25: LGTM!

The new helper imports from megatron_common are correctly added and used in the MOE tests below.


47-48: LGTM!

The Megatron MOE imports are correctly added for use in the new MOE-focused tests.


234-282: LGTM!

The MOE-related parameters are correctly added to _gpt_model_provider and properly propagated to the model construction calls.


285-357: LGTM!

The moe_config parameter is correctly integrated into _test_sharded_state_dict, and MOE parameters are properly extracted and propagated to model initialization and provider calls.


572-640: LGTM!

The test helper correctly creates both TEGrouped and Sequential MOE models, copies weights between them, and validates output equivalence before and after quantization.

Signed-off-by: Kinjal Patel <[email protected]>
Signed-off-by: Kinjal Patel <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 91837c3 and ca55348.

📒 Files selected for processing (3)
  • modelopt/torch/quantization/model_calib.py (3 hunks)
  • modelopt/torch/quantization/plugins/transformer_engine.py (2 hunks)
  • modelopt/torch/quantization/utils.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-18T20:15:04.615Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#332
File: modelopt/torch/quantization/algorithms.py:323-326
Timestamp: 2025-09-18T20:15:04.615Z
Learning: In modelopt/torch/quantization/algorithms.py, the `_is_auto_quantize_module` method requires `isinstance(module, QuantModule)` because some modules like MCore Column/Row Parallel Linear are `QuantModule` but not `QuantLinearConvBase`. The check ensures all quantization-capable modules are included in AutoQuantize search.

Applied to files:

  • modelopt/torch/quantization/utils.py
🧬 Code graph analysis (2)
modelopt/torch/quantization/model_calib.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • SequentialQuantizer (1114-1222)
  • sync_amax_across_distributed_group (1071-1083)
  • TensorQuantizer (65-1111)
modelopt/torch/quantization/plugins/megatron.py (1)
  • sync_moe_local_experts_amax (508-533)
modelopt/torch/quantization/plugins/transformer_engine.py (1)
modelopt/torch/quantization/plugins/custom.py (1)
  • _ParallelLinear (76-174)

Comment on lines +180 to +182
if hasattr(module, "sync_moe_local_experts_amax"):
module.sync_moe_local_experts_amax()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard MOE expert sync behind an initialized process group

max_calibrate is invoked in single-process flows. The new call into module.sync_moe_local_experts_amax() executes torch.distributed.barrier() unconditionally, so on a non-initialized default group this now throws RuntimeError: Default process group has not been initialized. Please gate this loop on dist.is_available() / dist.is_initialized() (or make the callee accept a group handle) so single-process calibration keeps working.

-    for name, module in model.named_modules():
-        if hasattr(module, "sync_moe_local_experts_amax"):
-            module.sync_moe_local_experts_amax()
+    if dist.is_available() and dist.is_initialized():
+        for name, module in model.named_modules():
+            if hasattr(module, "sync_moe_local_experts_amax"):
+                module.sync_moe_local_experts_amax()
🤖 Prompt for AI Agents
In modelopt/torch/quantization/model_calib.py around lines 180-182, the call to
module.sync_moe_local_experts_amax() triggers a torch.distributed.barrier()
without checking if the default process group is initialized, causing errors in
single-process runs. Fix this by wrapping the call with a guard that checks if
torch.distributed.is_available() and torch.distributed.is_initialized() return
True before invoking the method, ensuring it only runs when the distributed
backend is properly set up.

Comment on lines +72 to +115
def _setup(self):
# GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
# self.weight0 to self.weight to run the quantizer states initialization.
assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear"
self.weight = self.weight0
# Memorize the original weight.dtype for modelopt_post_restore given that
# the dtype can change later.
super()._setup()
# Remove self.weight after setup.
delattr(self, "weight")

def modelopt_post_restore(self, prefix: str = ""):
# GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
# self.weight0 to self.weight to run the quantizer states initialization.
assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear"
self.weight = self.weight0
super().modelopt_post_restore(prefix=prefix)
# Remove self.weight after post_restore.
delattr(self, "weight")

@staticmethod
def te_grouped_quantized_linear_fn(package, func_name, self, *args):
idx = 1 if func_name == "_forward" else 0
inp = args[idx]
num_gemms = len(args[idx + 1])
weights_and_biases = args[-2 * num_gemms :]
weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:]
quantized_inputs = self.input_quantizer(inp)
quantized_weights = [self.weight_quantizer(weight) for weight in weights]

output = getattr(package, func_name)(
*(
args[0],
quantized_inputs,
)
if func_name == "_forward"
else (quantized_inputs,),
*args[idx + 1 : -2 * num_gemms],
*quantized_weights,
*biases,
)
return self.output_quantizer(output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Expose a stable .weight view for grouped TE layers

With is_quantized_linear() now recognizing modules that only provide weight0, helpers such as smoothquant, disable_pre_quant_scale_and_resmooth, etc. immediately access module.weight. Because _QuantTEGroupedLinear deletes that alias after setup, those helpers now hit AttributeError and break the quantization flows for grouped TE models. Please keep a .weight view backed by weight0 (without registering a duplicate parameter) so the existing utilities continue to function.

     def _setup(self):
-        # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
-        # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
-        # self.weight0 to self.weight to run the quantizer states initialization.
-        assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear"
-        self.weight = self.weight0
+        # GroupedMLP stores the weights as weight0, weight1, etc. Use weight0 to drive quantizer setup.
+        assert "weight" not in self._parameters, "self.weight should not exist for TEGroupedLinear"
+        self.weight = self.weight0
         # Memorize the original weight.dtype for modelopt_post_restore given that
         # the dtype can change later.
         super()._setup()
-        # Remove self.weight after setup.
-        delattr(self, "weight")
+        # Setter below is a no-op so we do not register a duplicate Parameter named "weight".
@@
     def modelopt_post_restore(self, prefix: str = ""):
-        # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
-        # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
-        # self.weight0 to self.weight to run the quantizer states initialization.
-        assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear"
-        self.weight = self.weight0
+        # GroupedMLP stores the weights as weight0, weight1, etc. Reuse weight0 to drive post_restore.
+        assert "weight" not in self._parameters, "self.weight should not exist for TEGroupedLinear"
+        self.weight = self.weight0
         super().modelopt_post_restore(prefix=prefix)
-        # Remove self.weight after post_restore.
-        delattr(self, "weight")
+        # Setter below keeps weight0 as the canonical tensor.
+
+    @property
+    def weight(self):
+        return self.weight0
+
+    @weight.setter
+    def weight(self, value):
+        if value is not self.weight0:
+            raise ValueError("TEGroupedLinear expects weight0 to back the canonical weight parameter.")
🤖 Prompt for AI Agents
In modelopt/torch/quantization/plugins/transformer_engine.py around lines 72 to
115, the current implementation temporarily assigns self.weight to self.weight0
during setup and post_restore, then deletes self.weight afterward. This deletion
causes AttributeError in utilities that expect a stable .weight attribute. To
fix this, keep a persistent .weight property backed by self.weight0 without
deleting it so that .weight remains accessible, ensuring compatibility with
helpers relying on this attribute.

@kinjalpatel27 kinjalpatel27 merged commit 6ef9954 into main Oct 17, 2025
27 checks passed
@kinjalpatel27 kinjalpatel27 deleted the kinjal/grouped_linear branch October 17, 2025 21:22
yeyu-nvidia pushed a commit that referenced this pull request Dec 8, 2025
soodoshll pushed a commit to soodoshll/TensorRT-Model-Optimizer that referenced this pull request Dec 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants