Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def rocm_aiter_gemm_w8a8_blockscale_fake(
from aiter import get_hip_quant

aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)

def aiter_triton_a16w8_gemm_check(m, n, k):
if m <= 256:
return (n == 7168 and k == 2048) # DS-R1 o_proj for decode
return False


def rocm_aiter_ck_tile_gemm_w8a8_blockscale_impl(
Expand Down Expand Up @@ -236,6 +241,12 @@ def apply_w8a8_block_fp8_linear(
q_input = input
x_scale = input_quant_scale
output_dtype = torch.bfloat16
elif aiter_triton_a16w8_gemm_check(input_2d.shape[0], weight.shape[0], input_2d.shape[1]):
from aiter.ops.triton.gemm_a16w8_blockscale import gemm_a16w8_blockscale
output = gemm_a16w8_blockscale(input_2d, weight, weight_scale, dtype=output_dtype)
if bias is not None:
output = output + bias
return output.view(*output_shape)
elif use_aiter_and_is_supported and current_platform.is_fp8_fnuz():
q_input, x_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
Expand Down
105 changes: 85 additions & 20 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,75 @@
#from aiter.ops.triton.fused_mxfp4_quant import fused_rms_mxfp4_quant as fused_rms_fp8_group_quant

if VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT:
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant, fused_reduce_rms_fp8_group_quant
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
import aiter as rocm_aiter
rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8
rocm_aiter_fp8_quant_group_size = 128
rocm_aiter_fp8_quant_group_size = 128

def rocm_aiter_triton_qkv_a_proj_layernorm_impl(
hidden_states_quant: torch.Tensor,
hidden_states_quant_scale: torch.Tensor,
weight_qkv_a_proj: torch.Tensor,
weight_scale_qkv_a_proj: torch.Tensor,
q_a_layernorm_weight: torch.Tensor,
q_a_layernorm_variance_epsilon: float,
kv_a_layernorm_weight: torch.Tensor,
kv_a_layernorm_variance_epsilon: float,
q_lora_rank: int,
kv_lora_rank: int,
qk_rope_head_dim: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
qkv_lora = gemm_a8w8_blockscale(hidden_states_quant, weight_qkv_a_proj, hidden_states_quant_scale, weight_scale_qkv_a_proj, skip_reduce=True)
q_c, kv_c, k_pe = qkv_lora.split([q_lora_rank, kv_lora_rank, qk_rope_head_dim],
dim=-1,
)
k_pe_reduced = None
k_pe_reduced_out = None
if k_pe.dim() == 3:
M = hidden_states_quant.shape[0]
device = hidden_states_quant.device
k_pe_reduced = k_pe
k_pe_reduced_out = torch.empty((M, q_lora_rank + kv_lora_rank + qk_rope_head_dim), dtype=torch.bfloat16, device=device)[..., :qk_rope_head_dim]
(q_c, q_c_scale), _, kv_c_normed, _, k_pe_reduced_out = fused_reduce_rms_fp8_group_quant(q_c, q_a_layernorm_weight, q_a_layernorm_variance_epsilon,
kv_c, kv_a_layernorm_weight, kv_a_layernorm_variance_epsilon, k_pe_reduced,
group_size=rocm_aiter_fp8_quant_group_size,
dtype_quant=rocm_aiter_fp8_dtype,
dtype=torch.bfloat16,
res1=None,
out3=k_pe_reduced_out)
if k_pe_reduced_out is not None:
k_pe = k_pe_reduced_out
return q_c, q_c_scale, kv_c_normed, k_pe

def rocm_aiter_triton_qkv_a_proj_layernorm_fake(
hidden_states_quant: torch.Tensor,
hidden_states_quant_scale: torch.Tensor,
weight_qkv_a_proj: torch.Tensor,
weight_scale_qkv_a_proj: torch.Tensor,
q_a_layernorm_weight: torch.Tensor,
q_a_layernorm_variance_epsilon: float,
kv_a_layernorm_weight: torch.Tensor,
kv_a_layernorm_variance_epsilon: float,
q_lora_rank: int,
kv_lora_rank: int,
qk_rope_head_dim: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
M = hidden_states_quant.shape[0]
device = hidden_states_quant.device
q_c = torch.empty((M, q_lora_rank), dtype=rocm_aiter_fp8_dtype, device=device)
q_c_scale = torch.empty((M, (q_lora_rank + rocm_aiter_fp8_quant_group_size - 1) // rocm_aiter_fp8_quant_group_size), dtype=torch.float32, device=device)
kv_c_normed = torch.empty((M, kv_lora_rank), dtype=torch.bfloat16, device=device)
k_pe = torch.empty((M, q_lora_rank + kv_lora_rank + qk_rope_head_dim), dtype=torch.bfloat16, device=device)[..., :qk_rope_head_dim]
return q_c, q_c_scale, kv_c_normed, k_pe

direct_register_custom_op(
op_name="rocm_aiter_triton_qkv_a_proj_layernorm",
op_func=rocm_aiter_triton_qkv_a_proj_layernorm_impl,
mutates_args=[],
fake_impl=rocm_aiter_triton_qkv_a_proj_layernorm_fake,
dispatch_key=current_platform.dispatch_key,
)

if VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD:
from aiter.ops.triton.fused_mul_add import fused_mul_add
Expand Down Expand Up @@ -653,29 +719,28 @@ def forward(
hidden_states, hidden_states_quant = hidden_states

if self.q_lora_rank is not None:

qkv_lora = self.fused_qkv_a_proj(hidden_states, x_quant_scales = hidden_states_quant)[0]
q_c, kv_lora = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
if self.use_triton_fused_rmsnorm_fp8_quant:
weight = self.q_a_layernorm.weight
eps = self.q_a_layernorm.variance_epsilon
weight2 = self.kv_a_layernorm.weight
eps2 = self.kv_a_layernorm.variance_epsilon
kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim],
dim=-1)
(q_c, q_c_scale), _, kv_c_normed, _ = fused_rms_fp8_group_quant(q_c, weight, eps,
kv_c, weight2, eps2,
group_size=rocm_aiter_fp8_quant_group_size,
dtype_quant=rocm_aiter_fp8_dtype,
res1=None)
q_c, q_c_scale, kv_c_normed, k_pe = torch.ops.vllm.rocm_aiter_triton_qkv_a_proj_layernorm(
hidden_states_quant=hidden_states,
hidden_states_quant_scale=hidden_states_quant,
weight_qkv_a_proj=self.fused_qkv_a_proj.weight,
weight_scale_qkv_a_proj=self.fused_qkv_a_proj.weight_scale_inv,
q_a_layernorm_weight=self.q_a_layernorm.weight,
q_a_layernorm_variance_epsilon=self.q_a_layernorm.variance_epsilon,
kv_a_layernorm_weight=self.kv_a_layernorm.weight,
kv_a_layernorm_variance_epsilon=self.kv_a_layernorm.variance_epsilon,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim)
q = self.q_b_proj(q_c, x_quant_scales = q_c_scale)[0]
else:
else:
qkv_lora = self.fused_qkv_a_proj(hidden_states, x_quant_scales = hidden_states_quant)[0]
q_c, kv_lora = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0]

else:
kv_lora = self.kv_a_proj_with_mqa(hidden_states, x_quant_scales = hidden_states_quant)[0]
q = self.q_proj(hidden_states, x_quant_scales = hidden_states_quant)[0]
Expand Down
Loading