|
22 | 22 | import megatron.core.parallel_state as mcore_parallel |
23 | 23 | import megatron.core.tensor_parallel.layers as megatron_parallel |
24 | 24 | import megatron.core.transformer.mlp as megatron_mlp |
| 25 | +import megatron.core.transformer.moe.experts as megatron_moe |
25 | 26 | import torch |
26 | 27 | from megatron.core.parallel_state import get_data_parallel_group |
27 | 28 | from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region |
|
40 | 41 | from ..qtensor import QTensorWrapper |
41 | 42 | from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear |
42 | 43 |
|
| 44 | +try: |
| 45 | + from megatron.core.extensions.transformer_engine import ( |
| 46 | + TEColumnParallelGroupedLinear, |
| 47 | + TERowParallelGroupedLinear, |
| 48 | + ) |
| 49 | + |
| 50 | + from .transformer_engine import _QuantTEGroupedLinear |
| 51 | + |
| 52 | + HAS_TE = True |
| 53 | +except ImportError: |
| 54 | + HAS_TE = False |
| 55 | + |
43 | 56 | logger = logging.getLogger(__name__) |
44 | 57 |
|
45 | 58 | __all__ = [] |
@@ -221,16 +234,19 @@ class _MegatronParallelLinear(_ParallelLinear): |
221 | 234 | ] |
222 | 235 |
|
223 | 236 | def _setup(self): |
224 | | - data_parallel_group = None |
225 | | - try: |
226 | | - data_parallel_group = get_data_parallel_group(with_context_parallel=True) |
227 | | - except AssertionError: |
228 | | - logger.warning("Context parallel group is not initialized, using data parallel group") |
229 | | - data_parallel_group = get_data_parallel_group() |
230 | | - self.parallel_state = ParallelState( |
231 | | - data_parallel_group, |
232 | | - mcore_parallel.get_tensor_model_parallel_group(), |
233 | | - ) |
| 237 | + if not hasattr(self, "parallel_state") or self.parallel_state is None: |
| 238 | + data_parallel_group = None |
| 239 | + try: |
| 240 | + data_parallel_group = get_data_parallel_group(with_context_parallel=True) |
| 241 | + except AssertionError: |
| 242 | + logger.warning( |
| 243 | + "Context parallel group is not initialized, using data parallel group" |
| 244 | + ) |
| 245 | + data_parallel_group = get_data_parallel_group() |
| 246 | + self.parallel_state = ParallelState( |
| 247 | + data_parallel_group, |
| 248 | + mcore_parallel.get_tensor_model_parallel_group(), |
| 249 | + ) |
234 | 250 | super()._setup() |
235 | 251 |
|
236 | 252 | def _process_quantizer_amax(self, k, v, quantizer_state_dict): |
@@ -472,3 +488,95 @@ class _RealQuantMegatronRowParallelLinear( |
472 | 488 |
|
473 | 489 | def forward(self, input, *args, **kwargs): |
474 | 490 | return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs) |
| 491 | + |
| 492 | + |
| 493 | +@QuantModuleRegistry.register({megatron_moe.SequentialMLP: "megatron_moe_SequentialMLP"}) |
| 494 | +class _MegatronSequentialMLP(_MegatronMLP): |
| 495 | + def _setup(self): |
| 496 | + if not hasattr(self, "parallel_state") or self.parallel_state is None: |
| 497 | + self.parallel_state = ParallelState( |
| 498 | + mcore_parallel.get_expert_data_parallel_group(), |
| 499 | + tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(), |
| 500 | + expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(), |
| 501 | + ) |
| 502 | + |
| 503 | + # Initialize parallel state for submodules local_experts.*.linear_fc1 and local_experts.*.linear_fc2 |
| 504 | + for expert in self.local_experts: |
| 505 | + expert.linear_fc1.parallel_state = self.parallel_state |
| 506 | + expert.linear_fc2.parallel_state = self.parallel_state |
| 507 | + |
| 508 | + def sync_moe_local_experts_amax(self): |
| 509 | + """Sync amax across local experts in a SequentialMLP. |
| 510 | +
|
| 511 | + amax across EP and ETP (for RowParallel) are synchronized as part of model_calib.max_calibrate(). |
| 512 | + This function is called to synchronize the amax values across local experts s.t. all localexperts will |
| 513 | + share the same amax. |
| 514 | + """ |
| 515 | + torch.distributed.barrier() |
| 516 | + # Collect amax from all local experts |
| 517 | + amax_dict = {} |
| 518 | + for expert in self.local_experts: |
| 519 | + for name, module in expert.named_modules(): |
| 520 | + if isinstance(module, TensorQuantizer) and module.amax is not None: |
| 521 | + stored_amax = amax_dict.get(name) |
| 522 | + amax_tensor = module.amax.detach().clone() |
| 523 | + amax_dict[name] = ( |
| 524 | + amax_tensor |
| 525 | + if stored_amax is None |
| 526 | + else torch.maximum(stored_amax, amax_tensor) |
| 527 | + ) |
| 528 | + |
| 529 | + # Apply synchronized amax values back to all local experts |
| 530 | + for expert in self.local_experts: |
| 531 | + for name, module in expert.named_modules(): |
| 532 | + if isinstance(module, TensorQuantizer) and module.amax is not None: |
| 533 | + module.amax = amax_dict[name].detach().clone().to(module.amax.device) |
| 534 | + |
| 535 | + |
| 536 | +if HAS_TE: |
| 537 | + # Quantized subclasses to support TEGroupedMLP quantization |
| 538 | + class _QuantMegatronTEGroupedLinear(_QuantTEGroupedLinear, _MegatronParallelLinear): |
| 539 | + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): |
| 540 | + # _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in |
| 541 | + # sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for |
| 542 | + # TE Fp8 checkpoint, we need to remove the _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] |
| 543 | + # for modelopt checkpoint restore |
| 544 | + filtered_state_dict = { |
| 545 | + k: v |
| 546 | + for k, v in state_dict.items() |
| 547 | + if not any(k.endswith(f"_extra_state{num}") for num in range(1, self.num_gemms)) |
| 548 | + } |
| 549 | + return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs) |
| 550 | + |
| 551 | + def _process_quantizer_amax(self, k, v, quantizer_state_dict): |
| 552 | + assert v.numel() == 1, "TEGroupedLinear only supports per-tensor quantization" |
| 553 | + quantizer_state_dict[k] = v.view(-1) |
| 554 | + |
| 555 | + @QuantModuleRegistry.register( |
| 556 | + {TEColumnParallelGroupedLinear: "megatron_TEColumnParallelGroupedLinear"} |
| 557 | + ) |
| 558 | + class _MegatronTEGroupedColumnParallelLinear( |
| 559 | + _QuantMegatronTEGroupedLinear, _MegatronColumnParallelLinear |
| 560 | + ): |
| 561 | + pass |
| 562 | + |
| 563 | + @QuantModuleRegistry.register( |
| 564 | + {TERowParallelGroupedLinear: "megatron_TERowParallelGroupedLinear"} |
| 565 | + ) |
| 566 | + class _MegatronTEGroupedRowParallelLinear( |
| 567 | + _QuantMegatronTEGroupedLinear, _MegatronRowParallelLinear |
| 568 | + ): |
| 569 | + pass |
| 570 | + |
| 571 | + @QuantModuleRegistry.register({megatron_moe.TEGroupedMLP: "megatron_moe_TEGroupedMLP"}) |
| 572 | + class _MegatronTEGroupedMLP(_MegatronMLP): |
| 573 | + def _setup(self): |
| 574 | + if not hasattr(self, "parallel_state") or self.parallel_state is None: |
| 575 | + self.parallel_state = ParallelState( |
| 576 | + mcore_parallel.get_expert_data_parallel_group(), |
| 577 | + tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(), |
| 578 | + expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(), |
| 579 | + ) |
| 580 | + # initialize parallel state for submodules linear_fc1 and linear_fc2 |
| 581 | + self.linear_fc1.parallel_state = self.parallel_state |
| 582 | + self.linear_fc2.parallel_state = self.parallel_state |
0 commit comments