Skip to content

Commit 963657d

Browse files
committed
Added quantization support for TEGroupedMoE for megatron-lm
Signed-off-by: Kinjal Patel <[email protected]>
1 parent 95da832 commit 963657d

File tree

6 files changed

+715
-16
lines changed

6 files changed

+715
-16
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,22 +80,27 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
8080
if not distributed_sync:
8181
return
8282

83-
def sync_quantizer_amax_across_dp_cp(quantizer, parallel_state):
83+
def sync_quantizer_amax_across_dp_cp_ep(quantizer, parallel_state):
8484
"""Synchronize the amax across all ranks in the data parallel and context parallel groups."""
8585
if isinstance(quantizer, SequentialQuantizer):
8686
for _q in quantizer:
87-
sync_quantizer_amax_across_dp_cp(_q, parallel_state)
87+
sync_quantizer_amax_across_dp_cp_ep(_q, parallel_state)
8888
return
8989
if getattr(quantizer, "_amax", None) is not None:
9090
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
9191
quantizer.sync_amax_across_distributed_group(parallel_state.context_parallel_group)
92+
quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group)
93+
if parallel_state.expert_tensor_parallel_group is not None:
94+
quantizer.sync_amax_across_distributed_group(
95+
parallel_state.expert_tensor_parallel_group
96+
)
9297
# TODO: create sync_bias_across_distributed_group
9398

9499
for name, module in model.named_modules():
95100
if isinstance(module, QuantModule):
96101
for child in module.children():
97102
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
98-
sync_quantizer_amax_across_dp_cp(child, module.parallel_state)
103+
sync_quantizer_amax_across_dp_cp_ep(child, module.parallel_state)
99104
# TP sync:
100105
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same
101106

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 170 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import megatron.core.tensor_parallel.layers as megatron_parallel
2323
import megatron.core.transformer.mlp as megatron_mlp
2424
import torch
25+
import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear
26+
from megatron.core.extensions import transformer_engine as megatron_te
2527
from megatron.core.parallel_state import get_data_parallel_group
2628
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
2729
from megatron.core.transformer import MegatronModule
@@ -34,8 +36,8 @@
3436
)
3537
from modelopt.torch.utils.distributed import ParallelState
3638

37-
from ..nn import QuantModuleRegistry, TensorQuantizer
38-
from ..nn.modules.quant_linear import RealQuantLinear
39+
from ..nn import QuantModuleRegistry, SequentialQuantizer, TensorQuantizer
40+
from ..nn.modules.quant_linear import RealQuantLinear, _QuantLinear
3941
from ..qtensor import QTensorWrapper
4042
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
4143

@@ -223,10 +225,18 @@ def _setup(self):
223225
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
224226
except AssertionError:
225227
data_parallel_group = get_data_parallel_group()
228+
229+
try:
230+
expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group()
231+
except AssertionError:
232+
expert_tensor_parallel_group = None
233+
226234
self.parallel_state = ParallelState(
227235
data_parallel_group,
228236
mcore_parallel.get_tensor_model_parallel_group(),
229237
mcore_parallel.get_context_parallel_group(),
238+
mcore_parallel.get_expert_model_parallel_group(),
239+
expert_tensor_parallel_group,
230240
)
231241
super()._setup()
232242

@@ -469,3 +479,161 @@ class _RealQuantMegatronRowParallelLinear(
469479

470480
def forward(self, input, *args, **kwargs):
471481
return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs)
482+
483+
484+
# Register the public te.pytorch.GroupedLinear class
485+
@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear_public"})
486+
class _QuantTEGroupedLinear(_MegatronParallelLinear):
487+
def _setup(self):
488+
data_parallel_group = None
489+
try:
490+
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
491+
except AssertionError:
492+
data_parallel_group = get_data_parallel_group()
493+
494+
try:
495+
expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group()
496+
except AssertionError:
497+
expert_tensor_parallel_group = None
498+
self.parallel_state = ParallelState(
499+
data_parallel_group,
500+
mcore_parallel.get_tensor_model_parallel_group(),
501+
mcore_parallel.get_context_parallel_group(),
502+
mcore_parallel.get_expert_model_parallel_group(),
503+
expert_tensor_parallel_group,
504+
)
505+
self.input_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_input)
506+
self.weight_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_weight)
507+
self.output_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_output)
508+
self.output_quantizer.disable()
509+
510+
# Memorize the original weight.dtype for modelopt_post_restore given that
511+
# the dtype can change later.
512+
self.original_weight_dtype = None if self.weight0 is None else self.weight0.dtype
513+
514+
@property
515+
def functionals_to_replace(self):
516+
original_forward = te_grouped_linear._GroupedLinear.forward
517+
518+
def te_grouped_quantized_linear_fn(ctx, inp, m_splits, *args):
519+
num_gemms = len(m_splits)
520+
weights_and_biases = args[-2 * num_gemms :]
521+
weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:]
522+
quantized_inputs = self.input_quantizer(inp)
523+
quantized_weights = [self.weight_quantizer(weight) for weight in weights]
524+
525+
output = original_forward(
526+
ctx,
527+
quantized_inputs,
528+
m_splits,
529+
*args[: -2 * num_gemms],
530+
*quantized_weights,
531+
*biases,
532+
)
533+
return self.output_quantizer(output)
534+
535+
return [
536+
(
537+
te_grouped_linear._GroupedLinear,
538+
"forward",
539+
te_grouped_quantized_linear_fn,
540+
),
541+
]
542+
543+
def modelopt_post_restore(self, prefix: str = ""):
544+
"""Post restore to correctly configure the TensorQuantizer states for MCore/distributed frameworks.
545+
546+
ModelOpt restores the TensorQuantizer states such as `_amax` and `_pre_quant_scale` to their
547+
shape before saving. However this is not enough for MCore/distributed frameworks since the tensor parallelism
548+
could change between saving and restoring. If the tensor parallelism changes, the shape of the quantizer
549+
states also changes. So we need to re-calculate the quantizer states.
550+
"""
551+
from modelopt.torch.quantization.model_calib import max_calibrate
552+
553+
def _check_unsupported_states(quantizer: TensorQuantizer):
554+
for k in quantizer.state_dict():
555+
if k not in ["_amax", "_pre_quant_scale"]:
556+
warnings.warn(
557+
f"Restore of {k} for {prefix} is not supported. The restore of this layer might be "
558+
f"incorrect. Please implement a custom restore for {k}."
559+
)
560+
561+
def _has_state(quantizer, name):
562+
# Handling for SequentialQuantizer
563+
quantizer = quantizer[0] if isinstance(quantizer, SequentialQuantizer) else quantizer
564+
return hasattr(quantizer, name)
565+
566+
# weights for TEGroupedLinear are stored in weight0, weight1, etc.
567+
if self.weight0 is None:
568+
return
569+
for quantizer in [self.weight_quantizer, self.input_quantizer, self.output_quantizer]:
570+
_check_unsupported_states(
571+
quantizer if isinstance(quantizer, TensorQuantizer) else quantizer[0]
572+
)
573+
if _has_state(self.weight_quantizer, "_amax"):
574+
self.weight_quantizer.reset_amax()
575+
for i in range(self.num_gemms):
576+
weight = getattr(self, f"weight{i}")
577+
assert weight is not None, "weight is None"
578+
579+
max_calibrate(self.weight_quantizer, lambda wq: wq(weight), distributed_sync=False)
580+
if _has_state(self.input_quantizer, "_pre_quant_scale"):
581+
if hasattr(self.input_quantizer, "_pre_quant_scale"):
582+
delattr(self.input_quantizer, "_pre_quant_scale")
583+
pqs = torch.zeros(
584+
(weight.shape[1]), device=weight.device, dtype=self.original_weight_dtype
585+
)
586+
self.input_quantizer.register_buffer("_pre_quant_scale", pqs)
587+
588+
if _has_state(self.input_quantizer, "_amax"):
589+
self.input_quantizer.reset_amax()
590+
dummy_input = torch.ones(
591+
(1, 1, self.weight0.shape[1]),
592+
device=self.weight0.device,
593+
dtype=self.original_weight_dtype,
594+
)
595+
max_calibrate(self.input_quantizer, lambda iq: iq(dummy_input), distributed_sync=False)
596+
if _has_state(self.output_quantizer, "_amax"):
597+
self.output_quantizer.reset_amax()
598+
dummy_input = torch.ones(
599+
(1, 1, self.weight0.shape[0]),
600+
device=self.weight0.device,
601+
dtype=self.original_weight_dtype,
602+
)
603+
max_calibrate(self.output_quantizer, lambda oq: oq(dummy_input), distributed_sync=False)
604+
# If there are any other states, lets move them to the correct device
605+
606+
self.weight = None
607+
super().modelopt_post_restore(prefix=prefix)
608+
609+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
610+
# _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in
611+
# sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for
612+
# TE Fp8 checkpoint, we need to remove the _extra_state{gemm_idx} for gemm_idx:[1, num_gemms]
613+
# for modelopt checkpoint restore
614+
filtered_state_dict = {
615+
k: v
616+
for k, v in state_dict.items()
617+
if not any(k.endswith(f"_extra_state{num}") for num in range(1, self.num_gemms))
618+
}
619+
return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs)
620+
621+
def _process_quantizer_amax(self, k, v, quantizer_state_dict):
622+
if v.ndim == 4:
623+
quantizer_state_dict[k] = v.squeeze(1).squeeze(-1)
624+
else:
625+
quantizer_state_dict[k] = v.view(-1, 1) if v.numel() > 1 else v.view(-1)
626+
627+
628+
@QuantModuleRegistry.register(
629+
{megatron_te.TEColumnParallelGroupedLinear: "megatron_TEColumnParallelGroupedLinear"}
630+
)
631+
class _QuantTEGroupedColumnParallelLinear(_QuantTEGroupedLinear, _MegatronColumnParallelLinear):
632+
_is_column_parallel = True
633+
634+
635+
@QuantModuleRegistry.register(
636+
{megatron_te.TERowParallelGroupedLinear: "megatron_TERowParallelGroupedLinear"}
637+
)
638+
class _QuantTEGroupedRowParallelLinear(_QuantTEGroupedLinear, _MegatronColumnParallelLinear):
639+
_is_row_parallel = True

modelopt/torch/utils/distributed.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,18 +242,30 @@ def __init__(
242242
data_parallel_group: torch.distributed.ProcessGroup | int | None = None,
243243
tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
244244
context_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
245+
expert_model_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
246+
expert_tensor_parallel_group: torch.distributed.ProcessGroup | int | None = None,
245247
):
246248
"""Initialize the parallel state."""
247249
self.data_parallel_group = DistributedProcessGroup(data_parallel_group)
248250
self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group)
249251
self.context_parallel_group = DistributedProcessGroup(context_parallel_group)
252+
self.expert_model_parallel_group = DistributedProcessGroup(expert_model_parallel_group)
253+
self.expert_tensor_parallel_group = None
254+
if expert_tensor_parallel_group is not None:
255+
self.expert_tensor_parallel_group = DistributedProcessGroup(
256+
expert_tensor_parallel_group
257+
)
250258

251259
def __repr__(self) -> str:
252-
return (
260+
parallel_groups = (
253261
f"data_parallel_group: {self.data_parallel_group}, "
254262
f"tensor_parallel_group: {self.tensor_parallel_group}, "
255263
f"context_parallel_group: {self.context_parallel_group}"
264+
f"expert_model_parallel_group: {self.expert_model_parallel_group}"
256265
)
266+
if self.expert_tensor_parallel_group:
267+
parallel_groups += f"expert_tensor_parallel_group: {self.expert_tensor_parallel_group}"
268+
return parallel_groups
257269

258270

259271
def get_group(ranks: list[int]):

0 commit comments

Comments
 (0)