Skip to content

Commit 6ef9954

Browse files
Added support for quantizing TEGroupedMLP for megatron-lm (#403)
Signed-off-by: Kinjal Patel <[email protected]>
1 parent 8c6b915 commit 6ef9954

File tree

8 files changed

+661
-30
lines changed

8 files changed

+661
-30
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,21 +80,22 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
8080
if not distributed_sync:
8181
return
8282

83-
def sync_quantizer_amax_across_dp(quantizer, parallel_state):
84-
"""Synchronize the amax across all ranks in the data parallel group."""
83+
def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state):
84+
"""Synchronize the amax across all ranks in the data parallel and expert parallel groups."""
8585
if isinstance(quantizer, SequentialQuantizer):
8686
for _q in quantizer:
87-
sync_quantizer_amax_across_dp(_q, parallel_state)
87+
sync_quantizer_amax_across_dp_ep(_q, parallel_state)
8888
return
8989
if getattr(quantizer, "_amax", None) is not None:
9090
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
91+
quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group)
9192
# TODO: create sync_bias_across_distributed_group
9293

9394
for name, module in model.named_modules():
9495
if isinstance(module, QuantModule):
9596
for child in module.children():
9697
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
97-
sync_quantizer_amax_across_dp(child, module.parallel_state)
98+
sync_quantizer_amax_across_dp_ep(child, module.parallel_state)
9899
# TP sync:
99100
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same
100101

@@ -117,6 +118,7 @@ def sync_quantizer_amax_across_tp(
117118
# Syncing amax across TP for sequential quantizer
118119
if isinstance(quantizer, SequentialQuantizer):
119120
for _q in quantizer:
121+
# Syncing amax across TP for sequential quantizer
120122
sync_quantizer_amax_across_tp(
121123
_q, linear_name, quantizer_type, axes_for_sync, parallel_state
122124
)
@@ -174,6 +176,10 @@ def sync_quantizer_amax_across_tp(
174176
parallel_state=module.parallel_state,
175177
)
176178

179+
for name, module in model.named_modules():
180+
if hasattr(module, "sync_moe_local_experts_amax"):
181+
module.sync_moe_local_experts_amax()
182+
177183

178184
def enable_stats_collection(model: nn.Module):
179185
"""Enable stats collection for all quantizers in the model."""

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 118 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import megatron.core.parallel_state as mcore_parallel
2323
import megatron.core.tensor_parallel.layers as megatron_parallel
2424
import megatron.core.transformer.mlp as megatron_mlp
25+
import megatron.core.transformer.moe.experts as megatron_moe
2526
import torch
2627
from megatron.core.parallel_state import get_data_parallel_group
2728
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
@@ -40,6 +41,18 @@
4041
from ..qtensor import QTensorWrapper
4142
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
4243

44+
try:
45+
from megatron.core.extensions.transformer_engine import (
46+
TEColumnParallelGroupedLinear,
47+
TERowParallelGroupedLinear,
48+
)
49+
50+
from .transformer_engine import _QuantTEGroupedLinear
51+
52+
HAS_TE = True
53+
except ImportError:
54+
HAS_TE = False
55+
4356
logger = logging.getLogger(__name__)
4457

4558
__all__ = []
@@ -221,16 +234,19 @@ class _MegatronParallelLinear(_ParallelLinear):
221234
]
222235

223236
def _setup(self):
224-
data_parallel_group = None
225-
try:
226-
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
227-
except AssertionError:
228-
logger.warning("Context parallel group is not initialized, using data parallel group")
229-
data_parallel_group = get_data_parallel_group()
230-
self.parallel_state = ParallelState(
231-
data_parallel_group,
232-
mcore_parallel.get_tensor_model_parallel_group(),
233-
)
237+
if not hasattr(self, "parallel_state") or self.parallel_state is None:
238+
data_parallel_group = None
239+
try:
240+
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
241+
except AssertionError:
242+
logger.warning(
243+
"Context parallel group is not initialized, using data parallel group"
244+
)
245+
data_parallel_group = get_data_parallel_group()
246+
self.parallel_state = ParallelState(
247+
data_parallel_group,
248+
mcore_parallel.get_tensor_model_parallel_group(),
249+
)
234250
super()._setup()
235251

236252
def _process_quantizer_amax(self, k, v, quantizer_state_dict):
@@ -472,3 +488,95 @@ class _RealQuantMegatronRowParallelLinear(
472488

473489
def forward(self, input, *args, **kwargs):
474490
return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs)
491+
492+
493+
@QuantModuleRegistry.register({megatron_moe.SequentialMLP: "megatron_moe_SequentialMLP"})
494+
class _MegatronSequentialMLP(_MegatronMLP):
495+
def _setup(self):
496+
if not hasattr(self, "parallel_state") or self.parallel_state is None:
497+
self.parallel_state = ParallelState(
498+
mcore_parallel.get_expert_data_parallel_group(),
499+
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(),
500+
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(),
501+
)
502+
503+
# Initialize parallel state for submodules local_experts.*.linear_fc1 and local_experts.*.linear_fc2
504+
for expert in self.local_experts:
505+
expert.linear_fc1.parallel_state = self.parallel_state
506+
expert.linear_fc2.parallel_state = self.parallel_state
507+
508+
def sync_moe_local_experts_amax(self):
509+
"""Sync amax across local experts in a SequentialMLP.
510+
511+
amax across EP and ETP (for RowParallel) are synchronized as part of model_calib.max_calibrate().
512+
This function is called to synchronize the amax values across local experts s.t. all localexperts will
513+
share the same amax.
514+
"""
515+
torch.distributed.barrier()
516+
# Collect amax from all local experts
517+
amax_dict = {}
518+
for expert in self.local_experts:
519+
for name, module in expert.named_modules():
520+
if isinstance(module, TensorQuantizer) and module.amax is not None:
521+
stored_amax = amax_dict.get(name)
522+
amax_tensor = module.amax.detach().clone()
523+
amax_dict[name] = (
524+
amax_tensor
525+
if stored_amax is None
526+
else torch.maximum(stored_amax, amax_tensor)
527+
)
528+
529+
# Apply synchronized amax values back to all local experts
530+
for expert in self.local_experts:
531+
for name, module in expert.named_modules():
532+
if isinstance(module, TensorQuantizer) and module.amax is not None:
533+
module.amax = amax_dict[name].detach().clone().to(module.amax.device)
534+
535+
536+
if HAS_TE:
537+
# Quantized subclasses to support TEGroupedMLP quantization
538+
class _QuantMegatronTEGroupedLinear(_QuantTEGroupedLinear, _MegatronParallelLinear):
539+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
540+
# _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in
541+
# sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for
542+
# TE Fp8 checkpoint, we need to remove the _extra_state{gemm_idx} for gemm_idx:[1, num_gemms]
543+
# for modelopt checkpoint restore
544+
filtered_state_dict = {
545+
k: v
546+
for k, v in state_dict.items()
547+
if not any(k.endswith(f"_extra_state{num}") for num in range(1, self.num_gemms))
548+
}
549+
return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs)
550+
551+
def _process_quantizer_amax(self, k, v, quantizer_state_dict):
552+
assert v.numel() == 1, "TEGroupedLinear only supports per-tensor quantization"
553+
quantizer_state_dict[k] = v.view(-1)
554+
555+
@QuantModuleRegistry.register(
556+
{TEColumnParallelGroupedLinear: "megatron_TEColumnParallelGroupedLinear"}
557+
)
558+
class _MegatronTEGroupedColumnParallelLinear(
559+
_QuantMegatronTEGroupedLinear, _MegatronColumnParallelLinear
560+
):
561+
pass
562+
563+
@QuantModuleRegistry.register(
564+
{TERowParallelGroupedLinear: "megatron_TERowParallelGroupedLinear"}
565+
)
566+
class _MegatronTEGroupedRowParallelLinear(
567+
_QuantMegatronTEGroupedLinear, _MegatronRowParallelLinear
568+
):
569+
pass
570+
571+
@QuantModuleRegistry.register({megatron_moe.TEGroupedMLP: "megatron_moe_TEGroupedMLP"})
572+
class _MegatronTEGroupedMLP(_MegatronMLP):
573+
def _setup(self):
574+
if not hasattr(self, "parallel_state") or self.parallel_state is None:
575+
self.parallel_state = ParallelState(
576+
mcore_parallel.get_expert_data_parallel_group(),
577+
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(),
578+
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(),
579+
)
580+
# initialize parallel state for submodules linear_fc1 and linear_fc2
581+
self.linear_fc1.parallel_state = self.parallel_state
582+
self.linear_fc2.parallel_state = self.parallel_state

modelopt/torch/quantization/plugins/transformer_engine.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import torch
1919
import transformer_engine as te
20+
import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear
2021
import transformer_engine.pytorch.module.linear as te_linear
2122

2223
from ..nn import QuantModuleRegistry
@@ -58,3 +59,60 @@ def te_quantized_linear_fn(package, func_name, self, *args, **kwargs):
5859

5960
# Override the quantized linear function
6061
_quantized_linear_fn = te_quantized_linear_fn
62+
63+
64+
# Register the public te.pytorch.GroupedLinear class
65+
@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear"})
66+
class _QuantTEGroupedLinear(_ParallelLinear):
67+
_functionals_to_replace = [
68+
(te_grouped_linear._GroupedLinear, "forward"),
69+
(te_grouped_linear._GroupedLinear, "apply"),
70+
]
71+
72+
def _setup(self):
73+
# GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
74+
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
75+
# self.weight0 to self.weight to run the quantizer states initialization.
76+
assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear"
77+
self.weight = self.weight0
78+
# Memorize the original weight.dtype for modelopt_post_restore given that
79+
# the dtype can change later.
80+
super()._setup()
81+
# Remove self.weight after setup.
82+
delattr(self, "weight")
83+
84+
def modelopt_post_restore(self, prefix: str = ""):
85+
# GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
86+
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
87+
# self.weight0 to self.weight to run the quantizer states initialization.
88+
assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear"
89+
self.weight = self.weight0
90+
super().modelopt_post_restore(prefix=prefix)
91+
# Remove self.weight after post_restore.
92+
delattr(self, "weight")
93+
94+
@staticmethod
95+
def te_grouped_quantized_linear_fn(package, func_name, self, *args):
96+
idx = 1 if func_name == "_forward" else 0
97+
inp = args[idx]
98+
num_gemms = len(args[idx + 1])
99+
weights_and_biases = args[-2 * num_gemms :]
100+
weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:]
101+
quantized_inputs = self.input_quantizer(inp)
102+
quantized_weights = [self.weight_quantizer(weight) for weight in weights]
103+
104+
output = getattr(package, func_name)(
105+
*(
106+
args[0],
107+
quantized_inputs,
108+
)
109+
if func_name == "_forward"
110+
else (quantized_inputs,),
111+
*args[idx + 1 : -2 * num_gemms],
112+
*quantized_weights,
113+
*biases,
114+
)
115+
return self.output_quantizer(output)
116+
117+
# Override the quantized linear function
118+
_quantized_linear_fn = te_grouped_quantized_linear_fn

modelopt/torch/quantization/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,11 @@ def is_quantized_linear(module):
251251
isinstance(module, QuantModule)
252252
and isinstance(getattr(module, "input_quantizer", None), TensorQuantizer)
253253
and hasattr(module, "weight_quantizer")
254-
and getattr(module, "weight", None) is not None
255-
and module.weight.dim() == 2
254+
and (
255+
(getattr(module, "weight", None) is not None and module.weight.dim() == 2)
256+
# module.weight0 check is required to support TEGroupedLinear
257+
or (getattr(module, "weight0", None) is not None and module.weight0.dim() == 2)
258+
)
256259
)
257260

258261

modelopt/torch/utils/distributed.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,16 +241,20 @@ def __init__(
241241
self,
242242
data_parallel_group: torch.distributed.ProcessGroup | int | None = None,
243243
tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
244+
expert_model_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
244245
):
245246
"""Initialize the parallel state."""
246247
self.data_parallel_group = DistributedProcessGroup(data_parallel_group)
247248
self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group)
249+
self.expert_model_parallel_group = DistributedProcessGroup(expert_model_parallel_group)
248250

249251
def __repr__(self) -> str:
250-
return (
252+
parallel_groups = (
251253
f"data_parallel_group: {self.data_parallel_group}, "
252254
f"tensor_parallel_group: {self.tensor_parallel_group}, "
255+
f"expert_model_parallel_group: {self.expert_model_parallel_group}"
253256
)
257+
return parallel_groups
254258

255259

256260
def get_group(ranks: list[int]):

0 commit comments

Comments
 (0)