Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion tests/gpu/torch/quantization/plugins/test_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,22 @@ def _test_expert_model_parallel_amax_sync(
num_moe_experts=8,
transformer_impl="modelopt",
)
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()

# Initialize ALL weights based on rank to produce different amax values
# to produce different amax values across ranks that need synchronization
weight_idx = 0
for name, param in model.named_parameters():
# Skip embeddings and any parameters without 'weight' in the name
if "embedding" in name.lower() or "weight" not in name.lower():
continue

if param.requires_grad and param.dim() >= 2: # Only weight matrices, not biases
# Different constant value based on rank and parameter index
const_val = 0.1 + (rank * 0.5) + (weight_idx * 0.05)
param.data.fill_(const_val)
weight_idx += 1

prompt_tokens = (torch.ones((2, model.max_sequence_length)) * 0.05 + rank * 0.5).cuda().long()

# force all expert routing
for module in model.modules():
Expand Down