-
Notifications
You must be signed in to change notification settings - Fork 169
Sync amax & AWQ-Lite act_scale in context parallel/data parallel [OMNIML-2813] #359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
WalkthroughAdds context-parallel (CP) synchronization for quantization calibration (amax and act_scale), updates Megatron plugin to explicitly retrieve a data-parallel group with a CP-aware getter (with fallback), changes ParallelState string representation, expands tests and fixtures to cover DP/TP/CP scenarios (including an 8‑GPU fixture), and updates an example README and docker invocation (image tag, Megatron-LM mount, tensor_parallelism flag). Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Calib as Calibration Loop
participant Module as Quantized Module
participant Q as Tensor/SequentialQuantizer
participant DP as Data-Parallel Group
participant CP as Context-Parallel Group
Calib->>Module: collect stats (forward)
Module->>Q: update amax / act_scale
Note over Module,Q: after per-batch accumulation
alt Data-parallel sync
Module->>DP: all-reduce amax / act_scale
end
alt Context-parallel sync
Module->>CP: all-reduce amax / act_scale
end
Q-->>Module: synchronized scales
Module-->>Calib: continue/finalize calibration
sequenceDiagram
autonumber
participant Init as Megatron Plugin Init
participant MPG as megatron.core.parallel_state
participant PS as ParallelState
Init->>MPG: get_data_parallel_group(with_context_parallel=true)
alt context-parallel available
MPG-->>Init: dp_group (CP-aware)
else fallback
Init->>MPG: get_data_parallel_group()
MPG-->>Init: dp_group (fallback)
end
Init->>MPG: get_tensor_model_parallel_group()
Init->>MPG: get_context_parallel_group()
Init->>PS: __init__(data_parallel_group=dp_group, tensor_parallel_group=tp_group, ...)
PS-->>Init: ParallelState ready
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #359 +/- ##
==========================================
- Coverage 73.79% 73.78% -0.01%
==========================================
Files 171 171
Lines 17591 17596 +5
==========================================
+ Hits 12982 12984 +2
- Misses 4609 4612 +3 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
@jenchen13 could you please add unit tests for context parallel quantization (similar to tensor parallel) to here -
basically the TP test checks whether amax is similar across the TP group. see TensorRT-Model-Optimizer/tests/_test_utils/torch_quantization/quantize_common.py Line 119 in 26c203a
|
Signed-off-by: Jennifer Chen <[email protected]>
Signed-off-by: Jennifer Chen <[email protected]>
e764e79
to
42519cc
Compare
Signed-off-by: Jennifer Chen <[email protected]>
aa5b8fd
to
264adbb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/_test_utils/torch_quantization/quantize_common.py (1)
132-149
: Replace config-based guards with attribute-presence checks and sync AWQ pre_quant_scale
• For amax:if model.fc2.input_quantizer.amax is not None: activation_amax = model.fc2.input_quantizer.amax.clone() dist.all_reduce(activation_amax, op=dist.ReduceOp.MAX, group=tp_group) assert torch.allclose(activation_amax, model.fc2.input_quantizer.amax)• For scales (SmoothQuant, AWQ/AWQ-Lite):
# input scale if (scale := model.fc1.input_quantizer.pre_quant_scale) is not None: scale_clone = scale.clone() dist.all_reduce(scale_clone, op=dist.ReduceOp.MAX, group=tp_group) assert torch.allclose(scale_clone, scale) # weight scale (AWQ-Lite) if (wscale := model.fc1.weight_quantizer.pre_quant_scale) is not None: wscale_clone = wscale.clone() dist.all_reduce(wscale_clone, op=dist.ReduceOp.MAX, group=tp_group) assert torch.allclose(wscale_clone, wscale)Drop all
if config in […]
checks.
🧹 Nitpick comments (2)
modelopt/torch/quantization/model_calib.py (1)
628-651
: Weight DP/CP averages by token countsRight now we average
act_scale
equally across ranks. In mixed-workload runs (e.g., MoE routing) we can see unevennum_tokens
, so the lighter ranks end up pulling the mean down. Since we already tracknum_tokens
, we can switch to a weighted reduction (sum of scale * tokens and sum of tokens) before normalizing.A sketch:
- module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps - sync_act_scale_across_dp_cp(...) + module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps + token_count = torch.tensor( + [module.awq_lite.num_tokens], + device=module.awq_lite.act_scale.device, + dtype=module.awq_lite.act_scale.dtype, + ) + scale_sum = module.awq_lite.act_scale * token_count.item() + sync_reduction_across_dp_cp(scale_sum, token_count, module.parallel_state) + module.awq_lite.act_scale = scale_sum / token_count(
sync_reduction_across_dp_cp
would all-reduce both tensors across DP/CP groups.)tests/_test_utils/torch_quantization/quantize_common.py (1)
215-237
: 3D helper: sync input pre_quant_scale across TP/CP/DP
- Replace the config-based presence check with an attribute check for input amax (e.g. if getattr(model.fc1.input_quantizer, "amax", None) is not None) in tests/_test_utils/torch_quantization/quantize_common.py → data_tensor_context_parallel_test_helper.
- If input_quantizer.pre_quant_scale is present, clone it, all_reduce (MAX) across tp_group, cp_group, dp_group and assert torch.allclose(reduced, input_quantizer.pre_quant_scale).
- AWQ-Lite’s activation scale is stored on module.awq_lite.act_scale and model_calib syncs it using AVG across DP/CP — if you want to validate AWQ-Lite end-to-end, also check module.awq_lite.act_scale is synchronized (use the same group ops as model_calib).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
examples/nemo_run/qat/README.md
(1 hunks)modelopt/torch/quantization/model_calib.py
(2 hunks)modelopt/torch/quantization/plugins/megatron.py
(2 hunks)modelopt/torch/utils/distributed.py
(1 hunks)tests/_test_utils/torch_dist/plugins/megatron_common.py
(2 hunks)tests/_test_utils/torch_quantization/quantize_common.py
(2 hunks)tests/gpu/torch/conftest.py
(1 hunks)tests/gpu/torch/quantization/plugins/test_megatron.py
(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/torch/quantization/plugins/megatron.py (2)
modelopt/torch/opt/dynamic.py (2)
parallel_state
(876-878)parallel_state
(881-886)modelopt/torch/utils/distributed.py (1)
ParallelState
(232-256)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (4)
context_parallel_test_helper
(179-202)data_parallel_test_helper
(153-176)data_tensor_context_parallel_test_helper
(205-237)tensor_parallel_test_helper
(119-150)tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
initialize_for_megatron
(386-398)MegatronModel
(85-131)tests/gpu/torch/conftest.py (2)
need_2_gpus
(32-34)need_8_gpus
(38-40)
tests/_test_utils/torch_quantization/quantize_common.py (2)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_dummy_input
(130-131)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
amax
(231-236)amax
(239-250)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
🔇 Additional comments (7)
modelopt/torch/quantization/model_calib.py (1)
82-98
: Nice DP/CP amax sync integrationThe recursive helper cleanly reuses the existing SequentialQuantizer handling while adding the CP hop, so both DP and CP replicas end up aligned. 👍
tests/gpu/torch/conftest.py (1)
37-40
:need_8_gpus
fixture looks goodThe skip guard mirrors
need_2_gpus
, so multi-rank tests will short-circuit cleanly when the hardware isn’t there.tests/gpu/torch/quantization/plugins/test_megatron.py (1)
176-205
: Great coverage for the 2×2×2 scenarioSpinning up the combined DP/TP/CP path ensures the new sync logic is exercised end-to-end; thanks for wiring the groups explicitly.
tests/_test_utils/torch_quantization/quantize_common.py (2)
129-131
: Good addition: post-quantization sanity forward.Running a forward pass after quantize helps catch latent issues. Looks good.
205-212
: LGTM on the sequential TP→CP→DP all-reduce pattern.This correctly propagates maxima across orthogonal groups in 3D parallelism.
Please confirm group construction matches orthogonal decomposition (i.e., each rank belongs to exactly one group per dimension). If not, propagation may be incomplete.
examples/nemo_run/qat/README.md (1)
95-95
: Doc command update looks good.Nice to see the example reflecting the new tensor-parallel setup.
modelopt/torch/utils/distributed.py (1)
244-256
: Context group wiring looks solid.
ParallelState
now mirrors DP/TP/CP consistently, so downstream logging/debugging will show the full layout.
data_parallel_group = None | ||
try: | ||
data_parallel_group = get_data_parallel_group(with_context_parallel=True) | ||
except AssertionError: | ||
data_parallel_group = get_data_parallel_group() | ||
self.parallel_state = ParallelState( | ||
getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(), | ||
data_parallel_group, | ||
mcore_parallel.get_tensor_model_parallel_group(), | ||
mcore_parallel.get_context_parallel_group(), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard get_context_parallel_group()
when CP is disabled
get_context_parallel_group()
asserts that context parallelism was initialized. When we run TP/DP-only (the default in plenty of setups), that assertion fires and _MegatronParallelLinear._setup()
will crash. Please mirror the DP guard and fall back to -1
(unused) when the call raises.
Something along these lines keeps the DP-only path working:
- self.parallel_state = ParallelState(
- data_parallel_group,
- mcore_parallel.get_tensor_model_parallel_group(),
- mcore_parallel.get_context_parallel_group(),
- )
+ try:
+ context_parallel_group = mcore_parallel.get_context_parallel_group()
+ except AssertionError:
+ context_parallel_group = -1
+ self.parallel_state = ParallelState(
+ data_parallel_group,
+ mcore_parallel.get_tensor_model_parallel_group(),
+ context_parallel_group,
+ )
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
data_parallel_group = None | |
try: | |
data_parallel_group = get_data_parallel_group(with_context_parallel=True) | |
except AssertionError: | |
data_parallel_group = get_data_parallel_group() | |
self.parallel_state = ParallelState( | |
getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(), | |
data_parallel_group, | |
mcore_parallel.get_tensor_model_parallel_group(), | |
mcore_parallel.get_context_parallel_group(), | |
) | |
data_parallel_group = None | |
try: | |
data_parallel_group = get_data_parallel_group(with_context_parallel=True) | |
except AssertionError: | |
data_parallel_group = get_data_parallel_group() | |
try: | |
context_parallel_group = mcore_parallel.get_context_parallel_group() | |
except AssertionError: | |
context_parallel_group = -1 | |
self.parallel_state = ParallelState( | |
data_parallel_group, | |
mcore_parallel.get_tensor_model_parallel_group(), | |
context_parallel_group, | |
) |
🤖 Prompt for AI Agents
In modelopt/torch/quantization/plugins/megatron.py around lines 221 to 230, the
call to mcore_parallel.get_context_parallel_group() is unguarded and will assert
(and crash) when context-parallelism is disabled; mirror the data-parallel
guard: try to call get_context_parallel_group() and if it raises
(AssertionError) set the context group to -1 (or the sentinel used for
"unused"), then pass that value into ParallelState so TP/DP-only setups won't
fail. Ensure you only catch the assertion from the context-group call and keep
the existing fallback for get_data_parallel_group() unchanged.
Signed-off-by: Jennifer Chen <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tests/_test_utils/torch_quantization/quantize_common.py
(3 hunks)tests/gpu/torch/quantization/plugins/test_megatron.py
(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/_test_utils/torch_quantization/quantize_common.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
SequentialQuantizer
(1146-1254)amax
(231-236)amax
(239-250)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (3)
data_tensor_context_parallel_test_helper
(185-217)dp_cp_parallel_test_helper
(154-182)tensor_parallel_test_helper
(120-151)tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
initialize_for_megatron
(386-398)MegatronModel
(85-131)tests/gpu/torch/conftest.py (2)
need_2_gpus
(32-34)need_8_gpus
(38-40)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (2)
tests/_test_utils/torch_quantization/quantize_common.py (2)
151-151
: Remove process group teardown from helperWe’re still tearing down the global process group inside the helper, which the test harness owns. Subsequent jobs in the same worker will hit a destroyed NCCL group. Please drop this call.
- dist.destroy_process_group()
162-182
: Guard quantizer sync for AWQ-Lite and validate act_scaleAWQ-Lite quantizers set
is_mx_format
, soTensorQuantizer.amax
returnsNone
(see modelopt/torch/quantization/nn/modules/tensor_quantizer.py lines 230-235). The current helper unconditionally clones.amax
, which will raiseAttributeError
as soon as we run the AWQ-Lite DP/CP test. Even when it doesn’t explode, we never exercise the newly synchronizedact_scale
, so regressions slide through. Please guard for missingamax
and also checkpre_quant_scale
/act_scale
when present.- def reduce_amax(quantizer): - amax = quantizer.amax.clone() - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=group) - assert torch.allclose(amax, quantizer.amax) + def reduce_quant_attrs(quantizer): + validated = False + for attr in ("amax", "pre_quant_scale", "act_scale"): + tensor = getattr(quantizer, attr, None) + if tensor is None: + continue + synced = tensor.detach().clone() + dist.all_reduce(synced, op=dist.ReduceOp.MAX, group=group) + assert torch.allclose(synced, tensor), f"{attr} mismatch on {quantizer}" + validated = True + if not validated: + pytest.fail(f"{quantizer} exposed no distributed attribute to validate") @@ - reduce_amax(model.fc1.input_quantizer) - reduce_amax(model.fc2.input_quantizer) + reduce_quant_attrs(model.fc1.input_quantizer) + reduce_quant_attrs(model.fc2.input_quantizer) @@ - reduce_amax(quantizer) + reduce_quant_attrs(quantizer) @@ - reduce_amax(model.fc1.weight_quantizer) + reduce_quant_attrs(model.fc1.weight_quantizer) @@ - reduce_amax(quantizer) + reduce_quant_attrs(quantizer) @@ - reduce_amax(model.fc2.weight_quantizer) + reduce_quant_attrs(model.fc2.weight_quantizer)
Signed-off-by: Jenny Chen <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
modelopt/torch/quantization/model_calib.py (1)
602-612
: Consider adding a docstring for consistency.The act_scale synchronization logic is correct (using AVG for statistics is appropriate), but the function lacks a docstring.
Apply this diff to add a docstring:
def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_group): + """Synchronize the act_scale across all ranks in the data parallel and context parallel groups.""" # Sync across Data Parallel (DP)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/torch/quantization/model_calib.py
(2 hunks)modelopt/torch/quantization/plugins/megatron.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/quantization/plugins/megatron.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/model_calib.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
SequentialQuantizer
(1114-1222)sync_amax_across_distributed_group
(1071-1083)TensorQuantizer
(65-1111)is_enabled
(395-397)modelopt/torch/utils/distributed.py (2)
is_initialized
(49-51)is_initialized
(196-198)modelopt/torch/quantization/utils.py (1)
is_quantized_linear
(246-256)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (2)
modelopt/torch/quantization/model_calib.py (2)
83-91
: LGTM! CP support properly added to amax synchronization.The extension from DP-only to DP+CP synchronization is implemented correctly. The function handles both
SequentialQuantizer
recursion and directTensorQuantizer
sync, and the docstring accurately reflects the new behavior.
98-98
: Correct update to use CP-aware synchronization.The call site properly updated to invoke the new CP-aware sync function.
Signed-off-by: Jennifer Chen <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
modelopt/torch/quantization/model_calib.py
(2 hunks)tests/_test_utils/torch_dist/plugins/megatron_common.py
(2 hunks)tests/_test_utils/torch_quantization/quantize_common.py
(3 hunks)tests/gpu/torch/quantization/plugins/test_megatron.py
(4 hunks)tests/gpu/torch/quantization/test_model_calib.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/_test_utils/torch_dist/plugins/megatron_common.py
🧰 Additional context used
🧬 Code graph analysis (4)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (3)
data_tensor_context_parallel_test_helper
(185-221)dp_cp_parallel_test_helper
(154-182)tensor_parallel_test_helper
(120-151)tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
initialize_for_megatron
(386-401)MegatronModel
(85-131)tests/gpu/torch/conftest.py (2)
need_2_gpus
(32-34)need_8_gpus
(38-40)
tests/gpu/torch/quantization/test_model_calib.py (2)
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)
MegatronModel
(85-131)initialize_for_megatron
(386-401)get_dummy_input
(130-131)modelopt/torch/quantization/model_calib.py (1)
awq_lite
(424-677)
tests/_test_utils/torch_quantization/quantize_common.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
SequentialQuantizer
(1114-1222)amax
(236-241)amax
(244-255)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_dummy_input
(130-131)
modelopt/torch/quantization/model_calib.py (4)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
SequentialQuantizer
(1114-1222)sync_amax_across_distributed_group
(1071-1083)TensorQuantizer
(65-1111)is_enabled
(395-397)modelopt/torch/trace/symbols.py (1)
named_modules
(444-447)modelopt/torch/quantization/nn/modules/quant_module.py (1)
QuantModule
(37-96)modelopt/torch/utils/distributed.py (2)
is_initialized
(49-51)is_initialized
(196-198)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (13)
modelopt/torch/quantization/model_calib.py (2)
83-98
: LGTM! Well-structured CP-aware amax synchronization.The implementation correctly extends amax synchronization to context parallel groups while maintaining backward compatibility with data parallel. The recursive handling of
SequentialQuantizer
is appropriate, and usingReduceOp.MAX
for both DP and CP groups ensures correct semantics (maximum amax across all ranks).
602-613
: LGTM! Correct use of averaging for act_scale synchronization.The function properly synchronizes
act_scale
usingReduceOp.AVG
across both DP and CP groups, which is the correct reduction operation for averaging activation scales. The guard checks ensure synchronization only occurs when the groups are initialized.tests/gpu/torch/quantization/test_model_calib.py (1)
32-33
: LGTM! Test setup is correct.The test properly uses
spawn_multiprocess_job
with 2 GPUs and NCCL backend for distributed testing.tests/gpu/torch/quantization/plugins/test_megatron.py (5)
34-35
: LGTM!The new imports are necessary for the DP/CP test helpers and context-parallel group retrieval used in the tests below.
Also applies to: 45-45
101-103
: LGTM!The explicit
tp_size
keyword argument improves clarity, and removingdp_group
from thetensor_parallel_test_helper
call aligns with the updated signature inquantize_common.py
.
124-130
: Per-rank seed is overridden; test won't catch broken DP sync.Passing
SEED + rank
toinitialize_for_megatron
is overridden by the internal call tomodel_parallel_cuda_manual_seed(seed)
(seetests/_test_utils/torch_dist/plugins/megatron_common.py
, lines 385-400), so all ranks still produce identicalget_dummy_input()
activations. The test will pass even if DP synchronization is broken. Introduce a rank-dependent perturbation after initialization—e.g., reseed or add a small offset before callingdp_cp_parallel_test_helper
.Based on past review comments.
148-156
: Per-rank seed is overridden; test won't catch broken CP sync.The same issue from the DP test applies here:
initialize_for_megatron
internally callsmodel_parallel_cuda_manual_seed(seed)
with the provided seed, overriding the per-rank divergence you intended withSEED + rank
. All CP ranks will produce identical calibration data, so the test won't fail if CP synchronization regresses. Add a rank-dependent perturbation after initialization.Based on past review comments.
176-187
: Fixed seed produces identical calibration data; test won't catch broken DP/CP sync.Line 178 uses
SEED
without rank-dependent divergence. Sinceinitialize_for_megatron
callsmodel_parallel_cuda_manual_seed(SEED)
uniformly across all 8 ranks, every rank will produce identicalget_dummy_input()
activations, so the assertions indata_tensor_context_parallel_test_helper
will pass even if DP or CP synchronization is broken. Introduce rank-dependent perturbation (e.g.,SEED + rank + 1
) or add a small offset after initialization to ensure different calibration data per DP/CP rank.Based on past review comments.
tests/_test_utils/torch_quantization/quantize_common.py (5)
26-26
: LGTM!The
SequentialQuantizer
import is necessary for the new helpers to handle multi-format weight quantization correctly.
120-120
: LGTM!Removing the unused
dp_group
parameter simplifies the signature; the function only validates tensor-parallel synchronization.
151-151
: Remove global process group destruction from helper.This call unconditionally tears down the global process group, breaking any subsequent DP/CP/TP tests that run in the same process. The test harness owns the process group lifecycle. Remove this line.
Based on past review comments.
Apply this diff:
- dist.destroy_process_group()
154-183
: Guard amax access and add AWQ-Lite scale validation.Line 163 unconditionally accesses
quantizer.amax
, which returnsNone
for MX formats (seemodelopt/torch/quantization/nn/modules/tensor_quantizer.py
, line 237) and will crash on.clone()
. The config equality check at line 168 is brittle and misses future configs. Additionally, the PR objective includes synchronizing AWQ-Liteact_scale
, but this helper doesn't validate it.Based on past review comments.
Replace config checks with attribute guards and add scale validation:
- def reduce_amax(quantizer): - amax = quantizer.amax.clone() - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=group) - assert torch.allclose(amax, quantizer.amax) + def reduce_quant_attrs(quantizer): + for attr in ("amax", "pre_quant_scale", "act_scale"): + tensor = getattr(quantizer, attr, None) + if tensor is not None: + synced = tensor.clone() + dist.all_reduce(synced, op=dist.ReduceOp.MAX, group=group) + assert torch.allclose(synced, tensor), f"{attr} mismatch" # Input quantizer amax - if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - reduce_amax(model.fc1.input_quantizer) - reduce_amax(model.fc2.input_quantizer) + if hasattr(model.fc1, "input_quantizer"): + reduce_quant_attrs(model.fc1.input_quantizer) + if hasattr(model.fc2, "input_quantizer"): + reduce_quant_attrs(model.fc2.input_quantizer) # Weight quantizer amax if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): for quantizer in model.fc1.weight_quantizer: - reduce_amax(quantizer) + reduce_quant_attrs(quantizer) else: - reduce_amax(model.fc1.weight_quantizer) + reduce_quant_attrs(model.fc1.weight_quantizer) if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): for quantizer in model.fc2.weight_quantizer: - reduce_amax(quantizer) + reduce_quant_attrs(quantizer) else: - reduce_amax(model.fc2.weight_quantizer) + reduce_quant_attrs(model.fc2.weight_quantizer)
185-222
: Guard amax access, add AWQ-Lite scale validation, and remove debug prints.Lines 196, 197-198, 202-203 have the same issues as
dp_cp_parallel_test_helper
: unconditional.amax.clone()
crashes for MX formats, config equality checks are brittle, and AWQ-Liteact_scale
/pre_quant_scale
aren't validated. Additionally, theBased on past review comments.
Apply the same attribute-guard pattern from
dp_cp_parallel_test_helper
, replicate for all three groups (dp_group, cp_group, tp_group), and remove print statements:- def reduce_amax(quantizer): - amax = quantizer.amax.clone() - print("amax before reduce", amax) - print("quantizer.amax before reduce", quantizer.amax) - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=dp_group) - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=cp_group) - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=tp_group) - print("amax after reduce", amax) - print("quantizer.amax after reduce", quantizer.amax) - assert torch.allclose(amax, quantizer.amax) + def reduce_quant_attrs(quantizer): + for attr in ("amax", "pre_quant_scale", "act_scale"): + tensor = getattr(quantizer, attr, None) + if tensor is not None: + synced = tensor.clone() + for g in (dp_group, cp_group, tp_group): + dist.all_reduce(synced, op=dist.ReduceOp.MAX, group=g) + assert torch.allclose(synced, tensor), f"{attr} mismatch" # Input quantizer amax - if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - reduce_amax(model.fc1.input_quantizer) - reduce_amax(model.fc2.input_quantizer) + if hasattr(model.fc1, "input_quantizer"): + reduce_quant_attrs(model.fc1.input_quantizer) + if hasattr(model.fc2, "input_quantizer"): + reduce_quant_attrs(model.fc2.input_quantizer) if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): for quantizer in model.fc1.weight_quantizer: - reduce_amax(quantizer) + reduce_quant_attrs(quantizer) else: - reduce_amax(model.fc1.weight_quantizer) + reduce_quant_attrs(model.fc1.weight_quantizer) if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): for quantizer in model.fc2.weight_quantizer: - reduce_amax(quantizer) + reduce_quant_attrs(quantizer) else: - reduce_amax(model.fc2.weight_quantizer) + reduce_quant_attrs(model.fc2.weight_quantizer)
# Hack: MoEs forward all tokens through all experts if _if_calib is True | ||
module._if_calib = True | ||
module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps | ||
|
||
if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any( | ||
torch.isnan(module.awq_lite.weight_scale) | ||
): | ||
module.awq_lite.is_enabled = False | ||
# Hack: MoEs forward all tokens through all experts if _if_calib is True | ||
module._if_calib = True | ||
else: | ||
sync_act_scale_across_dp_cp( | ||
module, | ||
module.parallel_state.data_parallel_group, | ||
module.parallel_state.context_parallel_group, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Past review concern about NaN propagation remains unresolved.
The previous review identified that NaN detection is performed locally (lines 624-627), but the synchronization call (lines 629-633) will still propagate NaNs from any rank to all ranks in the DP/CP group through the AVG
operation. This creates an inconsistent state where:
- Rank A detects NaN → sets
is_enabled = False
locally - Rank B has valid data → enters else block and calls sync
- During sync, Rank A's NaN propagates to Rank B via
AVG
operation - Rank B now has NaN-corrupted
act_scale
butis_enabled = True
Consider one of these solutions:
- Add a collective NaN check before sync: use
all_reduce
to check if ANY rank has NaN, then disable all ranks if true - Move the sync before the NaN check and perform NaN detection after synchronization
- Guard the sync call: only call
sync_act_scale_across_dp_cp
whenmodule.awq_lite.is_enabled
is true (but this requires collective agreement on enabled state)
Based on past review comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i now only sync the act_scale
if no nans have been detected (it is after an else
statement)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jenchen13 Thank you for the update. While moving the sync to the else
block prevents syncing on ranks that detect NaN locally, there's a critical distributed synchronization issue that remains:
The Problem: Collective Operation Mismatch
The dist.all_reduce()
call inside sync_act_scale_across_dp_cp
is a collective operation that requires all ranks in the DP/CP group to participate. With the current implementation:
- Rank A (has NaN): Detects NaN → sets
is_enabled = False
→ skips the else block → does NOT callall_reduce
- Rank B (no NaN): No NaN detected → enters
else
block → callsall_reduce
This creates a deadlock because Rank B is waiting for Rank A to participate in the collective operation, but Rank A never makes the call.
Suggested Solution
You need to ensure all ranks agree on whether to sync before making the collective call:
# First, check if ANY rank has NaN
has_nan = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
torch.isnan(module.awq_lite.weight_scale)
)
# Create a global NaN indicator across all ranks
has_nan_tensor = torch.tensor(int(has_nan), device=module.weight.device)
if module.parallel_state.data_parallel_group.is_initialized():
dist.all_reduce(has_nan_tensor, op=dist.ReduceOp.MAX,
group=module.parallel_state.data_parallel_group.group)
if module.parallel_state.context_parallel_group.is_initialized():
dist.all_reduce(has_nan_tensor, op=dist.ReduceOp.MAX,
group=module.parallel_state.context_parallel_group.group)
if has_nan_tensor.item() > 0:
# At least one rank has NaN - disable on ALL ranks
module.awq_lite.is_enabled = False
else:
# No rank has NaN - safe to sync on ALL ranks
sync_act_scale_across_dp_cp(
module,
module.parallel_state.data_parallel_group,
module.parallel_state.context_parallel_group,
)
This ensures all ranks take the same code path and avoid the collective operation mismatch.
Signed-off-by: Jennifer Chen <[email protected]>
Signed-off-by: Jennifer Chen <[email protected]>
Signed-off-by: Jennifer Chen <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
tests/_test_utils/torch_quantization/quantize_common.py (3)
121-128
: Consider removing debug print statements.The
_reduce_quantizer_attr
helper is well-structured, but the print statements on lines 123, 125, and 126 will clutter test output. Consider removing them or wrapping them in a debug flag check.Apply this diff to remove the print statements:
def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None): quantizer_attr = getattr(quantizer, attr).clone() - print("quantizer.attr before reduce", getattr(quantizer, attr)) dist.all_reduce(quantizer_attr, op=op, group=group) - print("quantizer.attr after reduce", getattr(quantizer, attr)) - print("quantizer_attr after reduce", quantizer_attr) assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
187-227
: Consider attribute-based guards over config equality checks.The helper correctly validates DP/CP synchronization for amax and act_scale. However, the config-based exclusions (line 197) are fragile—new configs may have different attribute availability. Consider using attribute existence checks instead for more robust validation.
Apply this diff to use attribute-based guards:
# Input quantizer amax - if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) - _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) + if hasattr(model.fc1.input_quantizer, "amax") and getattr(model.fc1.input_quantizer, "amax") is not None: + _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) + if hasattr(model.fc2.input_quantizer, "amax") and getattr(model.fc2.input_quantizer, "amax") is not None: + _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=group)
242-251
: Consider renaming the internal helper to avoid shadowing.The internal
_reduce_quantizer_attr
function (lines 242-251) shadows the module-level helper defined at line 121. While the logic is correct (reducing across all three groups), the naming collision reduces clarity.Apply this diff to rename the internal helper:
- def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): + def _reduce_quantizer_attr_all_groups(quantizer, attr=str, op=dist.ReduceOp.MAX): quantizer_attr = getattr(quantizer, attr).clone() print("quantizer_attr before reduce", quantizer_attr) print("quantizer.attr before reduce", getattr(quantizer, attr)) dist.all_reduce(quantizer_attr, op=op, group=dp_group) dist.all_reduce(quantizer_attr, op=op, group=cp_group) dist.all_reduce(quantizer_attr, op=op, group=tp_group) print("quantizer_attr after reduce", quantizer_attr) print("quantizer.attr after reduce", getattr(quantizer, attr)) assert torch.allclose(quantizer_attr, getattr(quantizer, attr))Then update all call sites within the function to use the new name.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tests/_test_utils/torch_quantization/quantize_common.py
(3 hunks)tests/gpu/torch/quantization/plugins/test_megatron.py
(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (3)
data_tensor_context_parallel_test_helper
(230-287)dp_cp_parallel_test_helper
(188-226)tensor_parallel_test_helper
(142-184)tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
initialize_for_megatron
(386-401)MegatronModel
(85-131)tests/gpu/torch/conftest.py (2)
need_2_gpus
(32-34)need_8_gpus
(38-40)
tests/_test_utils/torch_quantization/quantize_common.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
SequentialQuantizer
(1114-1222)modelopt/torch/quantization/model_calib.py (1)
awq_lite
(424-677)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (9)
tests/gpu/torch/quantization/plugins/test_megatron.py (4)
98-103
: LGTM! Tensor parallel test correctly updated.The changes properly align with the updated
MegatronModel
constructor signature and the refactoredtensor_parallel_test_helper
. Correctly uses the same seed across TP ranks since tensor-parallel layers require identical input across the TP group.
124-146
: LGTM! Data parallel test properly validates DP synchronization.Per-rank seed divergence (SEED + rank) ensures that the test will catch broken DP synchronization. The use of
dp_cp_parallel_test_helper
withget_data_parallel_group()
correctly validates amax and act_scale synchronization across the data-parallel group.
148-174
: LGTM! Context parallel test properly validates CP synchronization.Per-rank seed divergence ensures the test will catch broken CP synchronization. The use of
dp_cp_parallel_test_helper
withget_context_parallel_group()
correctly validates amax and act_scale synchronization across the context-parallel group.
176-206
: LGTM! Combined DP+TP+CP test validates multi-dimensional parallelism.The 2×2×2 configuration correctly tests synchronization across all three parallelism dimensions. Per-rank seed divergence ensures the test will catch broken synchronization in any dimension. The use of
data_tensor_context_parallel_test_helper
with all three process groups correctly validates amax and act_scale synchronization.tests/_test_utils/torch_quantization/quantize_common.py (5)
130-139
: LGTM! Mock approach ensures consistent test behavior.Forcing
debug=True
via the mock ensures thatawq_lite.act_scale
is preserved for validation, which is essential for testing the synchronization logic.
244-245
: Remove debug print statements.Similar to the module-level helper, the print statements here (lines 244-245, 249-250) will clutter test output. Consider removing them.
Apply this diff:
def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): quantizer_attr = getattr(quantizer, attr).clone() - print("quantizer_attr before reduce", quantizer_attr) - print("quantizer.attr before reduce", getattr(quantizer, attr)) dist.all_reduce(quantizer_attr, op=op, group=dp_group) dist.all_reduce(quantizer_attr, op=op, group=cp_group) dist.all_reduce(quantizer_attr, op=op, group=tp_group) - print("quantizer_attr after reduce", quantizer_attr) - print("quantizer.attr after reduce", getattr(quantizer, attr)) assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
253-256
: Consider attribute-based guards over config equality checks.Similar to the
dp_cp_parallel_test_helper
, the config-based exclusions (line 254) are fragile. Consider using attribute existence checks for more robust validation.
141-184
: Verify process group teardown usage.
dist.destroy_process_group()
is still called intensor_parallel_test_helper
(line 184), which may prematurely tear down the global process group and affect subsequent tests. Confirm whether this explicit teardown is intentional here or if cleanup should be managed by the test harness.
274-287
: Test already syncs act_scale across DP, CP, and TP.The local
_reduce_quantizer_attr
unconditionally all‐reduces overdp_group
,cp_group
, andtp_group
, so passinggroup=tp_group
doesn’t limit scope—the test already covers synchronization across all three groups. No change needed.Likely an incorrect or invalid review comment.
Signed-off-by: Jennifer Chen <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tests/_test_utils/torch_quantization/quantize_common.py (1)
182-182
: Do not destroy process groups inside helpers.
This tears down the global PG and can break subsequent tests in the same process. Remove it; the harness owns lifecycle.- dist.destroy_process_group() + # Do not destroy process groups here; test harness owns lifecycle.
🧹 Nitpick comments (5)
tests/gpu/torch/quantization/plugins/test_megatron.py (1)
100-101
: Prefer using size for initialization to avoid drift.
Use the provided size consistently.- initialize_for_megatron(tensor_model_parallel_size=2, seed=SEED) + initialize_for_megatron(tensor_model_parallel_size=size, seed=SEED)tests/_test_utils/torch_quantization/quantize_common.py (4)
167-181
: Only fc1 act_scale validated; fc2 still TODO.
Consider enabling fc2 once the underlying assert is fixed to fully cover AWQ‑Lite.
185-225
: DP/CP helper: avoid config checks; guard by attribute presence and cover pre_quant_scale.
Config equality is brittle. Validate whatever attributes exist (amax/pre_quant_scale/act_scale).- # Input quantizer amax - if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) - _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) + # Input quantizer attrs (if present) + for iq in (model.fc1.input_quantizer, model.fc2.input_quantizer): + if getattr(iq, "amax", None) is not None: + _reduce_quantizer_attr(iq, "amax", dist.ReduceOp.MAX, group=group) + # SmoothQuant/AWQ‑Lite scales + for attr, op in (("pre_quant_scale", dist.ReduceOp.MAX), ("act_scale", dist.ReduceOp.AVG)): + if getattr(iq, attr, None) is not None: + _reduce_quantizer_attr(iq, attr, op, group=group) - # Weight quantizer amax - if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): - for quantizer in model.fc1.weight_quantizer: - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group) - else: - _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group) - if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): - for quantizer in model.fc2.weight_quantizer: - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group) - else: - _reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group) + # Weight quantizer amax/pre_quant_scale + def _reduce_wq(wq): + if getattr(wq, "amax", None) is not None: + _reduce_quantizer_attr(wq, "amax", dist.ReduceOp.MAX, group=group) + if getattr(wq, "pre_quant_scale", None) is not None: + _reduce_quantizer_attr(wq, "pre_quant_scale", dist.ReduceOp.MAX, group=group) + for wq in (model.fc1.weight_quantizer, model.fc2.weight_quantizer): + if isinstance(wq, SequentialQuantizer): + for q in wq: + _reduce_wq(q) + else: + _reduce_wq(wq) - - if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: - # Check act scale - _reduce_quantizer_attr( - model.fc1.awq_lite, - "act_scale", - dist.ReduceOp.AVG, - group=group, - ) - _reduce_quantizer_attr( - model.fc2.awq_lite, - "act_scale", - dist.ReduceOp.AVG, - group=group, - ) + # AWQ‑Lite helper scale (if available) + for helper in (getattr(model.fc1, "awq_lite", None), getattr(model.fc2, "awq_lite", None)): + if helper is not None and getattr(helper, "act_scale", None) is not None: + _reduce_quantizer_attr(helper, "act_scale", dist.ReduceOp.AVG, group=group)
240-250
: Avoid name shadowing and fix default arg in nested reducer.
Rename to avoid confusion with the module‑level helper and fix attr default.- def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): + def _reduce_attr_all_groups(quantizer, attr: str, op=dist.ReduceOp.MAX): quantizer_attr = getattr(quantizer, attr).clone() @@ - assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) + assert torch.allclose(quantizer_attr, getattr(quantizer, attr))And update calls in this helper accordingly.
251-267
: Add pre_quant_scale coverage in DP×TP×CP helper.
Mirror TP/DP checks and validate pre_quant_scale for input and weight quantizers.- # Input quantizer amax + # Input quantizer amax if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX) - _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX) + _reduce_attr_all_groups(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX) + _reduce_attr_all_groups(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX) + # Input quantizer pre_quant_scale if present + for iq in (model.fc1.input_quantizer, model.fc2.input_quantizer): + if getattr(iq, "pre_quant_scale", None) is not None: + _reduce_attr_all_groups(iq, "pre_quant_scale", dist.ReduceOp.MAX) @@ - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) + _reduce_attr_all_groups(quantizer, "amax", dist.ReduceOp.MAX) @@ - _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX) + _reduce_attr_all_groups(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX) + # Weight quantizer pre_quant_scale if present + for wq in (model.fc1.weight_quantizer, model.fc2.weight_quantizer): + if isinstance(wq, SequentialQuantizer): + for q in wq: + if getattr(q, "pre_quant_scale", None) is not None: + _reduce_attr_all_groups(q, "pre_quant_scale", dist.ReduceOp.MAX) + else: + if getattr(wq, "pre_quant_scale", None) is not None: + _reduce_attr_all_groups(wq, "pre_quant_scale", dist.ReduceOp.MAX)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/nemo_run/qat/README.md
(2 hunks)tests/_test_utils/torch_quantization/quantize_common.py
(3 hunks)tests/gpu/torch/quantization/plugins/test_megatron.py
(6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/nemo_run/qat/README.md
🧰 Additional context used
🧬 Code graph analysis (2)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (3)
data_tensor_context_parallel_test_helper
(228-279)dp_cp_parallel_test_helper
(186-224)tensor_parallel_test_helper
(140-182)tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
initialize_for_megatron
(386-401)MegatronModel
(85-131)tests/gpu/torch/conftest.py (2)
need_2_gpus
(32-34)need_8_gpus
(38-40)
tests/_test_utils/torch_quantization/quantize_common.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
SequentialQuantizer
(1114-1222)modelopt/torch/quantization/model_calib.py (1)
awq_lite
(424-677)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (12)
tests/gpu/torch/quantization/plugins/test_megatron.py (7)
34-37
: Good call to reuse shared helpers to cut duplication.
The switch to dp_cp_parallel_test_helper and data_tensor_context_parallel_test_helper keeps assertions centralized.
45-48
: Expose CP group via public getter — LGTM.
Importing get_context_parallel_group aligns tests with Megatron public API.
98-104
: TP helper wiring looks correct.
Passing tp_size to MegatronModel and using get_tensor_model_parallel_group() is the right scope.
124-146
: Per‑rank RNG divergence before DP sync — nice.
initialize_for_megatron(seed=SEED + rank) ensures inputs differ across DP ranks so sync bugs are caught.
149-173
: CP test mirrors DP path appropriately.
Seeding with SEED + rank and using get_context_parallel_group() verifies CP sync paths.
176-188
: 8‑GPU DP×TP×CP test wiring is solid.
Group selection and seeding look correct; the TP all_reduce on calib_data maintains TP invariants.Also applies to: 190-206
213-230
: num_attention_heads=8 change looks consistent.
Hidden size defaults (256) are divisible by heads; this should be safe for test configs.tests/_test_utils/torch_quantization/quantize_common.py (5)
16-16
: Imports for patching and SequentialQuantizer — LGTM.
These enable AWQ‑Lite debug patching and sequential weight quantizer handling.Also applies to: 26-26, 29-29
131-137
: AWQ‑Lite debug patching is correct.
Saving the original and forcing debug=True via side_effect avoids recursion.
139-141
: Decorator‑based patching of helpers — LGTM.
The mock is injected via the decorator and accepted by the helper signature.
152-166
: Nice addition: validate pre_quant_scale across TP.
Covers SmoothQuant/AWQ‑Lite scale sync across TP.
268-279
: AWQ‑Lite act_scale world‑sync check — LGTM.
Using AVG across DP/CP/TP confirms global agreement after per‑group sync.
def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None): | ||
quantizer_attr = getattr(quantizer, attr).clone() | ||
print("quantizer.attr before reduce", getattr(quantizer, attr)) | ||
dist.all_reduce(quantizer_attr, op=op, group=group) | ||
print("quantizer.attr after reduce", getattr(quantizer, attr)) | ||
print("quantizer_attr after reduce", quantizer_attr) | ||
assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix default arg type and reduce noisy prints.
- attr default should be a string annotation, not the str type.
- Unconditional prints will spam across ranks. Gate on rank 0 or remove.
-def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None):
+def _reduce_quantizer_attr(quantizer, attr: str, op=dist.ReduceOp.MAX, group=None):
quantizer_attr = getattr(quantizer, attr).clone()
- print("quantizer.attr before reduce", getattr(quantizer, attr))
+ # Optional: guard debug prints or remove
+ # if dist.is_initialized() and dist.get_rank() == 0:
+ # print("quantizer.attr before reduce", getattr(quantizer, attr))
dist.all_reduce(quantizer_attr, op=op, group=group)
- print("quantizer.attr after reduce", getattr(quantizer, attr))
- print("quantizer_attr after reduce", quantizer_attr)
+ # if dist.is_initialized() and dist.get_rank() == 0:
+ # print("quantizer.attr after reduce", getattr(quantizer, attr))
+ # print("quantizer_attr after reduce", quantizer_attr)
assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None): | |
quantizer_attr = getattr(quantizer, attr).clone() | |
print("quantizer.attr before reduce", getattr(quantizer, attr)) | |
dist.all_reduce(quantizer_attr, op=op, group=group) | |
print("quantizer.attr after reduce", getattr(quantizer, attr)) | |
print("quantizer_attr after reduce", quantizer_attr) | |
assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) | |
def _reduce_quantizer_attr(quantizer, attr: str, op=dist.ReduceOp.MAX, group=None): | |
quantizer_attr = getattr(quantizer, attr).clone() | |
# Optional: guard debug prints or remove | |
# if dist.is_initialized() and dist.get_rank() == 0: | |
# print("quantizer.attr before reduce", getattr(quantizer, attr)) | |
dist.all_reduce(quantizer_attr, op=op, group=group) | |
# if dist.is_initialized() and dist.get_rank() == 0: | |
# print("quantizer.attr after reduce", getattr(quantizer, attr)) | |
# print("quantizer_attr after reduce", quantizer_attr) | |
assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) |
🤖 Prompt for AI Agents
In tests/_test_utils/torch_quantization/quantize_common.py around lines 122-129,
the function signature incorrectly uses attr=str (making the default the str
type) and unconditionally prints from every rank; change the signature to
annotate attr as a string (e.g. def _reduce_quantizer_attr(quantizer, attr: str,
op=dist.ReduceOp.MAX, group=None):) so attr is typed properly (no stray
default), and remove or gate the print statements behind a single rank (e.g.
only print when dist.is_initialized() and dist.get_rank() == 0) to avoid
spamming across ranks; keep the clone, all_reduce, and assertion logic
unchanged.
Signed-off-by: Jennifer Chen <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (1)
tests/_test_utils/torch_quantization/quantize_common.py (1)
122-128
: Fix default arg type and reduce noisy prints.
- Line 122:
attr=str
sets the default to thestr
type itself. Useattr: str
(type annotation without default) or provide a proper string default.- Lines 124, 126-127: Unconditional prints will spam across all ranks in distributed tests. Gate on rank 0 or remove them.
Apply this diff:
-def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None): +def _reduce_quantizer_attr(quantizer, attr: str, op=dist.ReduceOp.MAX, group=None): quantizer_attr = getattr(quantizer, attr).clone() - print("quantizer.attr before reduce", getattr(quantizer, attr)) + # Optional: guard debug prints or remove + # if dist.is_initialized() and dist.get_rank() == 0: + # print("quantizer.attr before reduce", getattr(quantizer, attr)) dist.all_reduce(quantizer_attr, op=op, group=group) - print("quantizer.attr after reduce", getattr(quantizer, attr)) - print("quantizer_attr after reduce", quantizer_attr) + # if dist.is_initialized() and dist.get_rank() == 0: + # print("quantizer.attr after reduce", getattr(quantizer, attr)) + # print("quantizer_attr after reduce", quantizer_attr) assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
🧹 Nitpick comments (2)
tests/_test_utils/torch_quantization/quantize_common.py (2)
186-224
: Consider attribute-based guards for robustness.The helper relies on config exclusion lists (line 195) and assumes quantizer attributes exist. If configs evolve or quantizers are conditionally created, this may raise
AttributeError
.Prefer guarding by attribute presence:
# Input quantizer amax - if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: + if getattr(model.fc1.input_quantizer, "amax", None) is not None: _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) + if getattr(model.fc2.input_quantizer, "amax", None) is not None: _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=group)Apply similar guards for
act_scale
checks:- if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: + if hasattr(model.fc1, "awq_lite") and hasattr(model.fc1.awq_lite, "act_scale"): _reduce_quantizer_attr( model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, group=group, ) + if hasattr(model.fc2, "awq_lite") and hasattr(model.fc2.awq_lite, "act_scale"): _reduce_quantizer_attr( model.fc2.awq_lite, "act_scale", dist.ReduceOp.AVG, group=group, )
228-279
: Consider attribute-based guards for combined DP+TP+CP helper.Similar to
dp_cp_parallel_test_helper
, this function assumes quantizer attributes exist based on config checks. Prefer attribute presence guards for robustness.Apply similar changes as suggested for
dp_cp_parallel_test_helper
:# Input quantizer amax - if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: + if getattr(model.fc1.input_quantizer, "amax", None) is not None: _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX) + if getattr(model.fc2.input_quantizer, "amax", None) is not None: _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX)And for AWQ-Lite checks:
# Check act scale - if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: + if hasattr(model.fc1, "awq_lite") and hasattr(model.fc1.awq_lite, "act_scale"): _reduce_quantizer_attr( model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, ) + if hasattr(model.fc2, "awq_lite") and hasattr(model.fc2.awq_lite, "act_scale"): _reduce_quantizer_attr( model.fc2.awq_lite, "act_scale", dist.ReduceOp.AVG, )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/_test_utils/torch_quantization/quantize_common.py
(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/_test_utils/torch_quantization/quantize_common.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
SequentialQuantizer
(1114-1222)modelopt/torch/quantization/model_calib.py (1)
awq_lite
(424-677)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (3)
tests/_test_utils/torch_quantization/quantize_common.py (3)
16-16
: LGTM: Imports support AWQ-Lite testing infrastructure.The new imports are necessary for patching AWQ-Lite calibration and handling sequential quantizers in distributed tests.
Also applies to: 26-26, 28-28
131-136
: LGTM: AWQ-Lite mocking preserves debug attributes for validation.The mock wrapper forces
debug=True
to retainawq_lite
attributes on quantized layers, enabling the test assertions to verify cross-rank synchronization.Also applies to: 139-139
140-140
: LGTM: Added AWQ-Lite act_scale validation for tensor parallelism.The signature change accommodates the mock decorator, and the new
act_scale
validation on lines 167-180 directly addresses the PR objective to synchronize AWQ-Lite scales in TP.Also applies to: 154-180
Signed-off-by: Jennifer Chen <[email protected]>
Signed-off-by: Jennifer Chen <[email protected]>
Signed-off-by: Jennifer Chen <[email protected]>
|
||
if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: | ||
# Check activation scale for AWQ lite | ||
_reduce_quantizer_attr( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@realAsma For TP, I only test fc1 (column parallel) act scale during awq lite, because fc2 row parallel will fail. For DP/CP I can test both column + row parallel act scale. I'm assuming row parallel fails because it's split across the c_in
dimension in activation ... is this right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (1)
tests/_test_utils/torch_quantization/quantize_common.py (1)
168-176
: Don’t tear down the global process group in a helper
tensor_parallel_test_helper
never creates the default process group, yet it callsdist.destroy_process_group()
on exit. This blows away the global NCCL world for the remaining logic in the spawned worker, so any follow-up cleanup (e.g.,destroy_model_parallel()
or other helpers) will trip over an uninitialized backend. Please drop this call and leave lifecycle management to the test harness.- dist.destroy_process_group()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
modelopt/torch/quantization/model_calib.py
(3 hunks)modelopt/torch/quantization/plugins/megatron.py
(3 hunks)modelopt/torch/utils/distributed.py
(1 hunks)tests/_test_utils/torch_dist/plugins/megatron_common.py
(3 hunks)tests/_test_utils/torch_quantization/quantize_common.py
(3 hunks)tests/gpu/torch/quantization/plugins/test_megatron.py
(5 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- modelopt/torch/utils/distributed.py
- modelopt/torch/quantization/model_calib.py
- tests/_test_utils/torch_dist/plugins/megatron_common.py
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/torch/quantization/plugins/megatron.py (2)
modelopt/torch/opt/dynamic.py (2)
parallel_state
(876-878)parallel_state
(881-886)modelopt/torch/utils/distributed.py (1)
ParallelState
(232-253)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (3)
data_tensor_context_parallel_test_helper
(225-273)dp_cp_parallel_test_helper
(180-221)tensor_parallel_test_helper
(140-176)tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
initialize_for_megatron
(390-405)MegatronModel
(85-135)tests/gpu/torch/conftest.py (2)
need_2_gpus
(32-34)need_8_gpus
(38-40)
tests/_test_utils/torch_quantization/quantize_common.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
SequentialQuantizer
(1114-1222)modelopt/torch/quantization/model_calib.py (1)
awq_lite
(424-671)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_dummy_input
(130-135)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True): | ||
"""Function to mock awq_lite function to always use debug=True for testing""" | ||
return original_awq_lite(model, forward_loop, alpha_step, debug=True) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forward the AWQ-Lite kwargs in the patch
The _debug_awq_lite
wrapper drops every extra keyword argument that callers pass to awq_lite
(e.g., tensor_parallel_group
, data_parallel_group
, max_calib_steps
). The upstream API explicitly accepts **kwargs
, so the first call that includes one of those options will now raise a TypeError
, breaking AWQ-Lite calibration in the very paths this PR is exercising. Please mirror the original signature and forward **kwargs
to original_awq_lite
while forcing debug=True
.
-def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True):
- """Function to mock awq_lite function to always use debug=True for testing"""
- return original_awq_lite(model, forward_loop, alpha_step, debug=True)
+def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True, **kwargs):
+ """Force awq_lite debug mode during tests without dropping optional args."""
+ return original_awq_lite(
+ model,
+ forward_loop,
+ alpha_step=alpha_step,
+ debug=True,
+ **kwargs,
+ )
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True): | |
"""Function to mock awq_lite function to always use debug=True for testing""" | |
return original_awq_lite(model, forward_loop, alpha_step, debug=True) | |
def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True, **kwargs): | |
"""Force awq_lite debug mode during tests without dropping optional args.""" | |
return original_awq_lite( | |
model, | |
forward_loop, | |
alpha_step=alpha_step, | |
debug=True, | |
**kwargs, | |
) |
🤖 Prompt for AI Agents
In tests/_test_utils/torch_quantization/quantize_common.py around lines 134 to
137, the _debug_awq_lite wrapper drops any extra keyword arguments callers pass
to awq_lite which causes TypeError when upstream calls include options like
tensor_parallel_group or max_calib_steps; update the wrapper to mirror the
original awq_lite signature by accepting *args and **kwargs (or the same
explicit params plus **kwargs) and forward them to original_awq_lite while
forcing debug=True (i.e., call original_awq_lite(..., debug=True, **kwargs) so
all upstream options are preserved).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (4)
tests/_test_utils/torch_quantization/quantize_common.py (4)
131-136
: Forward kwargs to avoid TypeError in awq_lite calls.The
_debug_awq_lite
wrapper drops any extra keyword arguments (e.g.,tensor_parallel_group
,data_parallel_group
), which will causeTypeError
when callers pass those options.Apply this fix:
-def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True): +def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True, **kwargs): """Function to mock awq_lite function to always use debug=True for testing""" - return original_awq_lite(model, forward_loop, alpha_step, debug=True) + return original_awq_lite(model, forward_loop, alpha_step=alpha_step, debug=True, **kwargs)Based on past review comments.
224-274
: Add fc2.awq_lite.act_scale validation to match DP/CP helper.The combined DP×TP×CP helper only checks
fc1.awq_lite.act_scale
(lines 269-273), leavingfc2
unchecked. A regression in the row-parallel branch would go undetected.Apply this fix:
# Check act scale if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: _reduce_quantizer_attr( model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, ) + _reduce_quantizer_attr( + model.fc2.awq_lite, + "act_scale", + dist.ReduceOp.AVG, + )Based on past review comments.
122-128
: Remove noisy debug prints and fix type annotation.Issues:
- Line 122:
attr: str
is correct (notattr=str
)- Lines 124, 126-127: Unconditional prints spam output across all ranks
Apply this fix:
-def _reduce_quantizer_attr(quantizer, attr: str, op=dist.ReduceOp.MAX, group=None): +def _reduce_quantizer_attr(quantizer, attr: str, op=dist.ReduceOp.MAX, group=None): quantizer_attr = getattr(quantizer, attr).clone() - print("quantizer.attr before reduce", getattr(quantizer, attr)) dist.all_reduce(quantizer_attr, op=op, group=group) - print("quantizer.attr after reduce", getattr(quantizer, attr)) - print("quantizer_attr after reduce", quantizer_attr) assert torch.allclose(quantizer_attr, getattr(quantizer, attr))Based on past review comments.
139-176
: Remove process group destruction from helper.Line 176 calls
dist.destroy_process_group()
, which unconditionally tears down the global process group and will break subsequent tests. The test harness manages the lifecycle.Apply this fix:
if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: # Check activation scale for AWQ lite _reduce_quantizer_attr( model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, group=tp_group, ) - - dist.destroy_process_group()Based on past review comments.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
modelopt/torch/quantization/model_calib.py
(3 hunks)modelopt/torch/quantization/plugins/megatron.py
(3 hunks)modelopt/torch/utils/distributed.py
(1 hunks)tests/_test_utils/torch_dist/plugins/megatron_common.py
(3 hunks)tests/_test_utils/torch_quantization/quantize_common.py
(3 hunks)tests/gpu/torch/quantization/plugins/test_megatron.py
(5 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
modelopt/torch/quantization/plugins/megatron.py (2)
modelopt/torch/opt/dynamic.py (2)
parallel_state
(876-878)parallel_state
(881-886)modelopt/torch/utils/distributed.py (1)
ParallelState
(232-253)
modelopt/torch/quantization/model_calib.py (2)
modelopt/torch/utils/distributed.py (2)
is_initialized
(49-51)is_initialized
(196-198)modelopt/torch/quantization/utils.py (1)
is_quantized_linear
(246-256)
tests/_test_utils/torch_quantization/quantize_common.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
SequentialQuantizer
(1114-1222)modelopt/torch/quantization/model_calib.py (1)
awq_lite
(424-671)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_dummy_input
(130-135)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_quantization/quantize_common.py (3)
data_tensor_context_parallel_test_helper
(225-273)dp_cp_parallel_test_helper
(180-221)tensor_parallel_test_helper
(140-176)tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
initialize_for_megatron
(390-405)MegatronModel
(85-135)tests/gpu/torch/conftest.py (2)
need_2_gpus
(32-34)need_8_gpus
(38-40)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (15)
modelopt/torch/utils/distributed.py (1)
249-253
: LGTM: Improved readability for ParallelState representation.The multi-line format with trailing commas makes the output easier to read, especially when debugging multi-group parallel configurations.
modelopt/torch/quantization/plugins/megatron.py (2)
18-18
: LGTM: Proper logging setup for CP-aware DP group handling.The logger initialization and import of
get_data_parallel_group
are necessary for the new fallback logic in_setup
.Also applies to: 26-26, 43-44
224-233
: LGTM: Safe fallback for context-parallel group initialization.The try/except block correctly handles the case where CP is not initialized:
- Attempts to get the DP group with CP awareness first
- Catches the AssertionError when CP is disabled
- Logs a helpful warning
- Falls back to the standard DP group
This resolves the past review concern about CP-disabled setups crashing.
Based on past review comments.
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)
86-96
: LGTM: MegatronModel now supports context parallelism.The addition of
cp_size
parameter and its propagation toTransformerConfig
ascontext_parallel_size
correctly enables context-parallel testing.
130-135
: LGTM: Seeded dummy input generation enables per-rank divergence.The optional
seed
parameter allows tests to generate different calibration data across ranks, which is essential for validating DP/CP synchronization logic.
390-405
: LGTM: Context parallelism properly wired into initialization.The
context_parallel_size
parameter is correctly:
- Added after
seed
to maintain backward compatibility- Forwarded to
initialize_model_parallel
- Documented in the function
Based on past review comments.
modelopt/torch/quantization/model_calib.py (3)
83-91
: LGTM: Data parallel amax synchronization correctly implemented.The function properly:
- Handles SequentialQuantizer recursively
- Syncs amax across the data parallel group when present
- Returns early for SequentialQuantizer to avoid double-processing children
119-119
: Minor: Code marker for TP sync is helpful but non-functional.The comment serves as documentation for the TP synchronization logic that follows.
602-608
: LGTM: Activation scale DP synchronization helper.Clean helper function that averages
act_scale
across the data parallel group when initialized.tests/gpu/torch/quantization/plugins/test_megatron.py (4)
34-36
: LGTM: Correct imports for new DP/CP test helpers.The new imports enable testing of data-parallel and context-parallel synchronization.
97-103
: LGTM: Tensor parallel test properly updated.The test correctly:
- Initializes with TP=2
- Creates model with
tp_size=size
- Passes only the TP group to the helper
123-145
: LGTM: Data parallel test validates DP synchronization.The test correctly:
- Uses
seed=SEED+rank
to ensure divergent calibration data across ranks- Creates a non-TP/CP model (DP-only)
- Validates synchronization using the DP group
147-173
: LGTM: Context parallel test validates CP synchronization.The test correctly:
- Initializes CP with
context_parallel_size=size
- Uses
seed=SEED+rank
for per-rank divergence- Creates model with
cp_size=size
- Validates using the CP-aware DP group
tests/_test_utils/torch_quantization/quantize_common.py (2)
16-16
: LGTM: Correct imports for patching and SequentialQuantizer support.The imports enable mocking
awq_lite
to force debug mode and handling SequentialQuantizer validation.Also applies to: 26-29
179-222
: LGTM: DP/CP test helper validates synchronization correctly.The helper properly:
- Uses per-rank dummy input
- Validates input/weight quantizer amax across the group
- Handles SequentialQuantizer
- Validates AWQ-Lite act_scale for both fc1 and fc2
# Hack: MoEs forward all tokens through all experts if _if_calib is True | ||
module._if_calib = True | ||
module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps | ||
|
||
if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any( | ||
torch.isnan(module.awq_lite.weight_scale) | ||
): | ||
module.awq_lite.is_enabled = False | ||
# Hack: MoEs forward all tokens through all experts if _if_calib is True | ||
module._if_calib = True | ||
else: | ||
sync_act_scale_across_dp( | ||
module, | ||
module.parallel_state.data_parallel_group, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify: NaN check placement may cause collective operation mismatch.
The current implementation places the sync call in the else
block (lines 624-627), which means:
- Rank A (has NaN): sets
is_enabled = False
, skips else block, does NOT callsync_act_scale_across_dp
- Rank B (no NaN): enters else block, CALLS
sync_act_scale_across_dp
Since dist.all_reduce
(inside sync_act_scale_across_dp
at line 606) is a collective operation requiring all ranks to participate, this creates a deadlock when ranks disagree on NaN presence.
Past review comments suggest this should be resolved by either:
- Performing a collective NaN check first, then all ranks sync (or skip) together
- Moving NaN detection after sync
Consider adding a collective NaN check before the conditional:
# Check if ANY rank has NaN
has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
torch.isnan(module.awq_lite.weight_scale)
)
has_nan = torch.tensor(int(has_nan_local), device=module.weight.device)
if module.parallel_state.data_parallel_group.is_initialized():
dist.all_reduce(has_nan, op=dist.ReduceOp.MAX,
group=module.parallel_state.data_parallel_group.group)
if has_nan.item() > 0:
module.awq_lite.is_enabled = False
else:
sync_act_scale_across_dp(module, module.parallel_state.data_parallel_group)
Based on past review comments.
🏁 Script executed:
#!/bin/bash
rg -n "sync_act_scale_across_dp" -C3
Length of output: 1217
🏁 Script executed:
#!/bin/bash
sed -n '580,650p' modelopt/torch/quantization/model_calib.py
Length of output: 3165
Collectively sync NaN detection before calling sync_act_scale_across_dp
Perform a global NaN check across all ranks, then either disable calibration everywhere or call the collective sync together. For example, replace the per-rank branch (model_calib.py ~621–627) with:
has_nan_local = (
torch.any(torch.isnan(module.awq_lite.act_scale))
or torch.any(torch.isnan(module.awq_lite.weight_scale))
)
has_nan = torch.tensor(int(has_nan_local), device=module.weight.device)
if module.parallel_state.data_parallel_group.is_initialized():
dist.all_reduce(has_nan, op=dist.ReduceOp.MAX,
group=module.parallel_state.data_parallel_group.group)
if has_nan.item() > 0:
module.awq_lite.is_enabled = False
else:
sync_act_scale_across_dp(
module,
module.parallel_state.data_parallel_group,
)
This ensures every rank participates in the collective operation and prevents deadlock.
🤖 Prompt for AI Agents
In modelopt/torch/quantization/model_calib.py around lines 615 to 627, replace
the current per-rank NaN check and conditional call to sync_act_scale_across_dp
with a collective NaN detection: compute a has_nan_local boolean from act_scale
or weight_scale NaNs, create a tensor on the module weight/device with that
value, perform an all_reduce (MAX) across
module.parallel_state.data_parallel_group if it's initialized, then if the
reduced has_nan is >0 set module.awq_lite.is_enabled = False on all ranks,
otherwise call sync_act_scale_across_dp; ensure the tensor is on the correct
device and the collective uses the data_parallel_group to avoid deadlocks.
Signed-off-by: Jennifer Chen <[email protected]>
What does this PR do?
Type of change: ? New Feature
Overview: Sync quantizer amax in Context Parallelism & AWQ-Lite
act_scale
in CP/DPUsage
# Add a code snippet demonstrating how to use this
Testing
act_scale
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Documentation
Style
Tests