Skip to content

Conversation

jenchen13
Copy link
Contributor

@jenchen13 jenchen13 commented Sep 24, 2025

What does this PR do?

Type of change: ? New Feature

Overview: Sync quantizer amax in Context Parallelism & AWQ-Lite act_scale in CP/DP

Usage

# Add a code snippet demonstrating how to use this

Testing

  • tests for DP, CP, and DP/TP/CP combined
  • tests for AWQ lite act_scale

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Parallel-aware synchronization for quantization calibration, including activation-scale harmonization across data, tensor, and context parallel groups; safer Megatron parallel initialization with a logged fallback.
  • Documentation

    • Example README/docker updated: image tag bumped to 25.09+, Megatron‑LM repo mount added, and tensor_parallelism=4 flag documented for QAD runs.
  • Style

    • Improved parallel-state string formatting for clearer multi-line output.
  • Tests

    • Expanded DP/TP/CP test coverage with new helpers, seeded inputs, an 8-GPU fixture, and additional parallel test suites.

Copy link

copy-pr-bot bot commented Sep 24, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Copy link

coderabbitai bot commented Sep 24, 2025

Walkthrough

Adds context-parallel (CP) synchronization for quantization calibration (amax and act_scale), updates Megatron plugin to explicitly retrieve a data-parallel group with a CP-aware getter (with fallback), changes ParallelState string representation, expands tests and fixtures to cover DP/TP/CP scenarios (including an 8‑GPU fixture), and updates an example README and docker invocation (image tag, Megatron-LM mount, tensor_parallelism flag).

Changes

Cohort / File(s) Summary
Quantization calibration CP sync
modelopt/torch/quantization/model_calib.py
Adds DP-synchronization and new CP-aware sync helpers for quantizer amax and AWQ-Lite act_scale; integrates synchronization into calibration and awq_lite paths; adds conditional bypass (MoE/hack) and inline docs.
Megatron plugin DP retrieval
modelopt/torch/quantization/plugins/megatron.py
Retrieves data_parallel_group via get_data_parallel_group(with_context_parallel=True) with AssertionError handling and fallback to get_data_parallel_group(); logs warning on fallback and passes explicit group into ParallelState setup.
ParallelState repr only
modelopt/torch/utils/distributed.py
Changes ParallelState.__repr__ formatting to a multi-line string that includes data_parallel_group and tensor_parallel_group; no structural API changes.
Quantization test helpers & patches
tests/_test_utils/torch_quantization/quantize_common.py
Adds _reduce_quantizer_attr, AWQ-Lite debug wrapper and patching, updates tensor_parallel_test_helper signature to accept mock_awq_lite, and adds dp_cp_parallel_test_helper and data_tensor_context_parallel_test_helper for DP/TP/CP verification.
Megatron test common: CP wiring
tests/_test_utils/torch_dist/plugins/megatron_common.py
Adds cp_size parameter to MegatronModel.__init__ and forwards context_parallel_size through initialize_for_megatroninitialize_model_parallel; get_dummy_input accepts optional seed.
Megatron quantization tests
tests/gpu/torch/quantization/plugins/test_megatron.py
Wires new DP/CP/DP+TP+CP test suites and helpers; adapts model instantiation to tp_size/cp_size; increases num_attention_heads in test models; exposes/uses new test helpers.
GPU test fixtures
tests/gpu/torch/conftest.py
Adds need_8_gpus pytest fixture that skips tests when fewer than 8 GPUs are available.
Docs & example runbook
examples/nemo_run/qat/README.md
Updates NeMo container image tag 25.0725.09, adds Megatron‑LM repo clone and mount path in docker command, and adds --tensor_parallelism 4 flag to QAD training invocation.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Calib as Calibration Loop
  participant Module as Quantized Module
  participant Q as Tensor/SequentialQuantizer
  participant DP as Data-Parallel Group
  participant CP as Context-Parallel Group

  Calib->>Module: collect stats (forward)
  Module->>Q: update amax / act_scale
  Note over Module,Q: after per-batch accumulation
  alt Data-parallel sync
    Module->>DP: all-reduce amax / act_scale
  end
  alt Context-parallel sync
    Module->>CP: all-reduce amax / act_scale
  end
  Q-->>Module: synchronized scales
  Module-->>Calib: continue/finalize calibration
Loading
sequenceDiagram
  autonumber
  participant Init as Megatron Plugin Init
  participant MPG as megatron.core.parallel_state
  participant PS as ParallelState

  Init->>MPG: get_data_parallel_group(with_context_parallel=true)
  alt context-parallel available
    MPG-->>Init: dp_group (CP-aware)
  else fallback
    Init->>MPG: get_data_parallel_group()
    MPG-->>Init: dp_group (fallback)
  end
  Init->>MPG: get_tensor_model_parallel_group()
  Init->>MPG: get_context_parallel_group()
  Init->>PS: __init__(data_parallel_group=dp_group, tensor_parallel_group=tp_group, ...)
  PS-->>Init: ParallelState ready
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

A rabbit nudges bytes with nimble paws,
CP, DP, and TP bounce through the laws.
Scales sync up, quantizers sing,
Tests hop onward—CI takes wing.
Docker spins, Megatron hums—hooray, applause! 🐇✨

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Title Check ✅ Passed The title concisely and accurately describes the core feature added—synchronizing quantizer amax values and AWQ-Lite activation scales across context- and data-parallel groups—so it clearly summarizes the main change.
Docstring Coverage ✅ Passed Docstring coverage is 10.00% which is insufficient. The required threshold is 80.00%.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch jennifchen/cp_amax_sync

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

codecov bot commented Sep 24, 2025

Codecov Report

❌ Patch coverage is 66.66667% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.78%. Comparing base (cb44c55) to head (3f857a3).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/model_calib.py 66.66% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #359      +/-   ##
==========================================
- Coverage   73.79%   73.78%   -0.01%     
==========================================
  Files         171      171              
  Lines       17591    17596       +5     
==========================================
+ Hits        12982    12984       +2     
- Misses       4609     4612       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@realAsma
Copy link
Contributor

@jenchen13 could you please add unit tests for context parallel quantization (similar to tensor parallel) to here -

def test_tensor_parallel(need_2_gpus, config):

basically the TP test checks whether amax is similar across the TP group. see

def tensor_parallel_test_helper(model, config, tp_group, dp_group):

Copy link

copy-pr-bot bot commented Sep 24, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Signed-off-by: Jennifer Chen <[email protected]>
@jenchen13 jenchen13 force-pushed the jennifchen/cp_amax_sync branch from e764e79 to 42519cc Compare September 25, 2025 18:57
Signed-off-by: Jennifer Chen <[email protected]>
@jenchen13 jenchen13 force-pushed the jennifchen/cp_amax_sync branch from aa5b8fd to 264adbb Compare September 25, 2025 21:46
@jenchen13 jenchen13 changed the title sync amax in context parallel Sync amax & AWQ-Lite in context parallel/data parallel Sep 25, 2025
@jenchen13 jenchen13 changed the title Sync amax & AWQ-Lite in context parallel/data parallel Sync amax & AWQ-Lite act_scale in context parallel/data parallel Sep 25, 2025
@jenchen13 jenchen13 marked this pull request as ready for review September 25, 2025 23:23
@jenchen13 jenchen13 requested review from a team as code owners September 25, 2025 23:23
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/_test_utils/torch_quantization/quantize_common.py (1)

132-149: Replace config-based guards with attribute-presence checks and sync AWQ pre_quant_scale
• For amax:

if model.fc2.input_quantizer.amax is not None:
    activation_amax = model.fc2.input_quantizer.amax.clone()
    dist.all_reduce(activation_amax, op=dist.ReduceOp.MAX, group=tp_group)
    assert torch.allclose(activation_amax, model.fc2.input_quantizer.amax)

• For scales (SmoothQuant, AWQ/AWQ-Lite):

# input scale
if (scale := model.fc1.input_quantizer.pre_quant_scale) is not None:
    scale_clone = scale.clone()
    dist.all_reduce(scale_clone, op=dist.ReduceOp.MAX, group=tp_group)
    assert torch.allclose(scale_clone, scale)
# weight scale (AWQ-Lite)
if (wscale := model.fc1.weight_quantizer.pre_quant_scale) is not None:
    wscale_clone = wscale.clone()
    dist.all_reduce(wscale_clone, op=dist.ReduceOp.MAX, group=tp_group)
    assert torch.allclose(wscale_clone, wscale)

Drop all if config in […] checks.

🧹 Nitpick comments (2)
modelopt/torch/quantization/model_calib.py (1)

628-651: Weight DP/CP averages by token counts

Right now we average act_scale equally across ranks. In mixed-workload runs (e.g., MoE routing) we can see uneven num_tokens, so the lighter ranks end up pulling the mean down. Since we already track num_tokens, we can switch to a weighted reduction (sum of scale * tokens and sum of tokens) before normalizing.

A sketch:

-    module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps
-    sync_act_scale_across_dp_cp(...)
+    module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps
+    token_count = torch.tensor(
+        [module.awq_lite.num_tokens],
+        device=module.awq_lite.act_scale.device,
+        dtype=module.awq_lite.act_scale.dtype,
+    )
+    scale_sum = module.awq_lite.act_scale * token_count.item()
+    sync_reduction_across_dp_cp(scale_sum, token_count, module.parallel_state)
+    module.awq_lite.act_scale = scale_sum / token_count

(sync_reduction_across_dp_cp would all-reduce both tensors across DP/CP groups.)

tests/_test_utils/torch_quantization/quantize_common.py (1)

215-237: 3D helper: sync input pre_quant_scale across TP/CP/DP

  • Replace the config-based presence check with an attribute check for input amax (e.g. if getattr(model.fc1.input_quantizer, "amax", None) is not None) in tests/_test_utils/torch_quantization/quantize_common.py → data_tensor_context_parallel_test_helper.
  • If input_quantizer.pre_quant_scale is present, clone it, all_reduce (MAX) across tp_group, cp_group, dp_group and assert torch.allclose(reduced, input_quantizer.pre_quant_scale).
  • AWQ-Lite’s activation scale is stored on module.awq_lite.act_scale and model_calib syncs it using AVG across DP/CP — if you want to validate AWQ-Lite end-to-end, also check module.awq_lite.act_scale is synchronized (use the same group ops as model_calib).
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b4d6ced and 7cbe5b9.

📒 Files selected for processing (8)
  • examples/nemo_run/qat/README.md (1 hunks)
  • modelopt/torch/quantization/model_calib.py (2 hunks)
  • modelopt/torch/quantization/plugins/megatron.py (2 hunks)
  • modelopt/torch/utils/distributed.py (1 hunks)
  • tests/_test_utils/torch_dist/plugins/megatron_common.py (2 hunks)
  • tests/_test_utils/torch_quantization/quantize_common.py (2 hunks)
  • tests/gpu/torch/conftest.py (1 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/torch/quantization/plugins/megatron.py (2)
modelopt/torch/opt/dynamic.py (2)
  • parallel_state (876-878)
  • parallel_state (881-886)
modelopt/torch/utils/distributed.py (1)
  • ParallelState (232-256)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (4)
  • context_parallel_test_helper (179-202)
  • data_parallel_test_helper (153-176)
  • data_tensor_context_parallel_test_helper (205-237)
  • tensor_parallel_test_helper (119-150)
tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
  • initialize_for_megatron (386-398)
  • MegatronModel (85-131)
tests/gpu/torch/conftest.py (2)
  • need_2_gpus (32-34)
  • need_8_gpus (38-40)
tests/_test_utils/torch_quantization/quantize_common.py (2)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
  • get_dummy_input (130-131)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • amax (231-236)
  • amax (239-250)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
🔇 Additional comments (7)
modelopt/torch/quantization/model_calib.py (1)

82-98: Nice DP/CP amax sync integration

The recursive helper cleanly reuses the existing SequentialQuantizer handling while adding the CP hop, so both DP and CP replicas end up aligned. 👍

tests/gpu/torch/conftest.py (1)

37-40: need_8_gpus fixture looks good

The skip guard mirrors need_2_gpus, so multi-rank tests will short-circuit cleanly when the hardware isn’t there.

tests/gpu/torch/quantization/plugins/test_megatron.py (1)

176-205: Great coverage for the 2×2×2 scenario

Spinning up the combined DP/TP/CP path ensures the new sync logic is exercised end-to-end; thanks for wiring the groups explicitly.

tests/_test_utils/torch_quantization/quantize_common.py (2)

129-131: Good addition: post-quantization sanity forward.

Running a forward pass after quantize helps catch latent issues. Looks good.


205-212: LGTM on the sequential TP→CP→DP all-reduce pattern.

This correctly propagates maxima across orthogonal groups in 3D parallelism.

Please confirm group construction matches orthogonal decomposition (i.e., each rank belongs to exactly one group per dimension). If not, propagation may be incomplete.

examples/nemo_run/qat/README.md (1)

95-95: Doc command update looks good.

Nice to see the example reflecting the new tensor-parallel setup.

modelopt/torch/utils/distributed.py (1)

244-256: Context group wiring looks solid.

ParallelState now mirrors DP/TP/CP consistently, so downstream logging/debugging will show the full layout.

Comment on lines 221 to 230
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
data_parallel_group = get_data_parallel_group()
self.parallel_state = ParallelState(
getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(),
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
mcore_parallel.get_context_parallel_group(),
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard get_context_parallel_group() when CP is disabled

get_context_parallel_group() asserts that context parallelism was initialized. When we run TP/DP-only (the default in plenty of setups), that assertion fires and _MegatronParallelLinear._setup() will crash. Please mirror the DP guard and fall back to -1 (unused) when the call raises.

Something along these lines keeps the DP-only path working:

-        self.parallel_state = ParallelState(
-            data_parallel_group,
-            mcore_parallel.get_tensor_model_parallel_group(),
-            mcore_parallel.get_context_parallel_group(),
-        )
+        try:
+            context_parallel_group = mcore_parallel.get_context_parallel_group()
+        except AssertionError:
+            context_parallel_group = -1
+        self.parallel_state = ParallelState(
+            data_parallel_group,
+            mcore_parallel.get_tensor_model_parallel_group(),
+            context_parallel_group,
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
data_parallel_group = get_data_parallel_group()
self.parallel_state = ParallelState(
getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(),
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
mcore_parallel.get_context_parallel_group(),
)
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
data_parallel_group = get_data_parallel_group()
try:
context_parallel_group = mcore_parallel.get_context_parallel_group()
except AssertionError:
context_parallel_group = -1
self.parallel_state = ParallelState(
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
context_parallel_group,
)
🤖 Prompt for AI Agents
In modelopt/torch/quantization/plugins/megatron.py around lines 221 to 230, the
call to mcore_parallel.get_context_parallel_group() is unguarded and will assert
(and crash) when context-parallelism is disabled; mirror the data-parallel
guard: try to call get_context_parallel_group() and if it raises
(AssertionError) set the context group to -1 (or the sentinel used for
"unused"), then pass that value into ParallelState so TP/DP-only setups won't
fail. Ensure you only catch the assertion from the context-group call and keep
the existing fallback for get_data_parallel_group() unchanged.

Signed-off-by: Jennifer Chen <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7cbe5b9 and 1f7d17e.

📒 Files selected for processing (2)
  • tests/_test_utils/torch_quantization/quantize_common.py (3 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/_test_utils/torch_quantization/quantize_common.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • SequentialQuantizer (1146-1254)
  • amax (231-236)
  • amax (239-250)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (3)
  • data_tensor_context_parallel_test_helper (185-217)
  • dp_cp_parallel_test_helper (154-182)
  • tensor_parallel_test_helper (120-151)
tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
  • initialize_for_megatron (386-398)
  • MegatronModel (85-131)
tests/gpu/torch/conftest.py (2)
  • need_2_gpus (32-34)
  • need_8_gpus (38-40)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (2)
tests/_test_utils/torch_quantization/quantize_common.py (2)

151-151: Remove process group teardown from helper

We’re still tearing down the global process group inside the helper, which the test harness owns. Subsequent jobs in the same worker will hit a destroyed NCCL group. Please drop this call.

-    dist.destroy_process_group()

162-182: Guard quantizer sync for AWQ-Lite and validate act_scale

AWQ-Lite quantizers set is_mx_format, so TensorQuantizer.amax returns None (see modelopt/torch/quantization/nn/modules/tensor_quantizer.py lines 230-235). The current helper unconditionally clones .amax, which will raise AttributeError as soon as we run the AWQ-Lite DP/CP test. Even when it doesn’t explode, we never exercise the newly synchronized act_scale, so regressions slide through. Please guard for missing amax and also check pre_quant_scale/act_scale when present.

-    def reduce_amax(quantizer):
-        amax = quantizer.amax.clone()
-        dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=group)
-        assert torch.allclose(amax, quantizer.amax)
+    def reduce_quant_attrs(quantizer):
+        validated = False
+        for attr in ("amax", "pre_quant_scale", "act_scale"):
+            tensor = getattr(quantizer, attr, None)
+            if tensor is None:
+                continue
+            synced = tensor.detach().clone()
+            dist.all_reduce(synced, op=dist.ReduceOp.MAX, group=group)
+            assert torch.allclose(synced, tensor), f"{attr} mismatch on {quantizer}"
+            validated = True
+        if not validated:
+            pytest.fail(f"{quantizer} exposed no distributed attribute to validate")
@@
-        reduce_amax(model.fc1.input_quantizer)
-        reduce_amax(model.fc2.input_quantizer)
+        reduce_quant_attrs(model.fc1.input_quantizer)
+        reduce_quant_attrs(model.fc2.input_quantizer)
@@
-            reduce_amax(quantizer)
+            reduce_quant_attrs(quantizer)
@@
-        reduce_amax(model.fc1.weight_quantizer)
+        reduce_quant_attrs(model.fc1.weight_quantizer)
@@
-            reduce_amax(quantizer)
+            reduce_quant_attrs(quantizer)
@@
-        reduce_amax(model.fc2.weight_quantizer)
+        reduce_quant_attrs(model.fc2.weight_quantizer)

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
modelopt/torch/quantization/model_calib.py (1)

602-612: Consider adding a docstring for consistency.

The act_scale synchronization logic is correct (using AVG for statistics is appropriate), but the function lacks a docstring.

Apply this diff to add a docstring:

 def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_group):
+    """Synchronize the act_scale across all ranks in the data parallel and context parallel groups."""
     # Sync across Data Parallel (DP)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1f7d17e and 71a9f7a.

📒 Files selected for processing (2)
  • modelopt/torch/quantization/model_calib.py (2 hunks)
  • modelopt/torch/quantization/plugins/megatron.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/quantization/plugins/megatron.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/model_calib.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • SequentialQuantizer (1114-1222)
  • sync_amax_across_distributed_group (1071-1083)
  • TensorQuantizer (65-1111)
  • is_enabled (395-397)
modelopt/torch/utils/distributed.py (2)
  • is_initialized (49-51)
  • is_initialized (196-198)
modelopt/torch/quantization/utils.py (1)
  • is_quantized_linear (246-256)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (2)
modelopt/torch/quantization/model_calib.py (2)

83-91: LGTM! CP support properly added to amax synchronization.

The extension from DP-only to DP+CP synchronization is implemented correctly. The function handles both SequentialQuantizer recursion and direct TensorQuantizer sync, and the docstring accurately reflects the new behavior.


98-98: Correct update to use CP-aware synchronization.

The call site properly updated to invoke the new CP-aware sync function.

Signed-off-by: Jennifer Chen <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 71a9f7a and d02365c.

📒 Files selected for processing (5)
  • modelopt/torch/quantization/model_calib.py (2 hunks)
  • tests/_test_utils/torch_dist/plugins/megatron_common.py (2 hunks)
  • tests/_test_utils/torch_quantization/quantize_common.py (3 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (4 hunks)
  • tests/gpu/torch/quantization/test_model_calib.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/_test_utils/torch_dist/plugins/megatron_common.py
🧰 Additional context used
🧬 Code graph analysis (4)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (3)
  • data_tensor_context_parallel_test_helper (185-221)
  • dp_cp_parallel_test_helper (154-182)
  • tensor_parallel_test_helper (120-151)
tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
  • initialize_for_megatron (386-401)
  • MegatronModel (85-131)
tests/gpu/torch/conftest.py (2)
  • need_2_gpus (32-34)
  • need_8_gpus (38-40)
tests/gpu/torch/quantization/test_model_calib.py (2)
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)
  • MegatronModel (85-131)
  • initialize_for_megatron (386-401)
  • get_dummy_input (130-131)
modelopt/torch/quantization/model_calib.py (1)
  • awq_lite (424-677)
tests/_test_utils/torch_quantization/quantize_common.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • SequentialQuantizer (1114-1222)
  • amax (236-241)
  • amax (244-255)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
  • get_dummy_input (130-131)
modelopt/torch/quantization/model_calib.py (4)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • SequentialQuantizer (1114-1222)
  • sync_amax_across_distributed_group (1071-1083)
  • TensorQuantizer (65-1111)
  • is_enabled (395-397)
modelopt/torch/trace/symbols.py (1)
  • named_modules (444-447)
modelopt/torch/quantization/nn/modules/quant_module.py (1)
  • QuantModule (37-96)
modelopt/torch/utils/distributed.py (2)
  • is_initialized (49-51)
  • is_initialized (196-198)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (13)
modelopt/torch/quantization/model_calib.py (2)

83-98: LGTM! Well-structured CP-aware amax synchronization.

The implementation correctly extends amax synchronization to context parallel groups while maintaining backward compatibility with data parallel. The recursive handling of SequentialQuantizer is appropriate, and using ReduceOp.MAX for both DP and CP groups ensures correct semantics (maximum amax across all ranks).


602-613: LGTM! Correct use of averaging for act_scale synchronization.

The function properly synchronizes act_scale using ReduceOp.AVG across both DP and CP groups, which is the correct reduction operation for averaging activation scales. The guard checks ensure synchronization only occurs when the groups are initialized.

tests/gpu/torch/quantization/test_model_calib.py (1)

32-33: LGTM! Test setup is correct.

The test properly uses spawn_multiprocess_job with 2 GPUs and NCCL backend for distributed testing.

tests/gpu/torch/quantization/plugins/test_megatron.py (5)

34-35: LGTM!

The new imports are necessary for the DP/CP test helpers and context-parallel group retrieval used in the tests below.

Also applies to: 45-45


101-103: LGTM!

The explicit tp_size keyword argument improves clarity, and removing dp_group from the tensor_parallel_test_helper call aligns with the updated signature in quantize_common.py.


124-130: Per-rank seed is overridden; test won't catch broken DP sync.

Passing SEED + rank to initialize_for_megatron is overridden by the internal call to model_parallel_cuda_manual_seed(seed) (see tests/_test_utils/torch_dist/plugins/megatron_common.py, lines 385-400), so all ranks still produce identical get_dummy_input() activations. The test will pass even if DP synchronization is broken. Introduce a rank-dependent perturbation after initialization—e.g., reseed or add a small offset before calling dp_cp_parallel_test_helper.

Based on past review comments.


148-156: Per-rank seed is overridden; test won't catch broken CP sync.

The same issue from the DP test applies here: initialize_for_megatron internally calls model_parallel_cuda_manual_seed(seed) with the provided seed, overriding the per-rank divergence you intended with SEED + rank. All CP ranks will produce identical calibration data, so the test won't fail if CP synchronization regresses. Add a rank-dependent perturbation after initialization.

Based on past review comments.


176-187: Fixed seed produces identical calibration data; test won't catch broken DP/CP sync.

Line 178 uses SEED without rank-dependent divergence. Since initialize_for_megatron calls model_parallel_cuda_manual_seed(SEED) uniformly across all 8 ranks, every rank will produce identical get_dummy_input() activations, so the assertions in data_tensor_context_parallel_test_helper will pass even if DP or CP synchronization is broken. Introduce rank-dependent perturbation (e.g., SEED + rank + 1) or add a small offset after initialization to ensure different calibration data per DP/CP rank.

Based on past review comments.

tests/_test_utils/torch_quantization/quantize_common.py (5)

26-26: LGTM!

The SequentialQuantizer import is necessary for the new helpers to handle multi-format weight quantization correctly.


120-120: LGTM!

Removing the unused dp_group parameter simplifies the signature; the function only validates tensor-parallel synchronization.


151-151: Remove global process group destruction from helper.

This call unconditionally tears down the global process group, breaking any subsequent DP/CP/TP tests that run in the same process. The test harness owns the process group lifecycle. Remove this line.

Based on past review comments.

Apply this diff:

-    dist.destroy_process_group()

154-183: Guard amax access and add AWQ-Lite scale validation.

Line 163 unconditionally accesses quantizer.amax, which returns None for MX formats (see modelopt/torch/quantization/nn/modules/tensor_quantizer.py, line 237) and will crash on .clone(). The config equality check at line 168 is brittle and misses future configs. Additionally, the PR objective includes synchronizing AWQ-Lite act_scale, but this helper doesn't validate it.

Based on past review comments.

Replace config checks with attribute guards and add scale validation:

-    def reduce_amax(quantizer):
-        amax = quantizer.amax.clone()
-        dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=group)
-        assert torch.allclose(amax, quantizer.amax)
+    def reduce_quant_attrs(quantizer):
+        for attr in ("amax", "pre_quant_scale", "act_scale"):
+            tensor = getattr(quantizer, attr, None)
+            if tensor is not None:
+                synced = tensor.clone()
+                dist.all_reduce(synced, op=dist.ReduceOp.MAX, group=group)
+                assert torch.allclose(synced, tensor), f"{attr} mismatch"
 
     # Input quantizer amax
-    if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
-        reduce_amax(model.fc1.input_quantizer)
-        reduce_amax(model.fc2.input_quantizer)
+    if hasattr(model.fc1, "input_quantizer"):
+        reduce_quant_attrs(model.fc1.input_quantizer)
+    if hasattr(model.fc2, "input_quantizer"):
+        reduce_quant_attrs(model.fc2.input_quantizer)
 
     # Weight quantizer amax
     if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
         for quantizer in model.fc1.weight_quantizer:
-            reduce_amax(quantizer)
+            reduce_quant_attrs(quantizer)
     else:
-        reduce_amax(model.fc1.weight_quantizer)
+        reduce_quant_attrs(model.fc1.weight_quantizer)
     if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
         for quantizer in model.fc2.weight_quantizer:
-            reduce_amax(quantizer)
+            reduce_quant_attrs(quantizer)
     else:
-        reduce_amax(model.fc2.weight_quantizer)
+        reduce_quant_attrs(model.fc2.weight_quantizer)

185-222: Guard amax access, add AWQ-Lite scale validation, and remove debug prints.

Lines 196, 197-198, 202-203 have the same issues as dp_cp_parallel_test_helper: unconditional .amax.clone() crashes for MX formats, config equality checks are brittle, and AWQ-Lite act_scale/pre_quant_scale aren't validated. Additionally, the print statements at lines 197-198 and 202-203 should be removed or converted to logging for production.

Based on past review comments.

Apply the same attribute-guard pattern from dp_cp_parallel_test_helper, replicate for all three groups (dp_group, cp_group, tp_group), and remove print statements:

-    def reduce_amax(quantizer):
-        amax = quantizer.amax.clone()
-        print("amax before reduce", amax)
-        print("quantizer.amax before reduce", quantizer.amax)
-        dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=dp_group)
-        dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=cp_group)
-        dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=tp_group)
-        print("amax after reduce", amax)
-        print("quantizer.amax after reduce", quantizer.amax)
-        assert torch.allclose(amax, quantizer.amax)
+    def reduce_quant_attrs(quantizer):
+        for attr in ("amax", "pre_quant_scale", "act_scale"):
+            tensor = getattr(quantizer, attr, None)
+            if tensor is not None:
+                synced = tensor.clone()
+                for g in (dp_group, cp_group, tp_group):
+                    dist.all_reduce(synced, op=dist.ReduceOp.MAX, group=g)
+                assert torch.allclose(synced, tensor), f"{attr} mismatch"
 
     # Input quantizer amax
-    if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
-        reduce_amax(model.fc1.input_quantizer)
-        reduce_amax(model.fc2.input_quantizer)
+    if hasattr(model.fc1, "input_quantizer"):
+        reduce_quant_attrs(model.fc1.input_quantizer)
+    if hasattr(model.fc2, "input_quantizer"):
+        reduce_quant_attrs(model.fc2.input_quantizer)
 
     if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
         for quantizer in model.fc1.weight_quantizer:
-            reduce_amax(quantizer)
+            reduce_quant_attrs(quantizer)
     else:
-        reduce_amax(model.fc1.weight_quantizer)
+        reduce_quant_attrs(model.fc1.weight_quantizer)
 
     if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
         for quantizer in model.fc2.weight_quantizer:
-            reduce_amax(quantizer)
+            reduce_quant_attrs(quantizer)
     else:
-        reduce_amax(model.fc2.weight_quantizer)
+        reduce_quant_attrs(model.fc2.weight_quantizer)

Comment on lines 620 to 633
# Hack: MoEs forward all tokens through all experts if _if_calib is True
module._if_calib = True
module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps

if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
torch.isnan(module.awq_lite.weight_scale)
):
module.awq_lite.is_enabled = False
# Hack: MoEs forward all tokens through all experts if _if_calib is True
module._if_calib = True
else:
sync_act_scale_across_dp_cp(
module,
module.parallel_state.data_parallel_group,
module.parallel_state.context_parallel_group,
)
Copy link

@coderabbitai coderabbitai bot Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Past review concern about NaN propagation remains unresolved.

The previous review identified that NaN detection is performed locally (lines 624-627), but the synchronization call (lines 629-633) will still propagate NaNs from any rank to all ranks in the DP/CP group through the AVG operation. This creates an inconsistent state where:

  1. Rank A detects NaN → sets is_enabled = False locally
  2. Rank B has valid data → enters else block and calls sync
  3. During sync, Rank A's NaN propagates to Rank B via AVG operation
  4. Rank B now has NaN-corrupted act_scale but is_enabled = True

Consider one of these solutions:

  • Add a collective NaN check before sync: use all_reduce to check if ANY rank has NaN, then disable all ranks if true
  • Move the sync before the NaN check and perform NaN detection after synchronization
  • Guard the sync call: only call sync_act_scale_across_dp_cp when module.awq_lite.is_enabled is true (but this requires collective agreement on enabled state)

Based on past review comments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i now only sync the act_scale if no nans have been detected (it is after an else statement)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jenchen13 Thank you for the update. While moving the sync to the else block prevents syncing on ranks that detect NaN locally, there's a critical distributed synchronization issue that remains:

The Problem: Collective Operation Mismatch

The dist.all_reduce() call inside sync_act_scale_across_dp_cp is a collective operation that requires all ranks in the DP/CP group to participate. With the current implementation:

  • Rank A (has NaN): Detects NaN → sets is_enabled = Falseskips the else block → does NOT call all_reduce
  • Rank B (no NaN): No NaN detected → enters else block → calls all_reduce

This creates a deadlock because Rank B is waiting for Rank A to participate in the collective operation, but Rank A never makes the call.

Suggested Solution

You need to ensure all ranks agree on whether to sync before making the collective call:

# First, check if ANY rank has NaN
has_nan = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
    torch.isnan(module.awq_lite.weight_scale)
)

# Create a global NaN indicator across all ranks
has_nan_tensor = torch.tensor(int(has_nan), device=module.weight.device)
if module.parallel_state.data_parallel_group.is_initialized():
    dist.all_reduce(has_nan_tensor, op=dist.ReduceOp.MAX, 
                    group=module.parallel_state.data_parallel_group.group)
if module.parallel_state.context_parallel_group.is_initialized():
    dist.all_reduce(has_nan_tensor, op=dist.ReduceOp.MAX,
                    group=module.parallel_state.context_parallel_group.group)

if has_nan_tensor.item() > 0:
    # At least one rank has NaN - disable on ALL ranks
    module.awq_lite.is_enabled = False
else:
    # No rank has NaN - safe to sync on ALL ranks
    sync_act_scale_across_dp_cp(
        module,
        module.parallel_state.data_parallel_group,
        module.parallel_state.context_parallel_group,
    )

This ensures all ranks take the same code path and avoid the collective operation mismatch.

Signed-off-by: Jennifer Chen <[email protected]>
Signed-off-by: Jennifer Chen <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
tests/_test_utils/torch_quantization/quantize_common.py (3)

121-128: Consider removing debug print statements.

The _reduce_quantizer_attr helper is well-structured, but the print statements on lines 123, 125, and 126 will clutter test output. Consider removing them or wrapping them in a debug flag check.

Apply this diff to remove the print statements:

 def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None):
     quantizer_attr = getattr(quantizer, attr).clone()
-    print("quantizer.attr before reduce", getattr(quantizer, attr))
     dist.all_reduce(quantizer_attr, op=op, group=group)
-    print("quantizer.attr after reduce", getattr(quantizer, attr))
-    print("quantizer_attr after reduce", quantizer_attr)
     assert torch.allclose(quantizer_attr, getattr(quantizer, attr))

187-227: Consider attribute-based guards over config equality checks.

The helper correctly validates DP/CP synchronization for amax and act_scale. However, the config-based exclusions (line 197) are fragile—new configs may have different attribute availability. Consider using attribute existence checks instead for more robust validation.

Apply this diff to use attribute-based guards:

     # Input quantizer amax
-    if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
-        _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group)
-        _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=group)
+    if hasattr(model.fc1.input_quantizer, "amax") and getattr(model.fc1.input_quantizer, "amax") is not None:
+        _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group)
+    if hasattr(model.fc2.input_quantizer, "amax") and getattr(model.fc2.input_quantizer, "amax") is not None:
+        _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=group)

242-251: Consider renaming the internal helper to avoid shadowing.

The internal _reduce_quantizer_attr function (lines 242-251) shadows the module-level helper defined at line 121. While the logic is correct (reducing across all three groups), the naming collision reduces clarity.

Apply this diff to rename the internal helper:

-    def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX):
+    def _reduce_quantizer_attr_all_groups(quantizer, attr=str, op=dist.ReduceOp.MAX):
         quantizer_attr = getattr(quantizer, attr).clone()
         print("quantizer_attr before reduce", quantizer_attr)
         print("quantizer.attr before reduce", getattr(quantizer, attr))
         dist.all_reduce(quantizer_attr, op=op, group=dp_group)
         dist.all_reduce(quantizer_attr, op=op, group=cp_group)
         dist.all_reduce(quantizer_attr, op=op, group=tp_group)
         print("quantizer_attr after reduce", quantizer_attr)
         print("quantizer.attr after reduce", getattr(quantizer, attr))
         assert torch.allclose(quantizer_attr, getattr(quantizer, attr))

Then update all call sites within the function to use the new name.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d02365c and 95da832.

📒 Files selected for processing (2)
  • tests/_test_utils/torch_quantization/quantize_common.py (3 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (3)
  • data_tensor_context_parallel_test_helper (230-287)
  • dp_cp_parallel_test_helper (188-226)
  • tensor_parallel_test_helper (142-184)
tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
  • initialize_for_megatron (386-401)
  • MegatronModel (85-131)
tests/gpu/torch/conftest.py (2)
  • need_2_gpus (32-34)
  • need_8_gpus (38-40)
tests/_test_utils/torch_quantization/quantize_common.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • SequentialQuantizer (1114-1222)
modelopt/torch/quantization/model_calib.py (1)
  • awq_lite (424-677)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (9)
tests/gpu/torch/quantization/plugins/test_megatron.py (4)

98-103: LGTM! Tensor parallel test correctly updated.

The changes properly align with the updated MegatronModel constructor signature and the refactored tensor_parallel_test_helper. Correctly uses the same seed across TP ranks since tensor-parallel layers require identical input across the TP group.


124-146: LGTM! Data parallel test properly validates DP synchronization.

Per-rank seed divergence (SEED + rank) ensures that the test will catch broken DP synchronization. The use of dp_cp_parallel_test_helper with get_data_parallel_group() correctly validates amax and act_scale synchronization across the data-parallel group.


148-174: LGTM! Context parallel test properly validates CP synchronization.

Per-rank seed divergence ensures the test will catch broken CP synchronization. The use of dp_cp_parallel_test_helper with get_context_parallel_group() correctly validates amax and act_scale synchronization across the context-parallel group.


176-206: LGTM! Combined DP+TP+CP test validates multi-dimensional parallelism.

The 2×2×2 configuration correctly tests synchronization across all three parallelism dimensions. Per-rank seed divergence ensures the test will catch broken synchronization in any dimension. The use of data_tensor_context_parallel_test_helper with all three process groups correctly validates amax and act_scale synchronization.

tests/_test_utils/torch_quantization/quantize_common.py (5)

130-139: LGTM! Mock approach ensures consistent test behavior.

Forcing debug=True via the mock ensures that awq_lite.act_scale is preserved for validation, which is essential for testing the synchronization logic.


244-245: Remove debug print statements.

Similar to the module-level helper, the print statements here (lines 244-245, 249-250) will clutter test output. Consider removing them.

Apply this diff:

     def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX):
         quantizer_attr = getattr(quantizer, attr).clone()
-        print("quantizer_attr before reduce", quantizer_attr)
-        print("quantizer.attr before reduce", getattr(quantizer, attr))
         dist.all_reduce(quantizer_attr, op=op, group=dp_group)
         dist.all_reduce(quantizer_attr, op=op, group=cp_group)
         dist.all_reduce(quantizer_attr, op=op, group=tp_group)
-        print("quantizer_attr after reduce", quantizer_attr)
-        print("quantizer.attr after reduce", getattr(quantizer, attr))
         assert torch.allclose(quantizer_attr, getattr(quantizer, attr))

253-256: Consider attribute-based guards over config equality checks.

Similar to the dp_cp_parallel_test_helper, the config-based exclusions (line 254) are fragile. Consider using attribute existence checks for more robust validation.


141-184: Verify process group teardown usage.

dist.destroy_process_group() is still called in tensor_parallel_test_helper (line 184), which may prematurely tear down the global process group and affect subsequent tests. Confirm whether this explicit teardown is intentional here or if cleanup should be managed by the test harness.


274-287: Test already syncs act_scale across DP, CP, and TP.

The local _reduce_quantizer_attr unconditionally all‐reduces over dp_group, cp_group, and tp_group, so passing group=tp_group doesn’t limit scope—the test already covers synchronization across all three groups. No change needed.

Likely an incorrect or invalid review comment.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
tests/_test_utils/torch_quantization/quantize_common.py (1)

182-182: Do not destroy process groups inside helpers.
This tears down the global PG and can break subsequent tests in the same process. Remove it; the harness owns lifecycle.

-    dist.destroy_process_group()
+    # Do not destroy process groups here; test harness owns lifecycle.
🧹 Nitpick comments (5)
tests/gpu/torch/quantization/plugins/test_megatron.py (1)

100-101: Prefer using size for initialization to avoid drift.
Use the provided size consistently.

-    initialize_for_megatron(tensor_model_parallel_size=2, seed=SEED)
+    initialize_for_megatron(tensor_model_parallel_size=size, seed=SEED)
tests/_test_utils/torch_quantization/quantize_common.py (4)

167-181: Only fc1 act_scale validated; fc2 still TODO.
Consider enabling fc2 once the underlying assert is fixed to fully cover AWQ‑Lite.


185-225: DP/CP helper: avoid config checks; guard by attribute presence and cover pre_quant_scale.
Config equality is brittle. Validate whatever attributes exist (amax/pre_quant_scale/act_scale).

-    # Input quantizer amax
-    if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
-        _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group)
-        _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=group)
+    # Input quantizer attrs (if present)
+    for iq in (model.fc1.input_quantizer, model.fc2.input_quantizer):
+        if getattr(iq, "amax", None) is not None:
+            _reduce_quantizer_attr(iq, "amax", dist.ReduceOp.MAX, group=group)
+        # SmoothQuant/AWQ‑Lite scales
+        for attr, op in (("pre_quant_scale", dist.ReduceOp.MAX), ("act_scale", dist.ReduceOp.AVG)):
+            if getattr(iq, attr, None) is not None:
+                _reduce_quantizer_attr(iq, attr, op, group=group)
 
-    # Weight quantizer amax
-    if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
-        for quantizer in model.fc1.weight_quantizer:
-            _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group)
-    else:
-        _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group)
-    if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
-        for quantizer in model.fc2.weight_quantizer:
-            _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group)
-    else:
-        _reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group)
+    # Weight quantizer amax/pre_quant_scale
+    def _reduce_wq(wq):
+        if getattr(wq, "amax", None) is not None:
+            _reduce_quantizer_attr(wq, "amax", dist.ReduceOp.MAX, group=group)
+        if getattr(wq, "pre_quant_scale", None) is not None:
+            _reduce_quantizer_attr(wq, "pre_quant_scale", dist.ReduceOp.MAX, group=group)
+    for wq in (model.fc1.weight_quantizer, model.fc2.weight_quantizer):
+        if isinstance(wq, SequentialQuantizer):
+            for q in wq:
+                _reduce_wq(q)
+        else:
+            _reduce_wq(wq)
-
-    if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
-        # Check act scale
-        _reduce_quantizer_attr(
-            model.fc1.awq_lite,
-            "act_scale",
-            dist.ReduceOp.AVG,
-            group=group,
-        )
-        _reduce_quantizer_attr(
-            model.fc2.awq_lite,
-            "act_scale",
-            dist.ReduceOp.AVG,
-            group=group,
-        )
+    # AWQ‑Lite helper scale (if available)
+    for helper in (getattr(model.fc1, "awq_lite", None), getattr(model.fc2, "awq_lite", None)):
+        if helper is not None and getattr(helper, "act_scale", None) is not None:
+            _reduce_quantizer_attr(helper, "act_scale", dist.ReduceOp.AVG, group=group)

240-250: Avoid name shadowing and fix default arg in nested reducer.
Rename to avoid confusion with the module‑level helper and fix attr default.

-    def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX):
+    def _reduce_attr_all_groups(quantizer, attr: str, op=dist.ReduceOp.MAX):
         quantizer_attr = getattr(quantizer, attr).clone()
@@
-        assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
+        assert torch.allclose(quantizer_attr, getattr(quantizer, attr))

And update calls in this helper accordingly.


251-267: Add pre_quant_scale coverage in DP×TP×CP helper.
Mirror TP/DP checks and validate pre_quant_scale for input and weight quantizers.

-    # Input quantizer amax
+    # Input quantizer amax
     if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
-        _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX)
-        _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX)
+        _reduce_attr_all_groups(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX)
+        _reduce_attr_all_groups(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX)
+    # Input quantizer pre_quant_scale if present
+    for iq in (model.fc1.input_quantizer, model.fc2.input_quantizer):
+        if getattr(iq, "pre_quant_scale", None) is not None:
+            _reduce_attr_all_groups(iq, "pre_quant_scale", dist.ReduceOp.MAX)
@@
-            _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX)
+            _reduce_attr_all_groups(quantizer, "amax", dist.ReduceOp.MAX)
@@
-        _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX)
+        _reduce_attr_all_groups(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX)
+    # Weight quantizer pre_quant_scale if present
+    for wq in (model.fc1.weight_quantizer, model.fc2.weight_quantizer):
+        if isinstance(wq, SequentialQuantizer):
+            for q in wq:
+                if getattr(q, "pre_quant_scale", None) is not None:
+                    _reduce_attr_all_groups(q, "pre_quant_scale", dist.ReduceOp.MAX)
+        else:
+            if getattr(wq, "pre_quant_scale", None) is not None:
+                _reduce_attr_all_groups(wq, "pre_quant_scale", dist.ReduceOp.MAX)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 95da832 and 34c11ef.

📒 Files selected for processing (3)
  • examples/nemo_run/qat/README.md (2 hunks)
  • tests/_test_utils/torch_quantization/quantize_common.py (3 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/nemo_run/qat/README.md
🧰 Additional context used
🧬 Code graph analysis (2)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (3)
  • data_tensor_context_parallel_test_helper (228-279)
  • dp_cp_parallel_test_helper (186-224)
  • tensor_parallel_test_helper (140-182)
tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
  • initialize_for_megatron (386-401)
  • MegatronModel (85-131)
tests/gpu/torch/conftest.py (2)
  • need_2_gpus (32-34)
  • need_8_gpus (38-40)
tests/_test_utils/torch_quantization/quantize_common.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • SequentialQuantizer (1114-1222)
modelopt/torch/quantization/model_calib.py (1)
  • awq_lite (424-677)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (12)
tests/gpu/torch/quantization/plugins/test_megatron.py (7)

34-37: Good call to reuse shared helpers to cut duplication.
The switch to dp_cp_parallel_test_helper and data_tensor_context_parallel_test_helper keeps assertions centralized.


45-48: Expose CP group via public getter — LGTM.
Importing get_context_parallel_group aligns tests with Megatron public API.


98-104: TP helper wiring looks correct.
Passing tp_size to MegatronModel and using get_tensor_model_parallel_group() is the right scope.


124-146: Per‑rank RNG divergence before DP sync — nice.
initialize_for_megatron(seed=SEED + rank) ensures inputs differ across DP ranks so sync bugs are caught.


149-173: CP test mirrors DP path appropriately.
Seeding with SEED + rank and using get_context_parallel_group() verifies CP sync paths.


176-188: 8‑GPU DP×TP×CP test wiring is solid.
Group selection and seeding look correct; the TP all_reduce on calib_data maintains TP invariants.

Also applies to: 190-206


213-230: num_attention_heads=8 change looks consistent.
Hidden size defaults (256) are divisible by heads; this should be safe for test configs.

tests/_test_utils/torch_quantization/quantize_common.py (5)

16-16: Imports for patching and SequentialQuantizer — LGTM.
These enable AWQ‑Lite debug patching and sequential weight quantizer handling.

Also applies to: 26-26, 29-29


131-137: AWQ‑Lite debug patching is correct.
Saving the original and forcing debug=True via side_effect avoids recursion.


139-141: Decorator‑based patching of helpers — LGTM.
The mock is injected via the decorator and accepted by the helper signature.


152-166: Nice addition: validate pre_quant_scale across TP.
Covers SmoothQuant/AWQ‑Lite scale sync across TP.


268-279: AWQ‑Lite act_scale world‑sync check — LGTM.
Using AVG across DP/CP/TP confirms global agreement after per‑group sync.

Comment on lines 122 to 129
def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None):
quantizer_attr = getattr(quantizer, attr).clone()
print("quantizer.attr before reduce", getattr(quantizer, attr))
dist.all_reduce(quantizer_attr, op=op, group=group)
print("quantizer.attr after reduce", getattr(quantizer, attr))
print("quantizer_attr after reduce", quantizer_attr)
assert torch.allclose(quantizer_attr, getattr(quantizer, attr))

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix default arg type and reduce noisy prints.

  • attr default should be a string annotation, not the str type.
  • Unconditional prints will spam across ranks. Gate on rank 0 or remove.
-def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None):
+def _reduce_quantizer_attr(quantizer, attr: str, op=dist.ReduceOp.MAX, group=None):
     quantizer_attr = getattr(quantizer, attr).clone()
-    print("quantizer.attr before reduce", getattr(quantizer, attr))
+    # Optional: guard debug prints or remove
+    # if dist.is_initialized() and dist.get_rank() == 0:
+    #     print("quantizer.attr before reduce", getattr(quantizer, attr))
     dist.all_reduce(quantizer_attr, op=op, group=group)
-    print("quantizer.attr after reduce", getattr(quantizer, attr))
-    print("quantizer_attr after reduce", quantizer_attr)
+    # if dist.is_initialized() and dist.get_rank() == 0:
+    #     print("quantizer.attr after reduce", getattr(quantizer, attr))
+    #     print("quantizer_attr after reduce", quantizer_attr)
     assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None):
quantizer_attr = getattr(quantizer, attr).clone()
print("quantizer.attr before reduce", getattr(quantizer, attr))
dist.all_reduce(quantizer_attr, op=op, group=group)
print("quantizer.attr after reduce", getattr(quantizer, attr))
print("quantizer_attr after reduce", quantizer_attr)
assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
def _reduce_quantizer_attr(quantizer, attr: str, op=dist.ReduceOp.MAX, group=None):
quantizer_attr = getattr(quantizer, attr).clone()
# Optional: guard debug prints or remove
# if dist.is_initialized() and dist.get_rank() == 0:
# print("quantizer.attr before reduce", getattr(quantizer, attr))
dist.all_reduce(quantizer_attr, op=op, group=group)
# if dist.is_initialized() and dist.get_rank() == 0:
# print("quantizer.attr after reduce", getattr(quantizer, attr))
# print("quantizer_attr after reduce", quantizer_attr)
assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
🤖 Prompt for AI Agents
In tests/_test_utils/torch_quantization/quantize_common.py around lines 122-129,
the function signature incorrectly uses attr=str (making the default the str
type) and unconditionally prints from every rank; change the signature to
annotate attr as a string (e.g. def _reduce_quantizer_attr(quantizer, attr: str,
op=dist.ReduceOp.MAX, group=None):) so attr is typed properly (no stray
default), and remove or gate the print statements behind a single rank (e.g.
only print when dist.is_initialized() and dist.get_rank() == 0) to avoid
spamming across ranks; keep the clone, all_reduce, and assertion logic
unchanged.

Signed-off-by: Jennifer Chen <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
tests/_test_utils/torch_quantization/quantize_common.py (1)

122-128: Fix default arg type and reduce noisy prints.

  • Line 122: attr=str sets the default to the str type itself. Use attr: str (type annotation without default) or provide a proper string default.
  • Lines 124, 126-127: Unconditional prints will spam across all ranks in distributed tests. Gate on rank 0 or remove them.

Apply this diff:

-def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None):
+def _reduce_quantizer_attr(quantizer, attr: str, op=dist.ReduceOp.MAX, group=None):
     quantizer_attr = getattr(quantizer, attr).clone()
-    print("quantizer.attr before reduce", getattr(quantizer, attr))
+    # Optional: guard debug prints or remove
+    # if dist.is_initialized() and dist.get_rank() == 0:
+    #     print("quantizer.attr before reduce", getattr(quantizer, attr))
     dist.all_reduce(quantizer_attr, op=op, group=group)
-    print("quantizer.attr after reduce", getattr(quantizer, attr))
-    print("quantizer_attr after reduce", quantizer_attr)
+    # if dist.is_initialized() and dist.get_rank() == 0:
+    #     print("quantizer.attr after reduce", getattr(quantizer, attr))
+    #     print("quantizer_attr after reduce", quantizer_attr)
     assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
🧹 Nitpick comments (2)
tests/_test_utils/torch_quantization/quantize_common.py (2)

186-224: Consider attribute-based guards for robustness.

The helper relies on config exclusion lists (line 195) and assumes quantizer attributes exist. If configs evolve or quantizers are conditionally created, this may raise AttributeError.

Prefer guarding by attribute presence:

     # Input quantizer amax
-    if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
+    if getattr(model.fc1.input_quantizer, "amax", None) is not None:
         _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group)
+    if getattr(model.fc2.input_quantizer, "amax", None) is not None:
         _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=group)

Apply similar guards for act_scale checks:

-    if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
+    if hasattr(model.fc1, "awq_lite") and hasattr(model.fc1.awq_lite, "act_scale"):
         _reduce_quantizer_attr(
             model.fc1.awq_lite,
             "act_scale",
             dist.ReduceOp.AVG,
             group=group,
         )
+    if hasattr(model.fc2, "awq_lite") and hasattr(model.fc2.awq_lite, "act_scale"):
         _reduce_quantizer_attr(
             model.fc2.awq_lite,
             "act_scale",
             dist.ReduceOp.AVG,
             group=group,
         )

228-279: Consider attribute-based guards for combined DP+TP+CP helper.

Similar to dp_cp_parallel_test_helper, this function assumes quantizer attributes exist based on config checks. Prefer attribute presence guards for robustness.

Apply similar changes as suggested for dp_cp_parallel_test_helper:

     # Input quantizer amax
-    if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
+    if getattr(model.fc1.input_quantizer, "amax", None) is not None:
         _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX)
+    if getattr(model.fc2.input_quantizer, "amax", None) is not None:
         _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX)

And for AWQ-Lite checks:

     # Check act scale
-    if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
+    if hasattr(model.fc1, "awq_lite") and hasattr(model.fc1.awq_lite, "act_scale"):
         _reduce_quantizer_attr(
             model.fc1.awq_lite,
             "act_scale",
             dist.ReduceOp.AVG,
         )
+    if hasattr(model.fc2, "awq_lite") and hasattr(model.fc2.awq_lite, "act_scale"):
         _reduce_quantizer_attr(
             model.fc2.awq_lite,
             "act_scale",
             dist.ReduceOp.AVG,
         )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 34c11ef and 9f0691f.

📒 Files selected for processing (1)
  • tests/_test_utils/torch_quantization/quantize_common.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/_test_utils/torch_quantization/quantize_common.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • SequentialQuantizer (1114-1222)
modelopt/torch/quantization/model_calib.py (1)
  • awq_lite (424-677)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (3)
tests/_test_utils/torch_quantization/quantize_common.py (3)

16-16: LGTM: Imports support AWQ-Lite testing infrastructure.

The new imports are necessary for patching AWQ-Lite calibration and handling sequential quantizers in distributed tests.

Also applies to: 26-26, 28-28


131-136: LGTM: AWQ-Lite mocking preserves debug attributes for validation.

The mock wrapper forces debug=True to retain awq_lite attributes on quantized layers, enabling the test assertions to verify cross-rank synchronization.

Also applies to: 139-139


140-140: LGTM: Added AWQ-Lite act_scale validation for tensor parallelism.

The signature change accommodates the mock decorator, and the new act_scale validation on lines 167-180 directly addresses the PR objective to synchronize AWQ-Lite scales in TP.

Also applies to: 154-180

Signed-off-by: Jennifer Chen <[email protected]>
Signed-off-by: Jennifer Chen <[email protected]>
Signed-off-by: Jennifer Chen <[email protected]>

if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
# Check activation scale for AWQ lite
_reduce_quantizer_attr(
Copy link
Contributor Author

@jenchen13 jenchen13 Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@realAsma For TP, I only test fc1 (column parallel) act scale during awq lite, because fc2 row parallel will fail. For DP/CP I can test both column + row parallel act scale. I'm assuming row parallel fails because it's split across the c_in dimension in activation ... is this right?

@jenchen13 jenchen13 changed the title Sync amax & AWQ-Lite act_scale in context parallel/data parallel Sync amax & AWQ-Lite act_scale in context parallel/data parallel [OMNIML-2813] Oct 2, 2025
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
tests/_test_utils/torch_quantization/quantize_common.py (1)

168-176: Don’t tear down the global process group in a helper

tensor_parallel_test_helper never creates the default process group, yet it calls dist.destroy_process_group() on exit. This blows away the global NCCL world for the remaining logic in the spawned worker, so any follow-up cleanup (e.g., destroy_model_parallel() or other helpers) will trip over an uninitialized backend. Please drop this call and leave lifecycle management to the test harness.

-    dist.destroy_process_group()
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9f0691f and 22b8b73.

📒 Files selected for processing (6)
  • modelopt/torch/quantization/model_calib.py (3 hunks)
  • modelopt/torch/quantization/plugins/megatron.py (3 hunks)
  • modelopt/torch/utils/distributed.py (1 hunks)
  • tests/_test_utils/torch_dist/plugins/megatron_common.py (3 hunks)
  • tests/_test_utils/torch_quantization/quantize_common.py (3 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (5 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • modelopt/torch/utils/distributed.py
  • modelopt/torch/quantization/model_calib.py
  • tests/_test_utils/torch_dist/plugins/megatron_common.py
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/torch/quantization/plugins/megatron.py (2)
modelopt/torch/opt/dynamic.py (2)
  • parallel_state (876-878)
  • parallel_state (881-886)
modelopt/torch/utils/distributed.py (1)
  • ParallelState (232-253)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (3)
  • data_tensor_context_parallel_test_helper (225-273)
  • dp_cp_parallel_test_helper (180-221)
  • tensor_parallel_test_helper (140-176)
tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
  • initialize_for_megatron (390-405)
  • MegatronModel (85-135)
tests/gpu/torch/conftest.py (2)
  • need_2_gpus (32-34)
  • need_8_gpus (38-40)
tests/_test_utils/torch_quantization/quantize_common.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • SequentialQuantizer (1114-1222)
modelopt/torch/quantization/model_calib.py (1)
  • awq_lite (424-671)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
  • get_dummy_input (130-135)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs

Comment on lines +134 to +137
def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True):
"""Function to mock awq_lite function to always use debug=True for testing"""
return original_awq_lite(model, forward_loop, alpha_step, debug=True)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Forward the AWQ-Lite kwargs in the patch

The _debug_awq_lite wrapper drops every extra keyword argument that callers pass to awq_lite (e.g., tensor_parallel_group, data_parallel_group, max_calib_steps). The upstream API explicitly accepts **kwargs, so the first call that includes one of those options will now raise a TypeError, breaking AWQ-Lite calibration in the very paths this PR is exercising. Please mirror the original signature and forward **kwargs to original_awq_lite while forcing debug=True.

-def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True):
-    """Function to mock awq_lite function to always use debug=True for testing"""
-    return original_awq_lite(model, forward_loop, alpha_step, debug=True)
+def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True, **kwargs):
+    """Force awq_lite debug mode during tests without dropping optional args."""
+    return original_awq_lite(
+        model,
+        forward_loop,
+        alpha_step=alpha_step,
+        debug=True,
+        **kwargs,
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True):
"""Function to mock awq_lite function to always use debug=True for testing"""
return original_awq_lite(model, forward_loop, alpha_step, debug=True)
def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True, **kwargs):
"""Force awq_lite debug mode during tests without dropping optional args."""
return original_awq_lite(
model,
forward_loop,
alpha_step=alpha_step,
debug=True,
**kwargs,
)
🤖 Prompt for AI Agents
In tests/_test_utils/torch_quantization/quantize_common.py around lines 134 to
137, the _debug_awq_lite wrapper drops any extra keyword arguments callers pass
to awq_lite which causes TypeError when upstream calls include options like
tensor_parallel_group or max_calib_steps; update the wrapper to mirror the
original awq_lite signature by accepting *args and **kwargs (or the same
explicit params plus **kwargs) and forward them to original_awq_lite while
forcing debug=True (i.e., call original_awq_lite(..., debug=True, **kwargs) so
all upstream options are preserved).

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (4)
tests/_test_utils/torch_quantization/quantize_common.py (4)

131-136: Forward kwargs to avoid TypeError in awq_lite calls.

The _debug_awq_lite wrapper drops any extra keyword arguments (e.g., tensor_parallel_group, data_parallel_group), which will cause TypeError when callers pass those options.

Apply this fix:

-def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True):
+def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True, **kwargs):
     """Function to mock awq_lite function to always use debug=True for testing"""
-    return original_awq_lite(model, forward_loop, alpha_step, debug=True)
+    return original_awq_lite(model, forward_loop, alpha_step=alpha_step, debug=True, **kwargs)

Based on past review comments.


224-274: Add fc2.awq_lite.act_scale validation to match DP/CP helper.

The combined DP×TP×CP helper only checks fc1.awq_lite.act_scale (lines 269-273), leaving fc2 unchecked. A regression in the row-parallel branch would go undetected.

Apply this fix:

     # Check act scale
     if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
         _reduce_quantizer_attr(
             model.fc1.awq_lite,
             "act_scale",
             dist.ReduceOp.AVG,
         )
+        _reduce_quantizer_attr(
+            model.fc2.awq_lite,
+            "act_scale",
+            dist.ReduceOp.AVG,
+        )

Based on past review comments.


122-128: Remove noisy debug prints and fix type annotation.

Issues:

  1. Line 122: attr: str is correct (not attr=str)
  2. Lines 124, 126-127: Unconditional prints spam output across all ranks

Apply this fix:

-def _reduce_quantizer_attr(quantizer, attr: str, op=dist.ReduceOp.MAX, group=None):
+def _reduce_quantizer_attr(quantizer, attr: str, op=dist.ReduceOp.MAX, group=None):
     quantizer_attr = getattr(quantizer, attr).clone()
-    print("quantizer.attr before reduce", getattr(quantizer, attr))
     dist.all_reduce(quantizer_attr, op=op, group=group)
-    print("quantizer.attr after reduce", getattr(quantizer, attr))
-    print("quantizer_attr after reduce", quantizer_attr)
     assert torch.allclose(quantizer_attr, getattr(quantizer, attr))

Based on past review comments.


139-176: Remove process group destruction from helper.

Line 176 calls dist.destroy_process_group(), which unconditionally tears down the global process group and will break subsequent tests. The test harness manages the lifecycle.

Apply this fix:

     if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
         # Check activation scale for AWQ lite
         _reduce_quantizer_attr(
             model.fc1.awq_lite,
             "act_scale",
             dist.ReduceOp.AVG,
             group=tp_group,
         )
-
-    dist.destroy_process_group()

Based on past review comments.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9f0691f and ca7c0e8.

📒 Files selected for processing (6)
  • modelopt/torch/quantization/model_calib.py (3 hunks)
  • modelopt/torch/quantization/plugins/megatron.py (3 hunks)
  • modelopt/torch/utils/distributed.py (1 hunks)
  • tests/_test_utils/torch_dist/plugins/megatron_common.py (3 hunks)
  • tests/_test_utils/torch_quantization/quantize_common.py (3 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
modelopt/torch/quantization/plugins/megatron.py (2)
modelopt/torch/opt/dynamic.py (2)
  • parallel_state (876-878)
  • parallel_state (881-886)
modelopt/torch/utils/distributed.py (1)
  • ParallelState (232-253)
modelopt/torch/quantization/model_calib.py (2)
modelopt/torch/utils/distributed.py (2)
  • is_initialized (49-51)
  • is_initialized (196-198)
modelopt/torch/quantization/utils.py (1)
  • is_quantized_linear (246-256)
tests/_test_utils/torch_quantization/quantize_common.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • SequentialQuantizer (1114-1222)
modelopt/torch/quantization/model_calib.py (1)
  • awq_lite (424-671)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
  • get_dummy_input (130-135)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (3)
  • data_tensor_context_parallel_test_helper (225-273)
  • dp_cp_parallel_test_helper (180-221)
  • tensor_parallel_test_helper (140-176)
tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
  • initialize_for_megatron (390-405)
  • MegatronModel (85-135)
tests/gpu/torch/conftest.py (2)
  • need_2_gpus (32-34)
  • need_8_gpus (38-40)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (15)
modelopt/torch/utils/distributed.py (1)

249-253: LGTM: Improved readability for ParallelState representation.

The multi-line format with trailing commas makes the output easier to read, especially when debugging multi-group parallel configurations.

modelopt/torch/quantization/plugins/megatron.py (2)

18-18: LGTM: Proper logging setup for CP-aware DP group handling.

The logger initialization and import of get_data_parallel_group are necessary for the new fallback logic in _setup.

Also applies to: 26-26, 43-44


224-233: LGTM: Safe fallback for context-parallel group initialization.

The try/except block correctly handles the case where CP is not initialized:

  1. Attempts to get the DP group with CP awareness first
  2. Catches the AssertionError when CP is disabled
  3. Logs a helpful warning
  4. Falls back to the standard DP group

This resolves the past review concern about CP-disabled setups crashing.

Based on past review comments.

tests/_test_utils/torch_dist/plugins/megatron_common.py (3)

86-96: LGTM: MegatronModel now supports context parallelism.

The addition of cp_size parameter and its propagation to TransformerConfig as context_parallel_size correctly enables context-parallel testing.


130-135: LGTM: Seeded dummy input generation enables per-rank divergence.

The optional seed parameter allows tests to generate different calibration data across ranks, which is essential for validating DP/CP synchronization logic.


390-405: LGTM: Context parallelism properly wired into initialization.

The context_parallel_size parameter is correctly:

  1. Added after seed to maintain backward compatibility
  2. Forwarded to initialize_model_parallel
  3. Documented in the function

Based on past review comments.

modelopt/torch/quantization/model_calib.py (3)

83-91: LGTM: Data parallel amax synchronization correctly implemented.

The function properly:

  1. Handles SequentialQuantizer recursively
  2. Syncs amax across the data parallel group when present
  3. Returns early for SequentialQuantizer to avoid double-processing children

119-119: Minor: Code marker for TP sync is helpful but non-functional.

The comment serves as documentation for the TP synchronization logic that follows.


602-608: LGTM: Activation scale DP synchronization helper.

Clean helper function that averages act_scale across the data parallel group when initialized.

tests/gpu/torch/quantization/plugins/test_megatron.py (4)

34-36: LGTM: Correct imports for new DP/CP test helpers.

The new imports enable testing of data-parallel and context-parallel synchronization.


97-103: LGTM: Tensor parallel test properly updated.

The test correctly:

  1. Initializes with TP=2
  2. Creates model with tp_size=size
  3. Passes only the TP group to the helper

123-145: LGTM: Data parallel test validates DP synchronization.

The test correctly:

  1. Uses seed=SEED+rank to ensure divergent calibration data across ranks
  2. Creates a non-TP/CP model (DP-only)
  3. Validates synchronization using the DP group

147-173: LGTM: Context parallel test validates CP synchronization.

The test correctly:

  1. Initializes CP with context_parallel_size=size
  2. Uses seed=SEED+rank for per-rank divergence
  3. Creates model with cp_size=size
  4. Validates using the CP-aware DP group
tests/_test_utils/torch_quantization/quantize_common.py (2)

16-16: LGTM: Correct imports for patching and SequentialQuantizer support.

The imports enable mocking awq_lite to force debug mode and handling SequentialQuantizer validation.

Also applies to: 26-29


179-222: LGTM: DP/CP test helper validates synchronization correctly.

The helper properly:

  1. Uses per-rank dummy input
  2. Validates input/weight quantizer amax across the group
  3. Handles SequentialQuantizer
  4. Validates AWQ-Lite act_scale for both fc1 and fc2

Comment on lines +615 to +627
# Hack: MoEs forward all tokens through all experts if _if_calib is True
module._if_calib = True
module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps

if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
torch.isnan(module.awq_lite.weight_scale)
):
module.awq_lite.is_enabled = False
# Hack: MoEs forward all tokens through all experts if _if_calib is True
module._if_calib = True
else:
sync_act_scale_across_dp(
module,
module.parallel_state.data_parallel_group,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify: NaN check placement may cause collective operation mismatch.

The current implementation places the sync call in the else block (lines 624-627), which means:

  • Rank A (has NaN): sets is_enabled = False, skips else block, does NOT call sync_act_scale_across_dp
  • Rank B (no NaN): enters else block, CALLS sync_act_scale_across_dp

Since dist.all_reduce (inside sync_act_scale_across_dp at line 606) is a collective operation requiring all ranks to participate, this creates a deadlock when ranks disagree on NaN presence.

Past review comments suggest this should be resolved by either:

  1. Performing a collective NaN check first, then all ranks sync (or skip) together
  2. Moving NaN detection after sync

Consider adding a collective NaN check before the conditional:

# Check if ANY rank has NaN
has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
    torch.isnan(module.awq_lite.weight_scale)
)
has_nan = torch.tensor(int(has_nan_local), device=module.weight.device)
if module.parallel_state.data_parallel_group.is_initialized():
    dist.all_reduce(has_nan, op=dist.ReduceOp.MAX, 
                    group=module.parallel_state.data_parallel_group.group)

if has_nan.item() > 0:
    module.awq_lite.is_enabled = False
else:
    sync_act_scale_across_dp(module, module.parallel_state.data_parallel_group)

Based on past review comments.


🏁 Script executed:

#!/bin/bash
rg -n "sync_act_scale_across_dp" -C3

Length of output: 1217


🏁 Script executed:

#!/bin/bash
sed -n '580,650p' modelopt/torch/quantization/model_calib.py

Length of output: 3165


Collectively sync NaN detection before calling sync_act_scale_across_dp
Perform a global NaN check across all ranks, then either disable calibration everywhere or call the collective sync together. For example, replace the per-rank branch (model_calib.py ~621–627) with:

has_nan_local = (
    torch.any(torch.isnan(module.awq_lite.act_scale))
    or torch.any(torch.isnan(module.awq_lite.weight_scale))
)
has_nan = torch.tensor(int(has_nan_local), device=module.weight.device)
if module.parallel_state.data_parallel_group.is_initialized():
    dist.all_reduce(has_nan, op=dist.ReduceOp.MAX,
                    group=module.parallel_state.data_parallel_group.group)

if has_nan.item() > 0:
    module.awq_lite.is_enabled = False
else:
    sync_act_scale_across_dp(
        module,
        module.parallel_state.data_parallel_group,
    )

This ensures every rank participates in the collective operation and prevents deadlock.

🤖 Prompt for AI Agents
In modelopt/torch/quantization/model_calib.py around lines 615 to 627, replace
the current per-rank NaN check and conditional call to sync_act_scale_across_dp
with a collective NaN detection: compute a has_nan_local boolean from act_scale
or weight_scale NaNs, create a tensor on the module weight/device with that
value, perform an all_reduce (MAX) across
module.parallel_state.data_parallel_group if it's initialized, then if the
reduced has_nan is >0 set module.awq_lite.is_enabled = False on all ranks,
otherwise call sync_act_scale_across_dp; ensure the tensor is on the correct
device and the collective uses the data_parallel_group to avoid deadlocks.

Signed-off-by: Jennifer Chen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants