|
23 | 23 | import megatron.core.tensor_parallel.layers as megatron_parallel |
24 | 24 | import megatron.core.transformer.mlp as megatron_mlp |
25 | 25 | import torch |
| 26 | +import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear |
| 27 | +from megatron.core.extensions import transformer_engine as megatron_te |
26 | 28 | from megatron.core.parallel_state import get_data_parallel_group |
27 | 29 | from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region |
28 | 30 | from megatron.core.transformer import MegatronModule |
|
35 | 37 | ) |
36 | 38 | from modelopt.torch.utils.distributed import ParallelState |
37 | 39 |
|
38 | | -from ..nn import QuantModuleRegistry, TensorQuantizer |
39 | | -from ..nn.modules.quant_linear import RealQuantLinear |
| 40 | +from ..nn import QuantModuleRegistry, SequentialQuantizer, TensorQuantizer |
| 41 | +from ..nn.modules.quant_linear import RealQuantLinear, _QuantLinear |
40 | 42 | from ..qtensor import QTensorWrapper |
41 | 43 | from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear |
42 | 44 |
|
@@ -227,9 +229,17 @@ def _setup(self): |
227 | 229 | except AssertionError: |
228 | 230 | logger.warning("Context parallel group is not initialized, using data parallel group") |
229 | 231 | data_parallel_group = get_data_parallel_group() |
| 232 | + |
| 233 | + try: |
| 234 | + expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group() |
| 235 | + except AssertionError: |
| 236 | + expert_tensor_parallel_group = None |
| 237 | + |
230 | 238 | self.parallel_state = ParallelState( |
231 | 239 | data_parallel_group, |
232 | 240 | mcore_parallel.get_tensor_model_parallel_group(), |
| 241 | + mcore_parallel.get_expert_model_parallel_group(), |
| 242 | + expert_tensor_parallel_group, |
233 | 243 | ) |
234 | 244 | super()._setup() |
235 | 245 |
|
@@ -472,3 +482,161 @@ class _RealQuantMegatronRowParallelLinear( |
472 | 482 |
|
473 | 483 | def forward(self, input, *args, **kwargs): |
474 | 484 | return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs) |
| 485 | + |
| 486 | + |
| 487 | +# Register the public te.pytorch.GroupedLinear class |
| 488 | +@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear_public"}) |
| 489 | +class _QuantTEGroupedLinear(_MegatronParallelLinear): |
| 490 | + def _setup(self): |
| 491 | + data_parallel_group = None |
| 492 | + try: |
| 493 | + data_parallel_group = get_data_parallel_group(with_context_parallel=True) |
| 494 | + except AssertionError: |
| 495 | + data_parallel_group = get_data_parallel_group() |
| 496 | + |
| 497 | + try: |
| 498 | + expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group() |
| 499 | + except AssertionError: |
| 500 | + expert_tensor_parallel_group = None |
| 501 | + self.parallel_state = ParallelState( |
| 502 | + data_parallel_group, |
| 503 | + mcore_parallel.get_tensor_model_parallel_group(), |
| 504 | + mcore_parallel.get_context_parallel_group(), |
| 505 | + mcore_parallel.get_expert_model_parallel_group(), |
| 506 | + expert_tensor_parallel_group, |
| 507 | + ) |
| 508 | + self.input_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_input) |
| 509 | + self.weight_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_weight) |
| 510 | + self.output_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_output) |
| 511 | + self.output_quantizer.disable() |
| 512 | + |
| 513 | + # Memorize the original weight.dtype for modelopt_post_restore given that |
| 514 | + # the dtype can change later. |
| 515 | + self.original_weight_dtype = None if self.weight0 is None else self.weight0.dtype |
| 516 | + |
| 517 | + @property |
| 518 | + def functionals_to_replace(self): |
| 519 | + original_forward = te_grouped_linear._GroupedLinear.forward |
| 520 | + |
| 521 | + def te_grouped_quantized_linear_fn(ctx, inp, m_splits, *args): |
| 522 | + num_gemms = len(m_splits) |
| 523 | + weights_and_biases = args[-2 * num_gemms :] |
| 524 | + weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:] |
| 525 | + quantized_inputs = self.input_quantizer(inp) |
| 526 | + quantized_weights = [self.weight_quantizer(weight) for weight in weights] |
| 527 | + |
| 528 | + output = original_forward( |
| 529 | + ctx, |
| 530 | + quantized_inputs, |
| 531 | + m_splits, |
| 532 | + *args[: -2 * num_gemms], |
| 533 | + *quantized_weights, |
| 534 | + *biases, |
| 535 | + ) |
| 536 | + return self.output_quantizer(output) |
| 537 | + |
| 538 | + return [ |
| 539 | + ( |
| 540 | + te_grouped_linear._GroupedLinear, |
| 541 | + "forward", |
| 542 | + te_grouped_quantized_linear_fn, |
| 543 | + ), |
| 544 | + ] |
| 545 | + |
| 546 | + def modelopt_post_restore(self, prefix: str = ""): |
| 547 | + """Post restore to correctly configure the TensorQuantizer states for MCore/distributed frameworks. |
| 548 | +
|
| 549 | + ModelOpt restores the TensorQuantizer states such as `_amax` and `_pre_quant_scale` to their |
| 550 | + shape before saving. However this is not enough for MCore/distributed frameworks since the tensor parallelism |
| 551 | + could change between saving and restoring. If the tensor parallelism changes, the shape of the quantizer |
| 552 | + states also changes. So we need to re-calculate the quantizer states. |
| 553 | + """ |
| 554 | + from modelopt.torch.quantization.model_calib import max_calibrate |
| 555 | + |
| 556 | + def _check_unsupported_states(quantizer: TensorQuantizer): |
| 557 | + for k in quantizer.state_dict(): |
| 558 | + if k not in ["_amax", "_pre_quant_scale"]: |
| 559 | + warnings.warn( |
| 560 | + f"Restore of {k} for {prefix} is not supported. The restore of this layer might be " |
| 561 | + f"incorrect. Please implement a custom restore for {k}." |
| 562 | + ) |
| 563 | + |
| 564 | + def _has_state(quantizer, name): |
| 565 | + # Handling for SequentialQuantizer |
| 566 | + quantizer = quantizer[0] if isinstance(quantizer, SequentialQuantizer) else quantizer |
| 567 | + return hasattr(quantizer, name) |
| 568 | + |
| 569 | + # weights for TEGroupedLinear are stored in weight0, weight1, etc. |
| 570 | + if self.weight0 is None: |
| 571 | + return |
| 572 | + for quantizer in [self.weight_quantizer, self.input_quantizer, self.output_quantizer]: |
| 573 | + _check_unsupported_states( |
| 574 | + quantizer if isinstance(quantizer, TensorQuantizer) else quantizer[0] |
| 575 | + ) |
| 576 | + if _has_state(self.weight_quantizer, "_amax"): |
| 577 | + self.weight_quantizer.reset_amax() |
| 578 | + for i in range(self.num_gemms): |
| 579 | + weight = getattr(self, f"weight{i}") |
| 580 | + assert weight is not None, "weight is None" |
| 581 | + |
| 582 | + max_calibrate(self.weight_quantizer, lambda wq: wq(weight), distributed_sync=False) |
| 583 | + if _has_state(self.input_quantizer, "_pre_quant_scale"): |
| 584 | + if hasattr(self.input_quantizer, "_pre_quant_scale"): |
| 585 | + delattr(self.input_quantizer, "_pre_quant_scale") |
| 586 | + pqs = torch.zeros( |
| 587 | + (weight.shape[1]), device=weight.device, dtype=self.original_weight_dtype |
| 588 | + ) |
| 589 | + self.input_quantizer.register_buffer("_pre_quant_scale", pqs) |
| 590 | + |
| 591 | + if _has_state(self.input_quantizer, "_amax"): |
| 592 | + self.input_quantizer.reset_amax() |
| 593 | + dummy_input = torch.ones( |
| 594 | + (1, 1, self.weight0.shape[1]), |
| 595 | + device=self.weight0.device, |
| 596 | + dtype=self.original_weight_dtype, |
| 597 | + ) |
| 598 | + max_calibrate(self.input_quantizer, lambda iq: iq(dummy_input), distributed_sync=False) |
| 599 | + if _has_state(self.output_quantizer, "_amax"): |
| 600 | + self.output_quantizer.reset_amax() |
| 601 | + dummy_input = torch.ones( |
| 602 | + (1, 1, self.weight0.shape[0]), |
| 603 | + device=self.weight0.device, |
| 604 | + dtype=self.original_weight_dtype, |
| 605 | + ) |
| 606 | + max_calibrate(self.output_quantizer, lambda oq: oq(dummy_input), distributed_sync=False) |
| 607 | + # If there are any other states, lets move them to the correct device |
| 608 | + |
| 609 | + self.weight = None |
| 610 | + super().modelopt_post_restore(prefix=prefix) |
| 611 | + |
| 612 | + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): |
| 613 | + # _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in |
| 614 | + # sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for |
| 615 | + # TE Fp8 checkpoint, we need to remove the _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] |
| 616 | + # for modelopt checkpoint restore |
| 617 | + filtered_state_dict = { |
| 618 | + k: v |
| 619 | + for k, v in state_dict.items() |
| 620 | + if not any(k.endswith(f"_extra_state{num}") for num in range(1, self.num_gemms)) |
| 621 | + } |
| 622 | + return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs) |
| 623 | + |
| 624 | + def _process_quantizer_amax(self, k, v, quantizer_state_dict): |
| 625 | + if v.ndim == 4: |
| 626 | + quantizer_state_dict[k] = v.squeeze(1).squeeze(-1) |
| 627 | + else: |
| 628 | + quantizer_state_dict[k] = v.view(-1, 1) if v.numel() > 1 else v.view(-1) |
| 629 | + |
| 630 | + |
| 631 | +@QuantModuleRegistry.register( |
| 632 | + {megatron_te.TEColumnParallelGroupedLinear: "megatron_TEColumnParallelGroupedLinear"} |
| 633 | +) |
| 634 | +class _QuantTEGroupedColumnParallelLinear(_QuantTEGroupedLinear, _MegatronColumnParallelLinear): |
| 635 | + _is_column_parallel = True |
| 636 | + |
| 637 | + |
| 638 | +@QuantModuleRegistry.register( |
| 639 | + {megatron_te.TERowParallelGroupedLinear: "megatron_TERowParallelGroupedLinear"} |
| 640 | +) |
| 641 | +class _QuantTEGroupedRowParallelLinear(_QuantTEGroupedLinear, _MegatronColumnParallelLinear): |
| 642 | + _is_row_parallel = True |
0 commit comments