Skip to content

Commit 23ff415

Browse files
yzhou103valarLip
andauthored
Fused qk rope cat and cache mla (ROCm#1380)
* first version * opt to vec and add perf compared with triton * opt * add fused_qk_rope_concat_and_cache opt kenrel and opt concat_and_cache * use buffer_o.template set_raw() with bf16 output as set() will result in 2 busffer store dwordsx2 * fix lint error * refactor interface and fix error when is_neox=false * fix rot_dim!=64 and is_nope_first=false * fix error when input is not contiguous --------- Co-authored-by: Lingpeng Jin <103567126+valarLip@users.noreply.github.com>
1 parent 8910746 commit 23ff415

File tree

6 files changed

+1649
-95
lines changed

6 files changed

+1649
-95
lines changed

aiter/ops/cache.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,22 @@ def cp_gather_indexer_k_quant_cache(
115115
block_table: Tensor,
116116
cu_seq_lens: Tensor,
117117
) -> None: ...
118+
119+
120+
@compile_ops("module_cache")
121+
def fused_qk_rope_concat_and_cache_mla(
122+
q_nope: Tensor,
123+
q_pe: Tensor,
124+
kv_c: Tensor,
125+
k_pe: Tensor, # key tensor
126+
kv_cache: Tensor,
127+
q_out: Tensor,
128+
slot_mapping: Tensor,
129+
k_scale: Tensor,
130+
q_scale: Tensor,
131+
positions: Tensor,
132+
cos_cache: Tensor,
133+
sin_cache: Tensor,
134+
is_neox: bool,
135+
is_nope_first: bool,
136+
) -> None: ...

csrc/include/cache.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,22 @@ void cp_gather_indexer_k_quant_cache(
8585
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
8686
const torch::Tensor& block_table, // [batch_size, num_blocks]
8787
const torch::Tensor& cu_seq_lens); // [batch_size + 1]
88+
89+
void fused_qk_rope_concat_and_cache_mla(
90+
torch::Tensor& q_nope, // [num_tokens, num_heads, qk_lora_rank]
91+
torch::Tensor& q_pe, // [num_tokens, num_heads, pe_dim]
92+
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
93+
torch::Tensor& k_pe, // [num_tokens, pe_dim]
94+
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
95+
// pe_dim)]
96+
torch::Tensor& q_out, // [num_tokens, num_heads, qk_lora_rank+pe_dim]
97+
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
98+
torch::Tensor& k_scale,
99+
torch::Tensor& q_scale,
100+
torch::Tensor& positions, // [num_tokens]
101+
torch::Tensor& cos_cache, // [max_positions, pe_dim//2]
102+
torch::Tensor& sin_cache, // [max_positions, pe_dim//2]
103+
bool is_neox,
104+
bool is_nope_first);
105+
88106
} // namespace aiter

csrc/include/quant_utils.cuh

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22
/*
3-
* Copyright © Advanced Micro Devices, Inc. All rights reserved.
3+
* Copyright (C) Advanced Micro Devices, Inc. All rights reserved.
44
* Copyright (C) 2024-2025, The vLLM team.
55
*
66
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -783,6 +783,69 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale)
783783
} \
784784
}
785785

786+
#define DISPATCH_BY_KV_CACHE_QUERY_DTYPE(SRC_DTYPE, KV_DTYPE, QUERY_DTYPE, FN) \
787+
if(KV_DTYPE == "auto" && QUERY_DTYPE == "auto") \
788+
{ \
789+
if(SRC_DTYPE == at::ScalarType::Float) \
790+
{ \
791+
FN(float, float, float, vllm::Fp8KVCacheDataType::kAuto, vllm::Fp8KVCacheDataType::kAuto); \
792+
} \
793+
else if(SRC_DTYPE == at::ScalarType::Half) \
794+
{ \
795+
FN(ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, vllm::Fp8KVCacheDataType::kAuto, vllm::Fp8KVCacheDataType::kAuto); \
796+
} \
797+
else if(SRC_DTYPE == at::ScalarType::BFloat16) \
798+
{ \
799+
FN(ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, vllm::Fp8KVCacheDataType::kAuto, vllm::Fp8KVCacheDataType::kAuto); \
800+
} \
801+
else \
802+
{ \
803+
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
804+
} \
805+
} \
806+
else if ((KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") && (QUERY_DTYPE == "auto")) \
807+
{ \
808+
if(SRC_DTYPE == at::ScalarType::Float) \
809+
{ \
810+
FN(float, ck_tile::fp8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3, vllm::Fp8KVCacheDataType::kAuto); \
811+
} \
812+
else if(SRC_DTYPE == at::ScalarType::Half) \
813+
{ \
814+
FN(ck_tile::fp16_t, ck_tile::fp8_t, ck_tile::fp16_t, vllm::Fp8KVCacheDataType::kFp8E4M3, vllm::Fp8KVCacheDataType::kAuto); \
815+
} \
816+
else if(SRC_DTYPE == at::ScalarType::BFloat16) \
817+
{ \
818+
FN(ck_tile::bf16_t, ck_tile::fp8_t, ck_tile::bf16_t,vllm::Fp8KVCacheDataType::kFp8E4M3, vllm::Fp8KVCacheDataType::kAuto); \
819+
} \
820+
else \
821+
{ \
822+
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
823+
} \
824+
} \
825+
else if ((KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") && (QUERY_DTYPE == "fp8" || QUERY_DTYPE == "fp8_e4m3")) \
826+
{ \
827+
if(SRC_DTYPE == at::ScalarType::Float) \
828+
{ \
829+
FN(float, ck_tile::fp8_t, ck_tile::fp8_t, vllm::Fp8KVCacheDataType::kFp8E4M3, vllm::Fp8KVCacheDataType::kFp8E4M3); \
830+
} \
831+
else if(SRC_DTYPE == at::ScalarType::Half) \
832+
{ \
833+
FN(ck_tile::fp16_t, ck_tile::fp8_t, ck_tile::fp8_t, vllm::Fp8KVCacheDataType::kFp8E4M3, vllm::Fp8KVCacheDataType::kFp8E4M3);\
834+
} \
835+
else if(SRC_DTYPE == at::ScalarType::BFloat16) \
836+
{ \
837+
FN(ck_tile::bf16_t, ck_tile::fp8_t, ck_tile::fp8_t,vllm::Fp8KVCacheDataType::kFp8E4M3, vllm::Fp8KVCacheDataType::kFp8E4M3); \
838+
} \
839+
else \
840+
{ \
841+
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
842+
} \
843+
} \
844+
else \
845+
{ \
846+
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE, "Query type: ", QUERY_DTYPE); \
847+
}
848+
786849
} // namespace fp8
787850
#endif // USE_ROCM
788851
} // namespace vllm

csrc/include/rocm_ops.hpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,36 @@ namespace py = pybind11;
311311
py::arg("dst_k"), \
312312
py::arg("dst_scale"), \
313313
py::arg("block_table"), \
314-
py::arg("cu_seq_lens"));
314+
py::arg("cu_seq_lens")); \
315+
m.def("fused_qk_rope_concat_and_cache_mla", \
316+
&aiter::fused_qk_rope_concat_and_cache_mla, \
317+
"fused_qk_rope_concat_and_cache_mla(" \
318+
" Tensor q_nope, Tensor q_pe," \
319+
" Tensor kv_c, Tensor k_pe," \
320+
" Tensor! kv_cache," \
321+
" Tensor! q_out, " \
322+
" Tensor slot_mapping," \
323+
" Tensor k_scale," \
324+
" Tensor q_scale," \
325+
" Tensor positions," \
326+
" Tensor cos_cache," \
327+
" Tensor sin_cache," \
328+
" bool is_neox ," \
329+
" bool is_nope_first)->()", \
330+
py::arg("q_nope"), \
331+
py::arg("q_pe"), \
332+
py::arg("kv_c"), \
333+
py::arg("k_pe"), \
334+
py::arg("kv_cache"), \
335+
py::arg("q_out"), \
336+
py::arg("slot_mapping"), \
337+
py::arg("k_scale"), \
338+
py::arg("q_scale"), \
339+
py::arg("positions"), \
340+
py::arg("cos_cache"), \
341+
py::arg("sin_cache"), \
342+
py::arg("is_neox"), \
343+
py::arg("is_nope_first"));
315344

316345
#define CUSTOM_ALL_REDUCE_PYBIND \
317346
m.def("init_custom_ar", \

0 commit comments

Comments
 (0)