|
23 | 23 | import megatron.core.tensor_parallel.layers as megatron_parallel |
24 | 24 | import megatron.core.transformer.mlp as megatron_mlp |
25 | 25 | import megatron.core.transformer.moe.experts as megatron_moe |
| 26 | +import megatron.core.transformer.moe.moe_layer as megatron_moe_layer |
26 | 27 | import torch |
27 | 28 | from megatron.core.parallel_state import get_data_parallel_group |
28 | 29 | from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region |
|
36 | 37 | ) |
37 | 38 | from modelopt.torch.utils.distributed import ParallelState |
38 | 39 |
|
39 | | -from ..nn import QuantModuleRegistry, TensorQuantizer |
| 40 | +from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer |
40 | 41 | from ..nn.modules.quant_linear import RealQuantLinear |
41 | 42 | from ..qtensor import QTensorWrapper |
42 | 43 | from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear |
@@ -247,6 +248,14 @@ def _setup(self): |
247 | 248 | data_parallel_group, |
248 | 249 | mcore_parallel.get_tensor_model_parallel_group(), |
249 | 250 | ) |
| 251 | + |
| 252 | + if getattr(self, "gradient_accumulation_fusion", False): |
| 253 | + warnings.warn( |
| 254 | + "gradient_accumulation_fusion is not supported with ModelOpt quantization. " |
| 255 | + "Setting gradient_accumulation_fusion to False." |
| 256 | + ) |
| 257 | + self.gradient_accumulation_fusion = False |
| 258 | + |
250 | 259 | super()._setup() |
251 | 260 |
|
252 | 261 | def _process_quantizer_amax(self, k, v, quantizer_state_dict): |
@@ -580,3 +589,26 @@ def _setup(self): |
580 | 589 | # initialize parallel state for submodules linear_fc1 and linear_fc2 |
581 | 590 | self.linear_fc1.parallel_state = self.parallel_state |
582 | 591 | self.linear_fc2.parallel_state = self.parallel_state |
| 592 | + |
| 593 | + |
| 594 | +@QuantModuleRegistry.register({megatron_moe_layer.MoELayer: "megatron_moe_MoELayer"}) |
| 595 | +class _QuantMoELayer(QuantModule): |
| 596 | + """Module to support special handling of token dispatching during calibration. |
| 597 | +
|
| 598 | + During calibration, we forward all tokens to all experts so that all experts see sufficient tokens to calibrate. |
| 599 | + However, even in calibration mode, the actual top_k routing is used to calculate the actual outputs this instance |
| 600 | + returns. |
| 601 | +
|
| 602 | + If calibration is not enabled, this module behaves as a normal MoELayer. |
| 603 | + """ |
| 604 | + |
| 605 | + def _setup(self): |
| 606 | + pass |
| 607 | + |
| 608 | + def forward(self, hidden_states): |
| 609 | + if any(getattr(m, "_if_calib", False) for m in self.experts.modules()): |
| 610 | + original_top_k = self.router.topk |
| 611 | + self.router.topk = self.router.num_experts |
| 612 | + super().forward(hidden_states) |
| 613 | + self.router.topk = original_top_k |
| 614 | + return super().forward(hidden_states) |
0 commit comments