diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 43a352d7027d..1c7a493297f0 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -93,6 +93,11 @@ def rocm_aiter_gemm_w8a8_blockscale_fake( from aiter import get_hip_quant aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) + + def aiter_triton_a16w8_gemm_check(m, n, k): + if m <= 256: + return (n == 7168 and k == 2048) # DS-R1 o_proj for decode + return False def rocm_aiter_ck_tile_gemm_w8a8_blockscale_impl( @@ -236,6 +241,12 @@ def apply_w8a8_block_fp8_linear( q_input = input x_scale = input_quant_scale output_dtype = torch.bfloat16 + elif aiter_triton_a16w8_gemm_check(input_2d.shape[0], weight.shape[0], input_2d.shape[1]): + from aiter.ops.triton.gemm_a16w8_blockscale import gemm_a16w8_blockscale + output = gemm_a16w8_blockscale(input_2d, weight, weight_scale, dtype=output_dtype) + if bias is not None: + output = output + bias + return output.view(*output_shape) elif use_aiter_and_is_supported and current_platform.is_fp8_fnuz(): q_input, x_scale = aiter_per1x128_quant( input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index e40839b1d475..6d9da6bcab65 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -79,9 +79,75 @@ #from aiter.ops.triton.fused_mxfp4_quant import fused_rms_mxfp4_quant as fused_rms_fp8_group_quant if VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT: + from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant, fused_reduce_rms_fp8_group_quant + from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale import aiter as rocm_aiter rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8 - rocm_aiter_fp8_quant_group_size = 128 + rocm_aiter_fp8_quant_group_size = 128 + + def rocm_aiter_triton_qkv_a_proj_layernorm_impl( + hidden_states_quant: torch.Tensor, + hidden_states_quant_scale: torch.Tensor, + weight_qkv_a_proj: torch.Tensor, + weight_scale_qkv_a_proj: torch.Tensor, + q_a_layernorm_weight: torch.Tensor, + q_a_layernorm_variance_epsilon: float, + kv_a_layernorm_weight: torch.Tensor, + kv_a_layernorm_variance_epsilon: float, + q_lora_rank: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + qkv_lora = gemm_a8w8_blockscale(hidden_states_quant, weight_qkv_a_proj, hidden_states_quant_scale, weight_scale_qkv_a_proj, skip_reduce=True) + q_c, kv_c, k_pe = qkv_lora.split([q_lora_rank, kv_lora_rank, qk_rope_head_dim], + dim=-1, + ) + k_pe_reduced = None + k_pe_reduced_out = None + if k_pe.dim() == 3: + M = hidden_states_quant.shape[0] + device = hidden_states_quant.device + k_pe_reduced = k_pe + k_pe_reduced_out = torch.empty((M, q_lora_rank + kv_lora_rank + qk_rope_head_dim), dtype=torch.bfloat16, device=device)[..., :qk_rope_head_dim] + (q_c, q_c_scale), _, kv_c_normed, _, k_pe_reduced_out = fused_reduce_rms_fp8_group_quant(q_c, q_a_layernorm_weight, q_a_layernorm_variance_epsilon, + kv_c, kv_a_layernorm_weight, kv_a_layernorm_variance_epsilon, k_pe_reduced, + group_size=rocm_aiter_fp8_quant_group_size, + dtype_quant=rocm_aiter_fp8_dtype, + dtype=torch.bfloat16, + res1=None, + out3=k_pe_reduced_out) + if k_pe_reduced_out is not None: + k_pe = k_pe_reduced_out + return q_c, q_c_scale, kv_c_normed, k_pe + + def rocm_aiter_triton_qkv_a_proj_layernorm_fake( + hidden_states_quant: torch.Tensor, + hidden_states_quant_scale: torch.Tensor, + weight_qkv_a_proj: torch.Tensor, + weight_scale_qkv_a_proj: torch.Tensor, + q_a_layernorm_weight: torch.Tensor, + q_a_layernorm_variance_epsilon: float, + kv_a_layernorm_weight: torch.Tensor, + kv_a_layernorm_variance_epsilon: float, + q_lora_rank: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + M = hidden_states_quant.shape[0] + device = hidden_states_quant.device + q_c = torch.empty((M, q_lora_rank), dtype=rocm_aiter_fp8_dtype, device=device) + q_c_scale = torch.empty((M, (q_lora_rank + rocm_aiter_fp8_quant_group_size - 1) // rocm_aiter_fp8_quant_group_size), dtype=torch.float32, device=device) + kv_c_normed = torch.empty((M, kv_lora_rank), dtype=torch.bfloat16, device=device) + k_pe = torch.empty((M, q_lora_rank + kv_lora_rank + qk_rope_head_dim), dtype=torch.bfloat16, device=device)[..., :qk_rope_head_dim] + return q_c, q_c_scale, kv_c_normed, k_pe + + direct_register_custom_op( + op_name="rocm_aiter_triton_qkv_a_proj_layernorm", + op_func=rocm_aiter_triton_qkv_a_proj_layernorm_impl, + mutates_args=[], + fake_impl=rocm_aiter_triton_qkv_a_proj_layernorm_fake, + dispatch_key=current_platform.dispatch_key, + ) if VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD: from aiter.ops.triton.fused_mul_add import fused_mul_add @@ -653,29 +719,28 @@ def forward( hidden_states, hidden_states_quant = hidden_states if self.q_lora_rank is not None: - - qkv_lora = self.fused_qkv_a_proj(hidden_states, x_quant_scales = hidden_states_quant)[0] - q_c, kv_lora = qkv_lora.split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - dim=-1, - ) if self.use_triton_fused_rmsnorm_fp8_quant: - weight = self.q_a_layernorm.weight - eps = self.q_a_layernorm.variance_epsilon - weight2 = self.kv_a_layernorm.weight - eps2 = self.kv_a_layernorm.variance_epsilon - kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], - dim=-1) - (q_c, q_c_scale), _, kv_c_normed, _ = fused_rms_fp8_group_quant(q_c, weight, eps, - kv_c, weight2, eps2, - group_size=rocm_aiter_fp8_quant_group_size, - dtype_quant=rocm_aiter_fp8_dtype, - res1=None) + q_c, q_c_scale, kv_c_normed, k_pe = torch.ops.vllm.rocm_aiter_triton_qkv_a_proj_layernorm( + hidden_states_quant=hidden_states, + hidden_states_quant_scale=hidden_states_quant, + weight_qkv_a_proj=self.fused_qkv_a_proj.weight, + weight_scale_qkv_a_proj=self.fused_qkv_a_proj.weight_scale_inv, + q_a_layernorm_weight=self.q_a_layernorm.weight, + q_a_layernorm_variance_epsilon=self.q_a_layernorm.variance_epsilon, + kv_a_layernorm_weight=self.kv_a_layernorm.weight, + kv_a_layernorm_variance_epsilon=self.kv_a_layernorm.variance_epsilon, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_rope_head_dim=self.qk_rope_head_dim) q = self.q_b_proj(q_c, x_quant_scales = q_c_scale)[0] - else: + else: + qkv_lora = self.fused_qkv_a_proj(hidden_states, x_quant_scales = hidden_states_quant)[0] + q_c, kv_lora = qkv_lora.split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, + ) q_c = self.q_a_layernorm(q_c) q = self.q_b_proj(q_c)[0] - else: kv_lora = self.kv_a_proj_with_mqa(hidden_states, x_quant_scales = hidden_states_quant)[0] q = self.q_proj(hidden_states, x_quant_scales = hidden_states_quant)[0]