Skip to content

Commit 95f6c25

Browse files
committed
fix moe token forward
Signed-off-by: Jennifer Chen <[email protected]>
1 parent 0b25bcb commit 95f6c25

File tree

2 files changed

+49
-8
lines changed

2 files changed

+49
-8
lines changed

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,16 @@ def backward(ctx, grad_output):
345345
_transposed_quantize = _TransposedQuantization.apply
346346

347347

348-
class _QuantMoeSparseMoe(QuantModule):
348+
class _QuantSparseMoe(QuantModule):
349+
"""Module to support special handling of token dispatching during calibration.
350+
351+
During calibration, we forward all tokens to all experts so that all experts see sufficient tokens to calibrate.
352+
However, even in calibration mode, the actual top_k routing is used to calculate the actual outputs this instance
353+
returns.
354+
355+
If calibration is not enabled, this module behaves as a normal MoELayer.
356+
"""
357+
349358
def _setup(self):
350359
pass
351360

@@ -480,7 +489,7 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
480489
return self.w2_linear[expert_idx](x1)
481490

482491

483-
class _QuantDbrxFFN(_QuantMoeSparseMoe):
492+
class _QuantDbrxFFN(_QuantSparseMoe):
484493
@property
485494
def num_experts(self):
486495
return self.router.moe_num_experts
@@ -498,7 +507,7 @@ def top_k(self, value):
498507
from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe
499508

500509
if Llama4TextMoe not in QuantModuleRegistry:
501-
QuantModuleRegistry.register({Llama4TextMoe: "hf.Llama4TextMoe"})(_QuantMoeSparseMoe)
510+
QuantModuleRegistry.register({Llama4TextMoe: "hf.Llama4TextMoe"})(_QuantSparseMoe)
502511

503512
if Llama4TextExperts not in QuantModuleRegistry:
504513
QuantModuleRegistry.register({Llama4TextExperts: "hf.Llama4TextExperts"})(
@@ -526,7 +535,7 @@ def top_k(self, value):
526535

527536
if MixtralSparseMoeBlock not in QuantModuleRegistry:
528537
QuantModuleRegistry.register({MixtralSparseMoeBlock: "hf.MixtralSparseMoeBlock"})(
529-
_QuantMoeSparseMoe
538+
_QuantSparseMoe
530539
)
531540
except ImportError:
532541
pass
@@ -544,7 +553,7 @@ def top_k(self, value):
544553

545554
if Qwen3MoeSparseMoeBlock not in QuantModuleRegistry:
546555
QuantModuleRegistry.register({Qwen3MoeSparseMoeBlock: "hf.Qwen3MoeSparseMoeBlock"})(
547-
_QuantMoeSparseMoe
556+
_QuantSparseMoe
548557
)
549558
except ImportError:
550559
pass
@@ -554,7 +563,7 @@ def top_k(self, value):
554563

555564
if Qwen2MoeSparseMoeBlock not in QuantModuleRegistry:
556565
QuantModuleRegistry.register({Qwen2MoeSparseMoeBlock: "hf.Qwen2MoeSparseMoeBlock"})(
557-
_QuantMoeSparseMoe
566+
_QuantSparseMoe
558567
)
559568
except ImportError:
560569
pass
@@ -564,7 +573,7 @@ def top_k(self, value):
564573

565574
if Qwen3NextSparseMoeBlock not in QuantModuleRegistry:
566575
QuantModuleRegistry.register({Qwen3NextSparseMoeBlock: "hf.Qwen3NextSparseMoeBlock"})(
567-
_QuantMoeSparseMoe
576+
_QuantSparseMoe
568577
)
569578
except ImportError:
570579
pass

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import megatron.core.tensor_parallel.layers as megatron_parallel
2424
import megatron.core.transformer.mlp as megatron_mlp
2525
import megatron.core.transformer.moe.experts as megatron_moe
26+
import megatron.core.transformer.moe.moe_layer as megatron_moe_layer
2627
import torch
2728
from megatron.core.parallel_state import get_data_parallel_group
2829
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
@@ -36,7 +37,7 @@
3637
)
3738
from modelopt.torch.utils.distributed import ParallelState
3839

39-
from ..nn import QuantModuleRegistry, TensorQuantizer
40+
from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer
4041
from ..nn.modules.quant_linear import RealQuantLinear
4142
from ..qtensor import QTensorWrapper
4243
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
@@ -247,6 +248,14 @@ def _setup(self):
247248
data_parallel_group,
248249
mcore_parallel.get_tensor_model_parallel_group(),
249250
)
251+
252+
if getattr(self, "gradient_accumulation_fusion", False):
253+
warnings.warn(
254+
"gradient_accumulation_fusion is not supported with ModelOpt quantization. "
255+
"Setting gradient_accumulation_fusion to False."
256+
)
257+
self.gradient_accumulation_fusion = False
258+
250259
super()._setup()
251260

252261
def _process_quantizer_amax(self, k, v, quantizer_state_dict):
@@ -580,3 +589,26 @@ def _setup(self):
580589
# initialize parallel state for submodules linear_fc1 and linear_fc2
581590
self.linear_fc1.parallel_state = self.parallel_state
582591
self.linear_fc2.parallel_state = self.parallel_state
592+
593+
594+
@QuantModuleRegistry.register({megatron_moe_layer.MoELayer: "megatron_moe_MoELayer"})
595+
class _QuantMoELayer(QuantModule):
596+
"""Module to support special handling of token dispatching during calibration.
597+
598+
During calibration, we forward all tokens to all experts so that all experts see sufficient tokens to calibrate.
599+
However, even in calibration mode, the actual top_k routing is used to calculate the actual outputs this instance
600+
returns.
601+
602+
If calibration is not enabled, this module behaves as a normal MoELayer.
603+
"""
604+
605+
def _setup(self):
606+
pass
607+
608+
def forward(self, hidden_states):
609+
if any(getattr(m, "_if_calib", False) for m in self.experts.modules()):
610+
original_top_k = self.router.topk
611+
self.router.topk = self.router.num_experts
612+
super().forward(hidden_states)
613+
self.router.topk = original_top_k
614+
return super().forward(hidden_states)

0 commit comments

Comments
 (0)