Skip to content

Commit 15ffb87

Browse files
committed
Code cleanup and test update
Signed-off-by: Kinjal Patel <[email protected]>
1 parent 23daf38 commit 15ffb87

File tree

5 files changed

+175
-151
lines changed

5 files changed

+175
-151
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ def sync_quantizer_amax_across_tp(
176176
parallel_state=module.parallel_state,
177177
)
178178

179+
for name, module in model.named_modules():
180+
if hasattr(module, "sync_moe_local_experts_amax"):
181+
module.sync_moe_local_experts_amax()
182+
179183

180184
def enable_stats_collection(model: nn.Module):
181185
"""Enable stats collection for all quantizers in the model."""

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 82 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
import megatron.core.transformer.mlp as megatron_mlp
2525
import megatron.core.transformer.moe.experts as megatron_moe
2626
import torch
27-
import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear
28-
from megatron.core.extensions import transformer_engine as megatron_te
2927
from megatron.core.parallel_state import get_data_parallel_group
3028
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
3129
from megatron.core.transformer import MegatronModule
@@ -41,47 +39,23 @@
4139
from ..nn import QuantModuleRegistry, TensorQuantizer
4240
from ..nn.modules.quant_linear import RealQuantLinear
4341
from ..qtensor import QTensorWrapper
44-
from .custom import CUSTOM_MODEL_PLUGINS, CUSTOM_POST_CALIBRATION_PLUGINS, _ParallelLinear
42+
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
4543

46-
logger = logging.getLogger(__name__)
47-
48-
__all__ = []
49-
50-
51-
def sync_amax_across_sequential_mlp(model: torch.nn.Module):
52-
"""Sync amax across experts in a SequentialMLP."""
53-
amax_dict = {}
54-
55-
def get_sequential_mlp_expert_names(name: str, module: torch.nn.Module):
56-
if (
57-
isinstance(module, TensorQuantizer)
58-
and hasattr(module, "_amax")
59-
and ".local_experts." in name
60-
):
61-
expert_name, local_expert_name = name.split(".local_experts.")
62-
# extract quantizer name by removing local_expert number from the name
63-
local_expert_name = ".".join(local_expert_name.split(".")[1:])
64-
return f"{expert_name}.{local_expert_name}"
65-
return None
44+
try:
45+
from megatron.core.extensions.transformer_engine import (
46+
TEColumnParallelGroupedLinear,
47+
TERowParallelGroupedLinear,
48+
)
6649

67-
# gather amax values from SequentialMLP experts
68-
for name, module in model.named_modules():
69-
expert_name = get_sequential_mlp_expert_names(name, module)
70-
if expert_name and module.amax is not None:
71-
stored_amax = amax_dict.get(expert_name)
72-
amax_tensor = module.amax.detach().clone()
73-
amax_dict[expert_name] = (
74-
amax_tensor if stored_amax is None else torch.maximum(stored_amax, amax_tensor)
75-
)
50+
from .transformer_engine import _QuantTEGroupedLinear
7651

77-
# sync amax values across experts in SequentialMLP
78-
for name, module in model.named_modules():
79-
expert_name = get_sequential_mlp_expert_names(name, module)
80-
if expert_name and module.amax is not None:
81-
module.amax = amax_dict[expert_name].detach().clone().to(module.amax.device)
52+
HAS_TE = True
53+
except ImportError:
54+
HAS_TE = False
8255

56+
logger = logging.getLogger(__name__)
8357

84-
CUSTOM_POST_CALIBRATION_PLUGINS.add(sync_amax_across_sequential_mlp)
58+
__all__ = []
8559

8660

8761
def real_quant_module_get_extra_state(self) -> dict:
@@ -516,111 +490,6 @@ def forward(self, input, *args, **kwargs):
516490
return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs)
517491

518492

519-
# Register the public te.pytorch.GroupedLinear class
520-
@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear"})
521-
class _QuantMegatronTEGroupedLinear(_MegatronParallelLinear):
522-
_functionals_to_replace = [
523-
(te_grouped_linear._GroupedLinear, "forward"),
524-
(te_grouped_linear._GroupedLinear, "apply"),
525-
]
526-
527-
def _setup(self):
528-
# GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
529-
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
530-
# self.weight0 to self.weight to run the quantizer states initialization.
531-
self.weight = self.weight0
532-
# Memorize the original weight.dtype for modelopt_post_restore given that
533-
# the dtype can change later.
534-
super()._setup()
535-
# Remove self.weight after setup.
536-
delattr(self, "weight")
537-
538-
def modelopt_post_restore(self, prefix: str = ""):
539-
# GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
540-
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
541-
# self.weight0 to self.weight to run the quantizer states initialization.
542-
self.weight = self.weight0
543-
super().modelopt_post_restore(prefix=prefix)
544-
# Remove self.weight after post_restore.
545-
delattr(self, "weight")
546-
547-
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
548-
# _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in
549-
# sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for
550-
# TE Fp8 checkpoint, we need to remove the _extra_state{gemm_idx} for gemm_idx:[1, num_gemms]
551-
# for modelopt checkpoint restore
552-
filtered_state_dict = {
553-
k: v
554-
for k, v in state_dict.items()
555-
if not any(k.endswith(f"_extra_state{num}") for num in range(1, self.num_gemms))
556-
}
557-
return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs)
558-
559-
def _process_quantizer_amax(self, k, v, quantizer_state_dict):
560-
assert v.numel() == 1, "TEGroupedLinear only supports per-tensor quantization"
561-
quantizer_state_dict[k] = v.view(-1)
562-
563-
@staticmethod
564-
def te_grouped_quantized_linear_fn(package, func_name, self, *args):
565-
idx = 1 if func_name == "_forward" else 0
566-
inp = args[idx]
567-
num_gemms = len(args[idx + 1])
568-
weights_and_biases = args[-2 * num_gemms :]
569-
weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:]
570-
quantized_inputs = self.input_quantizer(inp)
571-
quantized_weights = [self.weight_quantizer(weight) for weight in weights]
572-
573-
output = getattr(package, func_name)(
574-
*(
575-
args[0],
576-
quantized_inputs,
577-
)
578-
if func_name == "_forward"
579-
else (quantized_inputs,),
580-
*args[idx + 1 : -2 * num_gemms],
581-
*quantized_weights,
582-
*biases,
583-
)
584-
return self.output_quantizer(output)
585-
586-
# Override the quantized linear function
587-
_quantized_linear_fn = te_grouped_quantized_linear_fn
588-
589-
590-
@QuantModuleRegistry.register(
591-
{megatron_te.TEColumnParallelGroupedLinear: "megatron_TEColumnParallelGroupedLinear"}
592-
)
593-
class _MegatronTEGroupedColumnParallelLinear(
594-
_QuantMegatronTEGroupedLinear, _MegatronColumnParallelLinear
595-
):
596-
_is_column_parallel = True
597-
598-
599-
@QuantModuleRegistry.register(
600-
{megatron_te.TERowParallelGroupedLinear: "megatron_TERowParallelGroupedLinear"}
601-
)
602-
class _MegatronTEGroupedRowParallelLinear(
603-
_QuantMegatronTEGroupedLinear, _MegatronRowParallelLinear
604-
):
605-
_is_row_parallel = True
606-
607-
608-
# Register the public megatron_moe.TEGroupedMLP class
609-
@QuantModuleRegistry.register({megatron_moe.TEGroupedMLP: "megatron_moe_TEGroupedMLP"})
610-
class _MegatronTEGroupedMLP(_MegatronMLP):
611-
def _setup(self):
612-
if not hasattr(self, "parallel_state") or self.parallel_state is None:
613-
self.parallel_state = ParallelState(
614-
mcore_parallel.get_expert_data_parallel_group(),
615-
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(),
616-
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(),
617-
)
618-
# initialize parallel state for submodules linear_fc1 and linear_fc2
619-
self.linear_fc1.parallel_state = self.parallel_state
620-
self.linear_fc2.parallel_state = self.parallel_state
621-
622-
623-
# Register the public megatron_moe.SequentialMLP class
624493
@QuantModuleRegistry.register({megatron_moe.SequentialMLP: "megatron_moe_SequentialMLP"})
625494
class _MegatronSequentialMLP(_MegatronMLP):
626495
def _setup(self):
@@ -635,3 +504,73 @@ def _setup(self):
635504
for expert in self.local_experts:
636505
expert.linear_fc1.parallel_state = self.parallel_state
637506
expert.linear_fc2.parallel_state = self.parallel_state
507+
508+
def sync_moe_local_experts_amax(self):
509+
"""Sync amax across experts in a SequentialMLP."""
510+
amax_dict = {}
511+
# gather amax values from SequentialMLP experts
512+
for expert in self.local_experts:
513+
for name, module in expert.named_modules():
514+
if isinstance(module, TensorQuantizer) and module.amax is not None:
515+
stored_amax = amax_dict.get(name)
516+
amax_tensor = module.amax.detach().clone()
517+
amax_dict[name] = (
518+
amax_tensor
519+
if stored_amax is None
520+
else torch.maximum(stored_amax, amax_tensor)
521+
)
522+
523+
# sync amax values across experts in SequentialMLP
524+
for expert in self.local_experts:
525+
for name, module in expert.named_modules():
526+
if isinstance(module, TensorQuantizer) and module.amax is not None:
527+
module.amax = amax_dict[name].detach().clone().to(module.amax.device)
528+
529+
530+
if HAS_TE:
531+
# Quantized subclasses to support TEGroupedMLP quantization
532+
class _QuantMegatronTEGroupedLinear(_QuantTEGroupedLinear, _MegatronParallelLinear):
533+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
534+
# _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in
535+
# sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for
536+
# TE Fp8 checkpoint, we need to remove the _extra_state{gemm_idx} for gemm_idx:[1, num_gemms]
537+
# for modelopt checkpoint restore
538+
filtered_state_dict = {
539+
k: v
540+
for k, v in state_dict.items()
541+
if not any(k.endswith(f"_extra_state{num}") for num in range(1, self.num_gemms))
542+
}
543+
return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs)
544+
545+
def _process_quantizer_amax(self, k, v, quantizer_state_dict):
546+
assert v.numel() == 1, "TEGroupedLinear only supports per-tensor quantization"
547+
quantizer_state_dict[k] = v.view(-1)
548+
549+
@QuantModuleRegistry.register(
550+
{TEColumnParallelGroupedLinear: "megatron_TEColumnParallelGroupedLinear"}
551+
)
552+
class _MegatronTEGroupedColumnParallelLinear(
553+
_QuantMegatronTEGroupedLinear, _MegatronColumnParallelLinear
554+
):
555+
pass
556+
557+
@QuantModuleRegistry.register(
558+
{TERowParallelGroupedLinear: "megatron_TERowParallelGroupedLinear"}
559+
)
560+
class _MegatronTEGroupedRowParallelLinear(
561+
_QuantMegatronTEGroupedLinear, _MegatronRowParallelLinear
562+
):
563+
pass
564+
565+
@QuantModuleRegistry.register({megatron_moe.TEGroupedMLP: "megatron_moe_TEGroupedMLP"})
566+
class _MegatronTEGroupedMLP(_MegatronMLP):
567+
def _setup(self):
568+
if not hasattr(self, "parallel_state") or self.parallel_state is None:
569+
self.parallel_state = ParallelState(
570+
mcore_parallel.get_expert_data_parallel_group(),
571+
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(),
572+
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(),
573+
)
574+
# initialize parallel state for submodules linear_fc1 and linear_fc2
575+
self.linear_fc1.parallel_state = self.parallel_state
576+
self.linear_fc2.parallel_state = self.parallel_state

modelopt/torch/quantization/plugins/transformer_engine.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import torch
1919
import transformer_engine as te
20+
import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear
2021
import transformer_engine.pytorch.module.linear as te_linear
2122

2223
from ..nn import QuantModuleRegistry
@@ -58,3 +59,58 @@ def te_quantized_linear_fn(package, func_name, self, *args, **kwargs):
5859

5960
# Override the quantized linear function
6061
_quantized_linear_fn = te_quantized_linear_fn
62+
63+
64+
# Register the public te.pytorch.GroupedLinear class
65+
@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear"})
66+
class _QuantTEGroupedLinear(_ParallelLinear):
67+
_functionals_to_replace = [
68+
(te_grouped_linear._GroupedLinear, "forward"),
69+
(te_grouped_linear._GroupedLinear, "apply"),
70+
]
71+
72+
def _setup(self):
73+
# GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
74+
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
75+
# self.weight0 to self.weight to run the quantizer states initialization.
76+
self.weight = self.weight0
77+
# Memorize the original weight.dtype for modelopt_post_restore given that
78+
# the dtype can change later.
79+
super()._setup()
80+
# Remove self.weight after setup.
81+
delattr(self, "weight")
82+
83+
def modelopt_post_restore(self, prefix: str = ""):
84+
# GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
85+
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
86+
# self.weight0 to self.weight to run the quantizer states initialization.
87+
self.weight = self.weight0
88+
super().modelopt_post_restore(prefix=prefix)
89+
# Remove self.weight after post_restore.
90+
delattr(self, "weight")
91+
92+
@staticmethod
93+
def te_grouped_quantized_linear_fn(package, func_name, self, *args):
94+
idx = 1 if func_name == "_forward" else 0
95+
inp = args[idx]
96+
num_gemms = len(args[idx + 1])
97+
weights_and_biases = args[-2 * num_gemms :]
98+
weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:]
99+
quantized_inputs = self.input_quantizer(inp)
100+
quantized_weights = [self.weight_quantizer(weight) for weight in weights]
101+
102+
output = getattr(package, func_name)(
103+
*(
104+
args[0],
105+
quantized_inputs,
106+
)
107+
if func_name == "_forward"
108+
else (quantized_inputs,),
109+
*args[idx + 1 : -2 * num_gemms],
110+
*quantized_weights,
111+
*biases,
112+
)
113+
return self.output_quantizer(output)
114+
115+
# Override the quantized linear function
116+
_quantized_linear_fn = te_grouped_quantized_linear_fn

tests/_test_utils/torch_dist/plugins/megatron_common.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def squared_relu(x):
190190
pipeline_model_parallel_size=pipeline_model_parallel_size,
191191
expert_model_parallel_size=expert_model_parallel_size,
192192
expert_tensor_parallel_size=expert_tensor_parallel_size,
193-
sequence_parallel=False,
193+
sequence_parallel=expert_model_parallel_size > 1,
194194
moe_grouped_gemm=moe_grouped_gemm,
195195
num_layers=num_layers,
196196
num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage,
@@ -215,7 +215,8 @@ def squared_relu(x):
215215
num_experts=num_moe_experts,
216216
normalization=normalization,
217217
moe_grouped_gemm=moe_grouped_gemm,
218-
use_te=use_te,
218+
# TODO: uncomment this when TEGroupedMLP is enabled in Megatron-LM
219+
# use_te=use_te,
219220
)
220221
else:
221222
assert HAS_TE, "Transformer Engine not installed"
@@ -563,7 +564,8 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True)
563564
if isinstance(module, mtq.nn.TensorQuantizer) and hasattr(module, "_amax"):
564565
# Check for both TEGrouped and sequential MoE patterns
565566
if "local_experts" in name or ("experts" in name and "linear_fc" in name):
566-
amax_val = module.amax.item() if hasattr(module.amax, "item") else module.amax
567+
# Convert to scalar only if tensor has a single element
568+
amax_val = module.amax.detach().clone().cpu()
567569
expert_amax_values[name] = amax_val
568570

569571
# Early return if no expert quantizers found
@@ -594,7 +596,13 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True)
594596
):
595597
if compare_across_experts:
596598
# compare expert value across expert for sequential MoE
597-
assert expert_quantizers[quantizer_type][rank_idx] == amax_val, (
599+
prev_val = expert_quantizers[quantizer_type][rank_idx]
600+
# Handle both scalar and tensor comparisons
601+
if isinstance(amax_val, torch.Tensor) and isinstance(prev_val, torch.Tensor):
602+
are_equal = torch.allclose(prev_val, amax_val, rtol=1e-6, atol=1e-6)
603+
else:
604+
are_equal = prev_val == amax_val
605+
assert are_equal, (
598606
f"{rank_idx}, {quantizer_type}, expert_quantizers[quantizer_type][rank_idx]: "
599607
f"{expert_quantizers[quantizer_type][rank_idx]}, amax_val: {amax_val}"
600608
)
@@ -604,8 +612,17 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True)
604612
for quantizer_type, rank_values in expert_quantizers.items():
605613
if len(rank_values) > 1: # Only check if we have multiple ranks
606614
values = list(rank_values.values())
607-
max_diff = max(values) - min(values)
608-
if max_diff > 1e-6: # Allow for small floating point differences
609-
return False, quantizer_type, rank_values
615+
# Handle both scalar and tensor comparisons
616+
first_val = values[0]
617+
if isinstance(first_val, torch.Tensor):
618+
# For tensors, check if all values are close to the first one
619+
for val in values[1:]:
620+
if not torch.allclose(first_val, val, rtol=1e-6, atol=1e-6):
621+
return False, quantizer_type, rank_values
622+
else:
623+
# For scalars, use numeric comparison
624+
max_diff = max(values) - min(values)
625+
if max_diff > 1e-6: # Allow for small floating point differences
626+
return False, quantizer_type, rank_values
610627

611628
return True, None, None

0 commit comments

Comments
 (0)