Skip to content

Commit 23daf38

Browse files
committed
Code cleanup
Signed-off-by: Kinjal Patel <[email protected]>
1 parent 5bc99e0 commit 23daf38

File tree

3 files changed

+29
-46
lines changed

3 files changed

+29
-46
lines changed

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -611,13 +611,9 @@ class _MegatronTEGroupedMLP(_MegatronMLP):
611611
def _setup(self):
612612
if not hasattr(self, "parallel_state") or self.parallel_state is None:
613613
self.parallel_state = ParallelState(
614-
mcore_parallel.get_expert_data_parallel_group(check_initialized=False),
615-
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(
616-
check_initialized=False
617-
),
618-
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(
619-
check_initialized=False
620-
),
614+
mcore_parallel.get_expert_data_parallel_group(),
615+
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(),
616+
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(),
621617
)
622618
# initialize parallel state for submodules linear_fc1 and linear_fc2
623619
self.linear_fc1.parallel_state = self.parallel_state
@@ -630,13 +626,9 @@ class _MegatronSequentialMLP(_MegatronMLP):
630626
def _setup(self):
631627
if not hasattr(self, "parallel_state") or self.parallel_state is None:
632628
self.parallel_state = ParallelState(
633-
mcore_parallel.get_expert_data_parallel_group(check_initialized=False),
634-
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(
635-
check_initialized=False
636-
),
637-
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(
638-
check_initialized=False
639-
),
629+
mcore_parallel.get_expert_data_parallel_group(),
630+
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(),
631+
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(),
640632
)
641633

642634
# Initialize parallel state for submodules local_experts.*.linear_fc1 and local_experts.*.linear_fc2

tests/_test_utils/torch_dist/plugins/megatron_common.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -515,20 +515,21 @@ def copy_weights_from_grouped_to_non_grouped(te_grouped_moe_model, sequential_mo
515515

516516
# Map grouped weights to sequential weights
517517
weight_mapping = {}
518-
sequential_key_template = "decoder.layers.{}.mlp.experts.local_experts.{}.linear_fc{}.weight"
518+
sequential_key_template = "decoder.layers.{}.mlp.experts.local_experts.{}.linear_fc{}"
519519
for key, value in te_grouped_state.items():
520-
if "experts.linear_fc" in key and "weight" in key:
520+
if "experts.linear_fc" in key and any(param in key for param in ("weight", "bias")):
521521
# Extract expert index from grouped weight name
522522
# Format: decoder.layers.X.mlp.experts.linear_fcY.weightZ
523523
parts = key.split(".")
524524
layer_idx = parts[2] # X
525525
fc_idx = parts[5] # Y (linear_fc1 or linear_fc2)
526-
weight_idx = parts[6] # Z (weight0, weight1, etc.)
527-
528-
# Map to sequential format: decoder.layers.X.mlp.experts.local_experts.Y.linear_fcZ.weight
529-
expert_idx = weight_idx.replace("weight", "")
526+
param_idx = parts[6] # weight0 / bias0 / etc.
527+
match = re.search(r"\d+", param_idx)
528+
expert_idx = match.group(0) if match else "0" # Z for expert index
529+
# Map to sequential format: decoder.layers.X.mlp.experts.local_experts.Y.linear_fcZ
530530
sequential_key = sequential_key_template.format(layer_idx, expert_idx, fc_idx[-1])
531-
weight_mapping[sequential_key] = value
531+
param_name = "weight" if "weight" in param_idx else "bias"
532+
weight_mapping[f"{sequential_key}.{param_name}"] = value
532533
elif isinstance(value, torch.Tensor):
533534
weight_mapping[key] = value
534535

@@ -540,7 +541,7 @@ def copy_weights_from_grouped_to_non_grouped(te_grouped_moe_model, sequential_mo
540541
sequential_moe_model.load_state_dict(sequential_state)
541542

542543

543-
def compare_amax_sync_across_expert_parallel(model):
544+
def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True):
544545
"""
545546
Test if amax values are synchronized across expert parallel groups.
546547
@@ -591,11 +592,12 @@ def compare_amax_sync_across_expert_parallel(model):
591592
quantizer_type in expert_quantizers
592593
and rank_idx in expert_quantizers[quantizer_type]
593594
):
594-
# compare expert value across expert for sequential MoE
595-
assert expert_quantizers[quantizer_type][rank_idx] == amax_val, (
596-
f"{rank_idx}, {quantizer_type}, expert_quantizers[quantizer_type][rank_idx]: "
597-
f"{expert_quantizers[quantizer_type][rank_idx]}, amax_val: {amax_val}"
598-
)
595+
if compare_across_experts:
596+
# compare expert value across expert for sequential MoE
597+
assert expert_quantizers[quantizer_type][rank_idx] == amax_val, (
598+
f"{rank_idx}, {quantizer_type}, expert_quantizers[quantizer_type][rank_idx]: "
599+
f"{expert_quantizers[quantizer_type][rank_idx]}, amax_val: {amax_val}"
600+
)
599601
expert_quantizers[quantizer_type][rank_idx] = amax_val
600602

601603
# Check synchronization - fail fast on first inconsistency

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -677,31 +677,20 @@ def forward_fn(model):
677677
assert initial_sync, (
678678
f"Inconsistent amax for expert {quantizer_type} across ranks: {rank_values}"
679679
)
680-
# Create inconsistent amax values
681-
cur_rank = torch.distributed.get_rank()
682-
for name, module in model.named_modules():
683-
if isinstance(module, mtq.nn.TensorQuantizer):
684-
# Check if this is an expert quantizer
685-
is_expert_quantizer = (
686-
"local_experts" in name # sequential MoE
687-
or ("experts" in name and "linear_fc" in name) # TEGrouped MoE
688-
)
689680

690-
if is_expert_quantizer and hasattr(module, "_amax"):
691-
# Create rank-specific amax values to simulate missing sync
692-
rank_offset = cur_rank * 0.1
693-
module.amax = module.amax + rank_offset
681+
# Test if the amax values are inconsistent when distributed sync is disabled
682+
mtq.model_calib.max_calibrate(model, forward_fn, distributed_sync=False)
683+
inconsistent_amax, _, _ = compare_amax_sync_across_expert_parallel(
684+
model, compare_across_experts=False
685+
)
694686

695-
# Test if the amax values are inconsistent
696-
inconsistent_amax, _, _ = compare_amax_sync_across_expert_parallel(model)
697687
assert not inconsistent_amax, (
698688
"Consistent amax across expert parallel ranks, "
699689
"Amax should not be synchronized across expert parallel ranks since expert parallel is disabled"
700690
)
701-
# Re-calibrate the model and test synchronization
702-
mtq.mode.wrapped_calib_func(
703-
model, mtq.config.MaxCalibConfig(), forward_fn, mtq.model_calib.max_calibrate
704-
)
691+
# calibrate the model with distributed sync and test synchronization
692+
mtq.model_calib.max_calibrate(model, forward_fn, distributed_sync=True)
693+
mtq.plugins.megatron.sync_amax_across_sequential_mlp(model)
705694

706695
final_sync, quantizer_type, rank_values = compare_amax_sync_across_expert_parallel(model)
707696
assert final_sync, f"Inconsistent amax for expert {quantizer_type} across ranks: {rank_values}"

0 commit comments

Comments
 (0)