|
37 | 37 | ) |
38 | 38 | from modelopt.torch.utils.distributed import ParallelState |
39 | 39 |
|
| 40 | +from ..model_calib import max_calibrate |
40 | 41 | from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer |
41 | 42 | from ..nn.modules.quant_linear import RealQuantLinear |
42 | 43 | from ..qtensor import QTensorWrapper |
|
45 | 46 | try: |
46 | 47 | from megatron.core.extensions.transformer_engine import ( |
47 | 48 | TEColumnParallelGroupedLinear, |
| 49 | + TEDotProductAttention, |
48 | 50 | TERowParallelGroupedLinear, |
49 | 51 | ) |
50 | 52 |
|
@@ -590,6 +592,169 @@ def _setup(self): |
590 | 592 | self.linear_fc1.parallel_state = self.parallel_state |
591 | 593 | self.linear_fc2.parallel_state = self.parallel_state |
592 | 594 |
|
| 595 | + @QuantModuleRegistry.register({TEDotProductAttention: "TEDotProductAttention"}) |
| 596 | + class _QuantTEDotProductAttention(QuantModule): |
| 597 | + """Quantized version of TEDotProductAttention for Megatron models with KV cache quantization. |
| 598 | +
|
| 599 | + This class adds KV cache quantization support to Transformer Engine's TEDotProductAttention |
| 600 | + module used in Megatron-Core models. It introduces three quantizers (q_bmm_quantizer, |
| 601 | + k_bmm_quantizer, v_bmm_quantizer) that quantize the query, key, and value tensors after |
| 602 | + RoPE has been applied. |
| 603 | + """ |
| 604 | + |
| 605 | + def _setup(self): |
| 606 | + """Initialize quantizers for Q, K, V tensors.""" |
| 607 | + self.q_bmm_quantizer = TensorQuantizer() |
| 608 | + self.k_bmm_quantizer = TensorQuantizer() |
| 609 | + self.v_bmm_quantizer = TensorQuantizer() |
| 610 | + |
| 611 | + def _calibrate_quantizers(self): |
| 612 | + """Calibrate quantizers with minimal dummy tensors.""" |
| 613 | + # Get device and dtype from the parent module's parameters |
| 614 | + param = next(iter(self.parameters()), None) |
| 615 | + device = param.device if param is not None else torch.device("cuda") |
| 616 | + dtype = param.dtype if param is not None else torch.float16 |
| 617 | + |
| 618 | + # TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion |
| 619 | + batch_size = 1 |
| 620 | + seq_len = 1 |
| 621 | + |
| 622 | + # Get dimensions from config |
| 623 | + num_heads = self.config.num_attention_heads |
| 624 | + head_dim = ( |
| 625 | + self.config.kv_channels |
| 626 | + if hasattr(self.config, "kv_channels") |
| 627 | + else self.config.hidden_size // num_heads |
| 628 | + ) |
| 629 | + |
| 630 | + # Determine tensor format (default to sbhd if not specified) |
| 631 | + apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False) |
| 632 | + qkv_format = "bshd" if apply_rope_fusion else "sbhd" |
| 633 | + |
| 634 | + if qkv_format == "sbhd": |
| 635 | + dummy_tensor = torch.randn( |
| 636 | + seq_len, batch_size, num_heads, head_dim, device=device, dtype=dtype |
| 637 | + ) |
| 638 | + else: |
| 639 | + dummy_tensor = torch.randn( |
| 640 | + batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype |
| 641 | + ) |
| 642 | + |
| 643 | + # Calibrate each quantizer |
| 644 | + quantizers = [ |
| 645 | + ("q_bmm_quantizer", self.q_bmm_quantizer), |
| 646 | + ("k_bmm_quantizer", self.k_bmm_quantizer), |
| 647 | + ("v_bmm_quantizer", self.v_bmm_quantizer), |
| 648 | + ] |
| 649 | + |
| 650 | + for _, quantizer in quantizers: |
| 651 | + if quantizer is not None and quantizer.is_enabled(): |
| 652 | + if not hasattr(quantizer, "_amax") or quantizer._amax is None: |
| 653 | + quantizer.reset_amax() |
| 654 | + max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False) |
| 655 | + |
| 656 | + def forward(self, query, key, value, *args, **kwargs): |
| 657 | + """Apply post-RoPE quantization to KV cache. |
| 658 | +
|
| 659 | + TEDotProductAttention receives Q, K, V after RoPE is applied, |
| 660 | + so we quantize them directly for KV cache quantization. |
| 661 | + """ |
| 662 | + # Quantize Q, K, V |
| 663 | + query = self.q_bmm_quantizer(query) |
| 664 | + key = self.k_bmm_quantizer(key) |
| 665 | + value = self.v_bmm_quantizer(value) |
| 666 | + |
| 667 | + return super().forward(query, key, value, *args, **kwargs) |
| 668 | + |
| 669 | + def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): |
| 670 | + """Create a sharded state dictionary for distributed checkpointing.""" |
| 671 | + sharded_state_dict = {} |
| 672 | + |
| 673 | + # First add non-quantizer parameters |
| 674 | + for k, v in self.state_dict(prefix="", keep_vars=True).items(): |
| 675 | + if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k: |
| 676 | + sharded_state_dict[prefix + k] = v |
| 677 | + |
| 678 | + # Process _amax in bmm_quantizers |
| 679 | + for name, quantizer in [ |
| 680 | + ("q_bmm_quantizer", self.q_bmm_quantizer), |
| 681 | + ("k_bmm_quantizer", self.k_bmm_quantizer), |
| 682 | + ("v_bmm_quantizer", self.v_bmm_quantizer), |
| 683 | + ]: |
| 684 | + if hasattr(quantizer, "_amax") and quantizer._amax is not None: |
| 685 | + amax_key = f"{prefix}{name}._amax" |
| 686 | + sharded_state_dict[amax_key] = quantizer._amax |
| 687 | + |
| 688 | + # Process other quantizer parameters in bmm_quantizers |
| 689 | + quantizer_state_dict = { |
| 690 | + k: v |
| 691 | + for k, v in self.state_dict(prefix="", keep_vars=True).items() |
| 692 | + if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k |
| 693 | + } |
| 694 | + |
| 695 | + if quantizer_state_dict: |
| 696 | + sharded_state_dict.update( |
| 697 | + **make_sharded_tensors_for_checkpoint( |
| 698 | + quantizer_state_dict, prefix, {}, sharded_offsets |
| 699 | + ) |
| 700 | + ) |
| 701 | + |
| 702 | + return sharded_state_dict |
| 703 | + |
| 704 | + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): |
| 705 | + """Handle loading state dict for quantizers.""" |
| 706 | + for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]: |
| 707 | + full_prefix = f"{prefix}{quantizer_name}." |
| 708 | + amax_key = f"{prefix}{quantizer_name}._amax" |
| 709 | + |
| 710 | + # If amax is in state_dict, rename it to the format expected by TensorQuantizer |
| 711 | + if amax_key in state_dict: |
| 712 | + expected_amax_key = f"{full_prefix}_amax" |
| 713 | + state_dict[expected_amax_key] = state_dict.pop(amax_key) |
| 714 | + |
| 715 | + # Handle other quantizer states |
| 716 | + for k in list(state_dict.keys()): |
| 717 | + if "_quantizer" in k and "_amax" not in k: |
| 718 | + name = k.split(prefix)[-1] if prefix else k |
| 719 | + if name in self.state_dict(): |
| 720 | + state_dict[k] = state_dict[k].view_as(self.state_dict()[name]) |
| 721 | + |
| 722 | + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
| 723 | + |
| 724 | + def modelopt_post_restore(self, name=""): |
| 725 | + """Restore quantizer states after model loading.""" |
| 726 | + super().modelopt_post_restore(name) |
| 727 | + |
| 728 | + def _check_unsupported_states(quantizer): |
| 729 | + """Check for unsupported quantizer states and warn if found.""" |
| 730 | + if not hasattr(quantizer, "state_dict"): |
| 731 | + return |
| 732 | + |
| 733 | + for k in quantizer.state_dict(): |
| 734 | + if k not in ["_amax", "_pre_quant_scale"]: |
| 735 | + warnings.warn( |
| 736 | + f"Restore of {k} for {name} is not supported. The restore of this layer might be " |
| 737 | + f"incorrect. Please implement a custom restore for {k}." |
| 738 | + ) |
| 739 | + |
| 740 | + calibration_needed = False |
| 741 | + |
| 742 | + for quantizer_name, quantizer in [ |
| 743 | + ("q_bmm_quantizer", self.q_bmm_quantizer), |
| 744 | + ("k_bmm_quantizer", self.k_bmm_quantizer), |
| 745 | + ("v_bmm_quantizer", self.v_bmm_quantizer), |
| 746 | + ]: |
| 747 | + if not hasattr(self, quantizer_name) or not quantizer.is_enabled(): |
| 748 | + continue |
| 749 | + |
| 750 | + _check_unsupported_states(quantizer) |
| 751 | + |
| 752 | + if not hasattr(quantizer, "_amax") or quantizer._amax is None: |
| 753 | + calibration_needed = True |
| 754 | + |
| 755 | + if calibration_needed: |
| 756 | + self._calibrate_quantizers() |
| 757 | + |
593 | 758 |
|
594 | 759 | @QuantModuleRegistry.register({megatron_moe_layer.MoELayer: "megatron_moe_MoELayer"}) |
595 | 760 | class _QuantMoELayer(QuantModule): |
|
0 commit comments