Skip to content

Commit d19f4f5

Browse files
Updated amax_sync test to set const weights based on rank (#451)
Signed-off-by: Kinjal Patel <[email protected]>
1 parent 4476f21 commit d19f4f5

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Model Optimizer Changelog (Linux)
1111
- Support PTQ and fakequant in vLLM for fast evaluation of arbitrary quantization formats. See ``examples/vllm_serve`` for more details.
1212
- Add support for ``nemotron-post-training-dataset-v2`` and ``nemotron-post-training-dataset-v1`` in ``examples/llm_ptq``. Default to a mix of ``cnn_dailymail`` and ``nemotron-post-training-dataset-v2`` (gated dataset accessed using ``HF_TOKEN`` environment variable) if no dataset is specified.
1313
- Allow specifying ``calib_seq`` in ``examples/llm_ptq`` to set the maximum sequence length for calibration.
14+
- Add support for MCore MoE PTQ/QAT/QAD.
1415

1516
**Documentation**
1617

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,22 @@ def _test_expert_model_parallel_amax_sync(
673673
num_moe_experts=8,
674674
transformer_impl="modelopt",
675675
)
676-
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()
676+
677+
# Initialize ALL weights based on rank to produce different amax values
678+
# to produce different amax values across ranks that need synchronization
679+
weight_idx = 0
680+
for name, param in model.named_parameters():
681+
# Skip embeddings and any parameters without 'weight' in the name
682+
if "embedding" in name.lower() or "weight" not in name.lower():
683+
continue
684+
685+
if param.requires_grad and param.dim() >= 2: # Only weight matrices, not biases
686+
# Different constant value based on rank and parameter index
687+
const_val = 0.1 + (rank * 0.5) + (weight_idx * 0.05)
688+
param.data.fill_(const_val)
689+
weight_idx += 1
690+
691+
prompt_tokens = (torch.ones((2, model.max_sequence_length)) * 0.05 + rank * 0.5).cuda().long()
677692

678693
# force all expert routing
679694
for module in model.modules():

0 commit comments

Comments
 (0)