Skip to content

Commit 54a38c7

Browse files
committed
Added quantization support for TEGroupedMoE for megatron-lm
Signed-off-by: Kinjal Patel <[email protected]>
1 parent 3f857a3 commit 54a38c7

File tree

6 files changed

+716
-17
lines changed

6 files changed

+716
-17
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,21 +80,26 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
8080
if not distributed_sync:
8181
return
8282

83-
def sync_quantizer_amax_across_dp(quantizer, parallel_state):
84-
"""Synchronize the amax across all ranks in the data parallel group."""
83+
def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state):
84+
"""Synchronize the amax across all ranks in the data parallel and context parallel groups."""
8585
if isinstance(quantizer, SequentialQuantizer):
8686
for _q in quantizer:
87-
sync_quantizer_amax_across_dp(_q, parallel_state)
87+
sync_quantizer_amax_across_dp_ep(_q, parallel_state)
8888
return
8989
if getattr(quantizer, "_amax", None) is not None:
9090
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
91+
quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group)
92+
if parallel_state.expert_tensor_parallel_group is not None:
93+
quantizer.sync_amax_across_distributed_group(
94+
parallel_state.expert_tensor_parallel_group
95+
)
9196
# TODO: create sync_bias_across_distributed_group
9297

9398
for name, module in model.named_modules():
9499
if isinstance(module, QuantModule):
95100
for child in module.children():
96101
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
97-
sync_quantizer_amax_across_dp(child, module.parallel_state)
102+
sync_quantizer_amax_across_dp_ep(child, module.parallel_state)
98103
# TP sync:
99104
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same
100105

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 170 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import megatron.core.tensor_parallel.layers as megatron_parallel
2424
import megatron.core.transformer.mlp as megatron_mlp
2525
import torch
26+
import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear
27+
from megatron.core.extensions import transformer_engine as megatron_te
2628
from megatron.core.parallel_state import get_data_parallel_group
2729
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
2830
from megatron.core.transformer import MegatronModule
@@ -35,8 +37,8 @@
3537
)
3638
from modelopt.torch.utils.distributed import ParallelState
3739

38-
from ..nn import QuantModuleRegistry, TensorQuantizer
39-
from ..nn.modules.quant_linear import RealQuantLinear
40+
from ..nn import QuantModuleRegistry, SequentialQuantizer, TensorQuantizer
41+
from ..nn.modules.quant_linear import RealQuantLinear, _QuantLinear
4042
from ..qtensor import QTensorWrapper
4143
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
4244

@@ -227,9 +229,17 @@ def _setup(self):
227229
except AssertionError:
228230
logger.warning("Context parallel group is not initialized, using data parallel group")
229231
data_parallel_group = get_data_parallel_group()
232+
233+
try:
234+
expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group()
235+
except AssertionError:
236+
expert_tensor_parallel_group = None
237+
230238
self.parallel_state = ParallelState(
231239
data_parallel_group,
232240
mcore_parallel.get_tensor_model_parallel_group(),
241+
mcore_parallel.get_expert_model_parallel_group(),
242+
expert_tensor_parallel_group,
233243
)
234244
super()._setup()
235245

@@ -472,3 +482,161 @@ class _RealQuantMegatronRowParallelLinear(
472482

473483
def forward(self, input, *args, **kwargs):
474484
return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs)
485+
486+
487+
# Register the public te.pytorch.GroupedLinear class
488+
@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear_public"})
489+
class _QuantTEGroupedLinear(_MegatronParallelLinear):
490+
def _setup(self):
491+
data_parallel_group = None
492+
try:
493+
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
494+
except AssertionError:
495+
data_parallel_group = get_data_parallel_group()
496+
497+
try:
498+
expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group()
499+
except AssertionError:
500+
expert_tensor_parallel_group = None
501+
self.parallel_state = ParallelState(
502+
data_parallel_group,
503+
mcore_parallel.get_tensor_model_parallel_group(),
504+
mcore_parallel.get_context_parallel_group(),
505+
mcore_parallel.get_expert_model_parallel_group(),
506+
expert_tensor_parallel_group,
507+
)
508+
self.input_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_input)
509+
self.weight_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_weight)
510+
self.output_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_output)
511+
self.output_quantizer.disable()
512+
513+
# Memorize the original weight.dtype for modelopt_post_restore given that
514+
# the dtype can change later.
515+
self.original_weight_dtype = None if self.weight0 is None else self.weight0.dtype
516+
517+
@property
518+
def functionals_to_replace(self):
519+
original_forward = te_grouped_linear._GroupedLinear.forward
520+
521+
def te_grouped_quantized_linear_fn(ctx, inp, m_splits, *args):
522+
num_gemms = len(m_splits)
523+
weights_and_biases = args[-2 * num_gemms :]
524+
weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:]
525+
quantized_inputs = self.input_quantizer(inp)
526+
quantized_weights = [self.weight_quantizer(weight) for weight in weights]
527+
528+
output = original_forward(
529+
ctx,
530+
quantized_inputs,
531+
m_splits,
532+
*args[: -2 * num_gemms],
533+
*quantized_weights,
534+
*biases,
535+
)
536+
return self.output_quantizer(output)
537+
538+
return [
539+
(
540+
te_grouped_linear._GroupedLinear,
541+
"forward",
542+
te_grouped_quantized_linear_fn,
543+
),
544+
]
545+
546+
def modelopt_post_restore(self, prefix: str = ""):
547+
"""Post restore to correctly configure the TensorQuantizer states for MCore/distributed frameworks.
548+
549+
ModelOpt restores the TensorQuantizer states such as `_amax` and `_pre_quant_scale` to their
550+
shape before saving. However this is not enough for MCore/distributed frameworks since the tensor parallelism
551+
could change between saving and restoring. If the tensor parallelism changes, the shape of the quantizer
552+
states also changes. So we need to re-calculate the quantizer states.
553+
"""
554+
from modelopt.torch.quantization.model_calib import max_calibrate
555+
556+
def _check_unsupported_states(quantizer: TensorQuantizer):
557+
for k in quantizer.state_dict():
558+
if k not in ["_amax", "_pre_quant_scale"]:
559+
warnings.warn(
560+
f"Restore of {k} for {prefix} is not supported. The restore of this layer might be "
561+
f"incorrect. Please implement a custom restore for {k}."
562+
)
563+
564+
def _has_state(quantizer, name):
565+
# Handling for SequentialQuantizer
566+
quantizer = quantizer[0] if isinstance(quantizer, SequentialQuantizer) else quantizer
567+
return hasattr(quantizer, name)
568+
569+
# weights for TEGroupedLinear are stored in weight0, weight1, etc.
570+
if self.weight0 is None:
571+
return
572+
for quantizer in [self.weight_quantizer, self.input_quantizer, self.output_quantizer]:
573+
_check_unsupported_states(
574+
quantizer if isinstance(quantizer, TensorQuantizer) else quantizer[0]
575+
)
576+
if _has_state(self.weight_quantizer, "_amax"):
577+
self.weight_quantizer.reset_amax()
578+
for i in range(self.num_gemms):
579+
weight = getattr(self, f"weight{i}")
580+
assert weight is not None, "weight is None"
581+
582+
max_calibrate(self.weight_quantizer, lambda wq: wq(weight), distributed_sync=False)
583+
if _has_state(self.input_quantizer, "_pre_quant_scale"):
584+
if hasattr(self.input_quantizer, "_pre_quant_scale"):
585+
delattr(self.input_quantizer, "_pre_quant_scale")
586+
pqs = torch.zeros(
587+
(weight.shape[1]), device=weight.device, dtype=self.original_weight_dtype
588+
)
589+
self.input_quantizer.register_buffer("_pre_quant_scale", pqs)
590+
591+
if _has_state(self.input_quantizer, "_amax"):
592+
self.input_quantizer.reset_amax()
593+
dummy_input = torch.ones(
594+
(1, 1, self.weight0.shape[1]),
595+
device=self.weight0.device,
596+
dtype=self.original_weight_dtype,
597+
)
598+
max_calibrate(self.input_quantizer, lambda iq: iq(dummy_input), distributed_sync=False)
599+
if _has_state(self.output_quantizer, "_amax"):
600+
self.output_quantizer.reset_amax()
601+
dummy_input = torch.ones(
602+
(1, 1, self.weight0.shape[0]),
603+
device=self.weight0.device,
604+
dtype=self.original_weight_dtype,
605+
)
606+
max_calibrate(self.output_quantizer, lambda oq: oq(dummy_input), distributed_sync=False)
607+
# If there are any other states, lets move them to the correct device
608+
609+
self.weight = None
610+
super().modelopt_post_restore(prefix=prefix)
611+
612+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
613+
# _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in
614+
# sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for
615+
# TE Fp8 checkpoint, we need to remove the _extra_state{gemm_idx} for gemm_idx:[1, num_gemms]
616+
# for modelopt checkpoint restore
617+
filtered_state_dict = {
618+
k: v
619+
for k, v in state_dict.items()
620+
if not any(k.endswith(f"_extra_state{num}") for num in range(1, self.num_gemms))
621+
}
622+
return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs)
623+
624+
def _process_quantizer_amax(self, k, v, quantizer_state_dict):
625+
if v.ndim == 4:
626+
quantizer_state_dict[k] = v.squeeze(1).squeeze(-1)
627+
else:
628+
quantizer_state_dict[k] = v.view(-1, 1) if v.numel() > 1 else v.view(-1)
629+
630+
631+
@QuantModuleRegistry.register(
632+
{megatron_te.TEColumnParallelGroupedLinear: "megatron_TEColumnParallelGroupedLinear"}
633+
)
634+
class _QuantTEGroupedColumnParallelLinear(_QuantTEGroupedLinear, _MegatronColumnParallelLinear):
635+
_is_column_parallel = True
636+
637+
638+
@QuantModuleRegistry.register(
639+
{megatron_te.TERowParallelGroupedLinear: "megatron_TERowParallelGroupedLinear"}
640+
)
641+
class _QuantTEGroupedRowParallelLinear(_QuantTEGroupedLinear, _MegatronColumnParallelLinear):
642+
_is_row_parallel = True

modelopt/torch/utils/distributed.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,16 +241,28 @@ def __init__(
241241
self,
242242
data_parallel_group: torch.distributed.ProcessGroup | int | None = None,
243243
tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
244+
expert_model_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
245+
expert_tensor_parallel_group: torch.distributed.ProcessGroup | int | None = None,
244246
):
245247
"""Initialize the parallel state."""
246248
self.data_parallel_group = DistributedProcessGroup(data_parallel_group)
247249
self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group)
250+
self.expert_model_parallel_group = DistributedProcessGroup(expert_model_parallel_group)
251+
self.expert_tensor_parallel_group = None
252+
if expert_tensor_parallel_group is not None:
253+
self.expert_tensor_parallel_group = DistributedProcessGroup(
254+
expert_tensor_parallel_group
255+
)
248256

249257
def __repr__(self) -> str:
250-
return (
258+
parallel_groups = (
251259
f"data_parallel_group: {self.data_parallel_group}, "
252260
f"tensor_parallel_group: {self.tensor_parallel_group}, "
261+
f"expert_model_parallel_group: {self.expert_model_parallel_group}"
253262
)
263+
if self.expert_tensor_parallel_group:
264+
parallel_groups += f"expert_tensor_parallel_group: {self.expert_tensor_parallel_group}"
265+
return parallel_groups
254266

255267

256268
def get_group(ranks: list[int]):

0 commit comments

Comments
 (0)