Skip to content

Commit 8bff6b0

Browse files
committed
fixed tests for per-tensor support
Signed-off-by: Kinjal Patel <[email protected]>
1 parent 28c8bbf commit 8bff6b0

File tree

3 files changed

+87
-41
lines changed

3 files changed

+87
-41
lines changed

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -506,9 +506,15 @@ def _setup(self):
506506
expert.linear_fc2.parallel_state = self.parallel_state
507507

508508
def sync_moe_local_experts_amax(self):
509-
"""Sync amax across experts in a SequentialMLP."""
509+
"""Sync amax across local experts in a SequentialMLP.
510+
511+
amax across EP and ETP (for RowParallel) are synchronized as part of model_calib.max_calibrate().
512+
This function is called to synchronize the amax values across local experts s.t. all localexperts will
513+
share the same amax.
514+
"""
515+
torch.distributed.barrier()
516+
# Collect amax from all local experts
510517
amax_dict = {}
511-
# gather amax values from SequentialMLP experts
512518
for expert in self.local_experts:
513519
for name, module in expert.named_modules():
514520
if isinstance(module, TensorQuantizer) and module.amax is not None:
@@ -520,7 +526,7 @@ def sync_moe_local_experts_amax(self):
520526
else torch.maximum(stored_amax, amax_tensor)
521527
)
522528

523-
# sync amax values across experts in SequentialMLP
529+
# Apply synchronized amax values back to all local experts
524530
for expert in self.local_experts:
525531
for name, module in expert.named_modules():
526532
if isinstance(module, TensorQuantizer) and module.amax is not None:

tests/_test_utils/torch_dist/plugins/megatron_common.py

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
import copy
1616
import re
17+
from collections import defaultdict
1718
from warnings import warn
1819

1920
import torch
@@ -41,6 +42,7 @@
4142
from megatron.core.parallel_state import (
4243
get_expert_model_parallel_group,
4344
get_expert_tensor_parallel_group,
45+
get_expert_tensor_parallel_rank,
4446
initialize_model_parallel,
4547
is_pipeline_first_stage,
4648
is_pipeline_last_stage,
@@ -190,7 +192,7 @@ def squared_relu(x):
190192
pipeline_model_parallel_size=pipeline_model_parallel_size,
191193
expert_model_parallel_size=expert_model_parallel_size,
192194
expert_tensor_parallel_size=expert_tensor_parallel_size,
193-
sequence_parallel=expert_model_parallel_size > 1,
195+
sequence_parallel=False,
194196
moe_grouped_gemm=moe_grouped_gemm,
195197
num_layers=num_layers,
196198
num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage,
@@ -565,8 +567,7 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True)
565567
# Check for both TEGrouped and sequential MoE patterns
566568
if "local_experts" in name or ("experts" in name and "linear_fc" in name):
567569
# Convert to scalar only if tensor has a single element
568-
amax_val = module.amax.detach().clone().cpu()
569-
expert_amax_values[name] = amax_val
570+
expert_amax_values[name] = module.amax.detach().clone().cpu()
570571

571572
# Early return if no expert quantizers found
572573
assert expert_amax_values, "No expert quantizers found"
@@ -577,19 +578,16 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True)
577578
torch.distributed.all_gather_object(all_amax_values, expert_amax_values)
578579

579580
# Group quantizers by type (ignoring specific expert indices) and check sync
580-
expert_quantizers = {}
581+
expert_quantizers = defaultdict(dict)
581582
for rank_idx, rank_amax in enumerate(all_amax_values):
582583
for name, amax_val in rank_amax.items():
583584
# Create quantizer type key by normalizing the name
584-
if "local_experts" in name:
585-
# sequential MoE: replace expert index with wildcard
586-
quantizer_type = re.sub(r"local_experts\.\d+", "local_experts.*", name)
587-
else:
588-
# TEGrouped MoE: use the name as-is since experts are grouped
589-
quantizer_type = name
590-
591-
if quantizer_type not in expert_quantizers:
592-
expert_quantizers[quantizer_type] = {}
585+
quantizer_type = (
586+
re.sub(r"local_experts\.\d+", "local_experts.*", name)
587+
if "local_experts" in name
588+
else name
589+
)
590+
593591
if (
594592
quantizer_type in expert_quantizers
595593
and rank_idx in expert_quantizers[quantizer_type]
@@ -608,21 +606,52 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True)
608606
)
609607
expert_quantizers[quantizer_type][rank_idx] = amax_val
610608

611-
# Check synchronization - fail fast on first inconsistency
609+
rank_info = {
610+
"global_rank": torch.distributed.get_rank(),
611+
"etp_rank": get_expert_tensor_parallel_rank(),
612+
}
613+
614+
all_rank_info = [None] * world_size
615+
torch.distributed.all_gather_object(all_rank_info, rank_info)
616+
617+
# Group ranks by ETP rank for fc1 (ColumnParallel: same output channels should match)
618+
etp_groups = defaultdict(list)
619+
for info in all_rank_info:
620+
etp_groups[info["etp_rank"] if info["etp_rank"] else 0].append(info["global_rank"])
621+
612622
for quantizer_type, rank_values in expert_quantizers.items():
613-
if len(rank_values) > 1: # Only check if we have multiple ranks
614-
values = list(rank_values.values())
615-
# Handle both scalar and tensor comparisons
616-
first_val = values[0]
617-
if isinstance(first_val, torch.Tensor):
618-
# For tensors, check if all values are close to the first one
619-
for val in values[1:]:
620-
if not torch.allclose(first_val, val, rtol=1e-6, atol=1e-6):
621-
return False, quantizer_type, rank_values
622-
else:
623-
# For scalars, use numeric comparison
624-
max_diff = max(values) - min(values)
625-
if max_diff > 1e-6: # Allow for small floating point differences
626-
return False, quantizer_type, rank_values
623+
# Determine which ranks should have same amax
624+
# Find which rank should have same amax
625+
#
626+
# fc1: ColumnParallel: X @ [A_1, A_2] (weights split along Cout)
627+
# so amax should be the same across same ETP rank
628+
# if EP is 2, ETP is 2, we have 4 ranks, EP1, ETP1: 0, EP1, ETP2: 1, EP2, ETP1: 2, EP2, ETP2: 3
629+
# so we need to compare amax across same ETP rank [0, 2] [1, 3]
630+
#
631+
# fc2: RowParallel: [X_1, X_2] @ [A_1
632+
# A_2] (weights split along Cin)
633+
# amax should be the same across all ranks
634+
635+
rank_groups = (
636+
list(etp_groups.values())
637+
if "linear_fc1" in quantizer_type
638+
else [list(range(world_size))]
639+
)
640+
# Check each group independently
641+
for group in rank_groups:
642+
group_values = [rank_values[r] for r in group if r in rank_values]
643+
if len(group_values) > 1:
644+
# All values in this group should be identical
645+
first_val = group_values[0]
646+
for val in group_values[1:]:
647+
if isinstance(first_val, torch.Tensor):
648+
if not torch.allclose(first_val, val, rtol=1e-6, atol=1e-6):
649+
group_rank_values = {
650+
r: rank_values[r] for r in group if r in rank_values
651+
}
652+
return False, f"{quantizer_type} (group {group})", group_rank_values
653+
elif abs(first_val - val) > 1e-6:
654+
group_rank_values = {r: rank_values[r] for r in group if r in rank_values}
655+
return False, f"{quantizer_type} (group {group})", group_rank_values
627656

628657
return True, None, None

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
)
4646
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
4747
from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP
48+
from megatron.core.transformer.moe.router import TopKRouter
4849

4950
import modelopt
5051
import modelopt.torch.opt as mto
@@ -240,6 +241,7 @@ def _gpt_model_provider(
240241
ep_size=1,
241242
etp_size=None,
242243
use_te=False,
244+
transformer_impl="local",
243245
):
244246
"""Build the model."""
245247

@@ -253,7 +255,7 @@ def _gpt_model_provider(
253255
ffn_hidden_size=None,
254256
num_attention_heads=8,
255257
activation_func="squared_relu",
256-
transformer_impl="local",
258+
transformer_impl=transformer_impl,
257259
hidden_size=hidden_size,
258260
vocab_size=vocab_size,
259261
use_cpu_initialization=meta_device,
@@ -270,7 +272,7 @@ def _gpt_model_provider(
270272
ffn_hidden_size=None,
271273
num_attention_heads=8,
272274
activation_func="squared_relu",
273-
transformer_impl="local",
275+
transformer_impl=transformer_impl,
274276
hidden_size=hidden_size,
275277
vocab_size=vocab_size,
276278
num_moe_experts=num_moe_experts,
@@ -297,6 +299,7 @@ def _test_sharded_state_dict(
297299
num_moe_experts = moe_config.get("num_moe_experts", None)
298300
moe_grouped_gemm = moe_config.get("moe_grouped_gemm", False)
299301
use_te = moe_config.get("use_te", False)
302+
transformer_impl = moe_config.get("transformer_impl", "local")
300303

301304
initialize_for_megatron(
302305
tensor_model_parallel_size=tp_size,
@@ -314,6 +317,7 @@ def _test_sharded_state_dict(
314317
use_te=use_te,
315318
ep_size=ep_size,
316319
etp_size=etp_size,
320+
transformer_impl=transformer_impl,
317321
)
318322
model_test = _gpt_model_provider(
319323
tp_size,
@@ -325,6 +329,7 @@ def _test_sharded_state_dict(
325329
meta_device=meta_device,
326330
ep_size=ep_size,
327331
etp_size=etp_size,
332+
transformer_impl=transformer_impl,
328333
)
329334

330335
prompt_tokens = torch.randint(
@@ -531,10 +536,7 @@ def test_fp8_real_quantize():
531536

532537
@pytest.mark.parametrize(
533538
"config",
534-
[
535-
mtq.FP8_DEFAULT_CFG,
536-
mtq.NVFP4_DEFAULT_CFG,
537-
],
539+
[mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG, mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG],
538540
)
539541
@pytest.mark.parametrize("moe_grouped_gemm", [True, False])
540542
def test_moe_sharded_state_dict(need_4_gpus, tmp_path, config, moe_grouped_gemm):
@@ -549,6 +551,7 @@ def test_moe_sharded_state_dict(need_4_gpus, tmp_path, config, moe_grouped_gemm)
549551
"num_moe_experts": 4,
550552
"moe_grouped_gemm": moe_grouped_gemm,
551553
"use_te": moe_grouped_gemm,
554+
"transformer_impl": "modelopt",
552555
}
553556
spawn_multiprocess_job(
554557
size=size,
@@ -606,6 +609,7 @@ def forward_fn(model):
606609
hidden_size=32,
607610
moe_grouped_gemm=False,
608611
num_moe_experts=4,
612+
transformer_impl="modelopt",
609613
)
610614
num_sequential_mlp = sum(
611615
isinstance(module, SequentialMLP) for module in sequential_moe_model.modules()
@@ -666,10 +670,16 @@ def _test_expert_model_parallel_amax_sync(
666670
hidden_size=256,
667671
moe_grouped_gemm=moe_grouped_gemm,
668672
use_te=moe_grouped_gemm,
669-
num_moe_experts=4,
673+
num_moe_experts=8,
674+
transformer_impl="modelopt",
670675
)
671676
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()
672677

678+
# force all expert routing
679+
for module in model.modules():
680+
if isinstance(module, TopKRouter):
681+
module.topk = module.num_experts
682+
673683
def forward_fn(model):
674684
return megatron_prefill(model, prompt_tokens)
675685

@@ -701,9 +711,10 @@ def forward_fn(model):
701711
assert final_sync, f"Inconsistent amax for expert {quantizer_type} across ranks: {rank_values}"
702712

703713

714+
@pytest.mark.parametrize("config", [mtq.FP8_DEFAULT_CFG, mtq.INT8_DEFAULT_CFG])
704715
@pytest.mark.parametrize(("ep_size", "etp_size"), [(1, 2), (2, 1), (2, 2)])
705716
@pytest.mark.parametrize("moe_grouped_gemm", [True, False])
706-
def test_expert_parallel_sync(ep_size, etp_size, moe_grouped_gemm):
717+
def test_expert_parallel_sync(config, ep_size, etp_size, moe_grouped_gemm):
707718
"""Test expert model parallel synchronization."""
708719
size = torch.cuda.device_count()
709720
if size < ep_size * etp_size:
@@ -716,11 +727,11 @@ def test_expert_parallel_sync(ep_size, etp_size, moe_grouped_gemm):
716727
size=size,
717728
job=partial(
718729
_test_expert_model_parallel_amax_sync,
719-
2,
730+
etp_size, # tp_size
720731
ep_size,
721732
etp_size,
722733
moe_grouped_gemm,
723-
mtq.FP8_DEFAULT_CFG,
734+
config,
724735
),
725736
backend="nccl",
726737
)

0 commit comments

Comments
 (0)