Skip to content

Commit ae9c305

Browse files
committed
add a16w8 gemm for DS-R1 for o_proj for decode, add rocm_aiter_triton_qkv_a_proj_layernorm
1 parent 06a0165 commit ae9c305

File tree

2 files changed

+96
-20
lines changed

2 files changed

+96
-20
lines changed

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ def rocm_aiter_gemm_w8a8_blockscale_fake(
9393
from aiter import get_hip_quant
9494

9595
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
96+
97+
def aiter_triton_a16w8_gemm_check(m, n, k):
98+
if m <= 256:
99+
return (n == 7168 and k == 2048) # DS-R1 o_proj for decode
100+
return False
96101

97102

98103
def rocm_aiter_ck_tile_gemm_w8a8_blockscale_impl(
@@ -236,6 +241,12 @@ def apply_w8a8_block_fp8_linear(
236241
q_input = input
237242
x_scale = input_quant_scale
238243
output_dtype = torch.bfloat16
244+
elif aiter_triton_a16w8_gemm_check(input_2d.shape[0], weight.shape[0], input_2d.shape[1]):
245+
from aiter.ops.triton.gemm_a16w8_blockscale import gemm_a16w8_blockscale
246+
output = gemm_a16w8_blockscale(input_2d, weight, weight_scale, dtype=output_dtype)
247+
if bias is not None:
248+
output = output + bias
249+
return output.view(*output_shape)
239250
elif use_aiter_and_is_supported and current_platform.is_fp8_fnuz():
240251
q_input, x_scale = aiter_per1x128_quant(
241252
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)

vllm/model_executor/models/deepseek_v2.py

Lines changed: 85 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,75 @@
7979
#from aiter.ops.triton.fused_mxfp4_quant import fused_rms_mxfp4_quant as fused_rms_fp8_group_quant
8080

8181
if VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT:
82+
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant, fused_reduce_rms_fp8_group_quant
83+
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
8284
import aiter as rocm_aiter
8385
rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8
84-
rocm_aiter_fp8_quant_group_size = 128
86+
rocm_aiter_fp8_quant_group_size = 128
87+
88+
def rocm_aiter_triton_qkv_a_proj_layernorm_impl(
89+
hidden_states_quant: torch.Tensor,
90+
hidden_states_quant_scale: torch.Tensor,
91+
weight_qkv_a_proj: torch.Tensor,
92+
weight_scale_qkv_a_proj: torch.Tensor,
93+
q_a_layernorm_weight: torch.Tensor,
94+
q_a_layernorm_variance_epsilon: float,
95+
kv_a_layernorm_weight: torch.Tensor,
96+
kv_a_layernorm_variance_epsilon: float,
97+
q_lora_rank: int,
98+
kv_lora_rank: int,
99+
qk_rope_head_dim: int,
100+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
101+
qkv_lora = gemm_a8w8_blockscale(hidden_states_quant, weight_qkv_a_proj, hidden_states_quant_scale, weight_scale_qkv_a_proj, skip_reduce=True)
102+
q_c, kv_c, k_pe = qkv_lora.split([q_lora_rank, kv_lora_rank, qk_rope_head_dim],
103+
dim=-1,
104+
)
105+
k_pe_reduced = None
106+
k_pe_reduced_out = None
107+
if k_pe.dim() == 3:
108+
M = hidden_states_quant.shape[0]
109+
device = hidden_states_quant.device
110+
k_pe_reduced = k_pe
111+
k_pe_reduced_out = torch.empty((M, q_lora_rank + kv_lora_rank + qk_rope_head_dim), dtype=torch.bfloat16, device=device)[..., :qk_rope_head_dim]
112+
(q_c, q_c_scale), _, kv_c_normed, _, k_pe_reduced_out = fused_reduce_rms_fp8_group_quant(q_c, q_a_layernorm_weight, q_a_layernorm_variance_epsilon,
113+
kv_c, kv_a_layernorm_weight, kv_a_layernorm_variance_epsilon, k_pe_reduced,
114+
group_size=rocm_aiter_fp8_quant_group_size,
115+
dtype_quant=rocm_aiter_fp8_dtype,
116+
dtype=torch.bfloat16,
117+
res1=None,
118+
out3=k_pe_reduced_out)
119+
if k_pe_reduced_out is not None:
120+
k_pe = k_pe_reduced_out
121+
return q_c, q_c_scale, kv_c_normed, k_pe
122+
123+
def rocm_aiter_triton_qkv_a_proj_layernorm_fake(
124+
hidden_states_quant: torch.Tensor,
125+
hidden_states_quant_scale: torch.Tensor,
126+
weight_qkv_a_proj: torch.Tensor,
127+
weight_scale_qkv_a_proj: torch.Tensor,
128+
q_a_layernorm_weight: torch.Tensor,
129+
q_a_layernorm_variance_epsilon: float,
130+
kv_a_layernorm_weight: torch.Tensor,
131+
kv_a_layernorm_variance_epsilon: float,
132+
q_lora_rank: int,
133+
kv_lora_rank: int,
134+
qk_rope_head_dim: int,
135+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
136+
M = hidden_states_quant.shape[0]
137+
device = hidden_states_quant.device
138+
q_c = torch.empty((M, q_lora_rank), dtype=rocm_aiter_fp8_dtype, device=device)
139+
q_c_scale = torch.empty((M, (q_lora_rank + rocm_aiter_fp8_quant_group_size - 1) // rocm_aiter_fp8_quant_group_size), dtype=torch.float32, device=device)
140+
kv_c_normed = torch.empty((M, kv_lora_rank), dtype=torch.bfloat16, device=device)
141+
k_pe = torch.empty((M, q_lora_rank + kv_lora_rank + qk_rope_head_dim), dtype=torch.bfloat16, device=device)[..., :qk_rope_head_dim]
142+
return q_c, q_c_scale, kv_c_normed, k_pe
143+
144+
direct_register_custom_op(
145+
op_name="rocm_aiter_triton_qkv_a_proj_layernorm",
146+
op_func=rocm_aiter_triton_qkv_a_proj_layernorm_impl,
147+
mutates_args=[],
148+
fake_impl=rocm_aiter_triton_qkv_a_proj_layernorm_fake,
149+
dispatch_key=current_platform.dispatch_key,
150+
)
85151

86152
if VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD:
87153
from aiter.ops.triton.fused_mul_add import fused_mul_add
@@ -653,29 +719,28 @@ def forward(
653719
hidden_states, hidden_states_quant = hidden_states
654720

655721
if self.q_lora_rank is not None:
656-
657-
qkv_lora = self.fused_qkv_a_proj(hidden_states, x_quant_scales = hidden_states_quant)[0]
658-
q_c, kv_lora = qkv_lora.split(
659-
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
660-
dim=-1,
661-
)
662722
if self.use_triton_fused_rmsnorm_fp8_quant:
663-
weight = self.q_a_layernorm.weight
664-
eps = self.q_a_layernorm.variance_epsilon
665-
weight2 = self.kv_a_layernorm.weight
666-
eps2 = self.kv_a_layernorm.variance_epsilon
667-
kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim],
668-
dim=-1)
669-
(q_c, q_c_scale), _, kv_c_normed, _ = fused_rms_fp8_group_quant(q_c, weight, eps,
670-
kv_c, weight2, eps2,
671-
group_size=rocm_aiter_fp8_quant_group_size,
672-
dtype_quant=rocm_aiter_fp8_dtype,
673-
res1=None)
723+
q_c, q_c_scale, kv_c_normed, k_pe = torch.ops.vllm.rocm_aiter_triton_qkv_a_proj_layernorm(
724+
hidden_states_quant=hidden_states,
725+
hidden_states_quant_scale=hidden_states_quant,
726+
weight_qkv_a_proj=self.fused_qkv_a_proj.weight,
727+
weight_scale_qkv_a_proj=self.fused_qkv_a_proj.weight_scale_inv,
728+
q_a_layernorm_weight=self.q_a_layernorm.weight,
729+
q_a_layernorm_variance_epsilon=self.q_a_layernorm.variance_epsilon,
730+
kv_a_layernorm_weight=self.kv_a_layernorm.weight,
731+
kv_a_layernorm_variance_epsilon=self.kv_a_layernorm.variance_epsilon,
732+
q_lora_rank=self.q_lora_rank,
733+
kv_lora_rank=self.kv_lora_rank,
734+
qk_rope_head_dim=self.qk_rope_head_dim)
674735
q = self.q_b_proj(q_c, x_quant_scales = q_c_scale)[0]
675-
else:
736+
else:
737+
qkv_lora = self.fused_qkv_a_proj(hidden_states, x_quant_scales = hidden_states_quant)[0]
738+
q_c, kv_lora = qkv_lora.split(
739+
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
740+
dim=-1,
741+
)
676742
q_c = self.q_a_layernorm(q_c)
677743
q = self.q_b_proj(q_c)[0]
678-
679744
else:
680745
kv_lora = self.kv_a_proj_with_mqa(hidden_states, x_quant_scales = hidden_states_quant)[0]
681746
q = self.q_proj(hidden_states, x_quant_scales = hidden_states_quant)[0]

0 commit comments

Comments
 (0)