|
79 | 79 | #from aiter.ops.triton.fused_mxfp4_quant import fused_rms_mxfp4_quant as fused_rms_fp8_group_quant |
80 | 80 |
|
81 | 81 | if VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT: |
| 82 | + from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant, fused_reduce_rms_fp8_group_quant |
| 83 | + from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale |
82 | 84 | import aiter as rocm_aiter |
83 | 85 | rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8 |
84 | | - rocm_aiter_fp8_quant_group_size = 128 |
| 86 | + rocm_aiter_fp8_quant_group_size = 128 |
| 87 | + |
| 88 | + def rocm_aiter_triton_qkv_a_proj_layernorm_impl( |
| 89 | + hidden_states_quant: torch.Tensor, |
| 90 | + hidden_states_quant_scale: torch.Tensor, |
| 91 | + weight_qkv_a_proj: torch.Tensor, |
| 92 | + weight_scale_qkv_a_proj: torch.Tensor, |
| 93 | + q_a_layernorm_weight: torch.Tensor, |
| 94 | + q_a_layernorm_variance_epsilon: float, |
| 95 | + kv_a_layernorm_weight: torch.Tensor, |
| 96 | + kv_a_layernorm_variance_epsilon: float, |
| 97 | + q_lora_rank: int, |
| 98 | + kv_lora_rank: int, |
| 99 | + qk_rope_head_dim: int, |
| 100 | + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| 101 | + qkv_lora = gemm_a8w8_blockscale(hidden_states_quant, weight_qkv_a_proj, hidden_states_quant_scale, weight_scale_qkv_a_proj, skip_reduce=True) |
| 102 | + q_c, kv_c, k_pe = qkv_lora.split([q_lora_rank, kv_lora_rank, qk_rope_head_dim], |
| 103 | + dim=-1, |
| 104 | + ) |
| 105 | + k_pe_reduced = None |
| 106 | + k_pe_reduced_out = None |
| 107 | + if k_pe.dim() == 3: |
| 108 | + M = hidden_states_quant.shape[0] |
| 109 | + device = hidden_states_quant.device |
| 110 | + k_pe_reduced = k_pe |
| 111 | + k_pe_reduced_out = torch.empty((M, q_lora_rank + kv_lora_rank + qk_rope_head_dim), dtype=torch.bfloat16, device=device)[..., :qk_rope_head_dim] |
| 112 | + (q_c, q_c_scale), _, kv_c_normed, _, k_pe_reduced_out = fused_reduce_rms_fp8_group_quant(q_c, q_a_layernorm_weight, q_a_layernorm_variance_epsilon, |
| 113 | + kv_c, kv_a_layernorm_weight, kv_a_layernorm_variance_epsilon, k_pe_reduced, |
| 114 | + group_size=rocm_aiter_fp8_quant_group_size, |
| 115 | + dtype_quant=rocm_aiter_fp8_dtype, |
| 116 | + dtype=torch.bfloat16, |
| 117 | + res1=None, |
| 118 | + out3=k_pe_reduced_out) |
| 119 | + if k_pe_reduced_out is not None: |
| 120 | + k_pe = k_pe_reduced_out |
| 121 | + return q_c, q_c_scale, kv_c_normed, k_pe |
| 122 | + |
| 123 | + def rocm_aiter_triton_qkv_a_proj_layernorm_fake( |
| 124 | + hidden_states_quant: torch.Tensor, |
| 125 | + hidden_states_quant_scale: torch.Tensor, |
| 126 | + weight_qkv_a_proj: torch.Tensor, |
| 127 | + weight_scale_qkv_a_proj: torch.Tensor, |
| 128 | + q_a_layernorm_weight: torch.Tensor, |
| 129 | + q_a_layernorm_variance_epsilon: float, |
| 130 | + kv_a_layernorm_weight: torch.Tensor, |
| 131 | + kv_a_layernorm_variance_epsilon: float, |
| 132 | + q_lora_rank: int, |
| 133 | + kv_lora_rank: int, |
| 134 | + qk_rope_head_dim: int, |
| 135 | + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| 136 | + M = hidden_states_quant.shape[0] |
| 137 | + device = hidden_states_quant.device |
| 138 | + q_c = torch.empty((M, q_lora_rank), dtype=rocm_aiter_fp8_dtype, device=device) |
| 139 | + q_c_scale = torch.empty((M, (q_lora_rank + rocm_aiter_fp8_quant_group_size - 1) // rocm_aiter_fp8_quant_group_size), dtype=torch.float32, device=device) |
| 140 | + kv_c_normed = torch.empty((M, kv_lora_rank), dtype=torch.bfloat16, device=device) |
| 141 | + k_pe = torch.empty((M, q_lora_rank + kv_lora_rank + qk_rope_head_dim), dtype=torch.bfloat16, device=device)[..., :qk_rope_head_dim] |
| 142 | + return q_c, q_c_scale, kv_c_normed, k_pe |
| 143 | + |
| 144 | + direct_register_custom_op( |
| 145 | + op_name="rocm_aiter_triton_qkv_a_proj_layernorm", |
| 146 | + op_func=rocm_aiter_triton_qkv_a_proj_layernorm_impl, |
| 147 | + mutates_args=[], |
| 148 | + fake_impl=rocm_aiter_triton_qkv_a_proj_layernorm_fake, |
| 149 | + dispatch_key=current_platform.dispatch_key, |
| 150 | + ) |
85 | 151 |
|
86 | 152 | if VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD: |
87 | 153 | from aiter.ops.triton.fused_mul_add import fused_mul_add |
@@ -653,29 +719,28 @@ def forward( |
653 | 719 | hidden_states, hidden_states_quant = hidden_states |
654 | 720 |
|
655 | 721 | if self.q_lora_rank is not None: |
656 | | - |
657 | | - qkv_lora = self.fused_qkv_a_proj(hidden_states, x_quant_scales = hidden_states_quant)[0] |
658 | | - q_c, kv_lora = qkv_lora.split( |
659 | | - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], |
660 | | - dim=-1, |
661 | | - ) |
662 | 722 | if self.use_triton_fused_rmsnorm_fp8_quant: |
663 | | - weight = self.q_a_layernorm.weight |
664 | | - eps = self.q_a_layernorm.variance_epsilon |
665 | | - weight2 = self.kv_a_layernorm.weight |
666 | | - eps2 = self.kv_a_layernorm.variance_epsilon |
667 | | - kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], |
668 | | - dim=-1) |
669 | | - (q_c, q_c_scale), _, kv_c_normed, _ = fused_rms_fp8_group_quant(q_c, weight, eps, |
670 | | - kv_c, weight2, eps2, |
671 | | - group_size=rocm_aiter_fp8_quant_group_size, |
672 | | - dtype_quant=rocm_aiter_fp8_dtype, |
673 | | - res1=None) |
| 723 | + q_c, q_c_scale, kv_c_normed, k_pe = torch.ops.vllm.rocm_aiter_triton_qkv_a_proj_layernorm( |
| 724 | + hidden_states_quant=hidden_states, |
| 725 | + hidden_states_quant_scale=hidden_states_quant, |
| 726 | + weight_qkv_a_proj=self.fused_qkv_a_proj.weight, |
| 727 | + weight_scale_qkv_a_proj=self.fused_qkv_a_proj.weight_scale_inv, |
| 728 | + q_a_layernorm_weight=self.q_a_layernorm.weight, |
| 729 | + q_a_layernorm_variance_epsilon=self.q_a_layernorm.variance_epsilon, |
| 730 | + kv_a_layernorm_weight=self.kv_a_layernorm.weight, |
| 731 | + kv_a_layernorm_variance_epsilon=self.kv_a_layernorm.variance_epsilon, |
| 732 | + q_lora_rank=self.q_lora_rank, |
| 733 | + kv_lora_rank=self.kv_lora_rank, |
| 734 | + qk_rope_head_dim=self.qk_rope_head_dim) |
674 | 735 | q = self.q_b_proj(q_c, x_quant_scales = q_c_scale)[0] |
675 | | - else: |
| 736 | + else: |
| 737 | + qkv_lora = self.fused_qkv_a_proj(hidden_states, x_quant_scales = hidden_states_quant)[0] |
| 738 | + q_c, kv_lora = qkv_lora.split( |
| 739 | + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], |
| 740 | + dim=-1, |
| 741 | + ) |
676 | 742 | q_c = self.q_a_layernorm(q_c) |
677 | 743 | q = self.q_b_proj(q_c)[0] |
678 | | - |
679 | 744 | else: |
680 | 745 | kv_lora = self.kv_a_proj_with_mqa(hidden_states, x_quant_scales = hidden_states_quant)[0] |
681 | 746 | q = self.q_proj(hidden_states, x_quant_scales = hidden_states_quant)[0] |
|
0 commit comments