Skip to content

Commit f745add

Browse files
committed
feat: add flashinfer as cuda kernel.
1 parent 6e5607b commit f745add

File tree

10 files changed

+365
-59
lines changed

10 files changed

+365
-59
lines changed

xllm/core/framework/batch/batch_input_builder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class BatchInputBuilder {
8686
#if defined(USE_NPU)
8787
std::vector<int32_t> seq_lens;
8888
std::vector<int32_t> q_seq_lens;
89-
#elif defined(USE_MLU)
89+
#elif defined(USE_MLU) || defined(USE_CUDA)
9090
std::vector<int32_t> seq_lens = {0}; // cu_seq_lens
9191
std::vector<int32_t> q_seq_lens = {0}; // q_cu_seq_len
9292
#endif

xllm/core/kernels/cuda/batch_decode.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,4 +269,53 @@ void BatchDecodeWithPagedKVCacheRun(TensorView float_workspace_buffer,
269269
});
270270
}
271271

272+
void batch_decode(torch::Tensor& float_workspace_buffer,
273+
torch::Tensor& int_workspace_buffer,
274+
torch::Tensor& page_locked_int_workspace_buffer,
275+
const torch::Tensor& query,
276+
const torch::Tensor& k_cache,
277+
const torch::Tensor& v_cache,
278+
const torch::Tensor& q_cu_seq_lens,
279+
const torch::Tensor& paged_kv_indptr,
280+
const torch::Tensor& paged_kv_indices,
281+
const torch::Tensor& paged_kv_last_page_len,
282+
int64_t window_size_left,
283+
torch::Tensor& output,
284+
std::optional<torch::Tensor>& output_lse,
285+
bool enable_cuda_graph,
286+
bool enable_pdl) {
287+
const int64_t batch_size = q_cu_seq_lens.size(0) - 1;
288+
Array<int64_t> plan_info_vec = BatchDecodeWithPagedKVCachePlan(
289+
float_workspace_buffer,
290+
int_workspace_buffer,
291+
page_locked_int_workspace_buffer,
292+
q_cu_seq_lens,
293+
batch_size,
294+
query.size(2), // num_qo_heads
295+
k_cache.size(2), // num_kv_heads
296+
k_cache.size(1), // page_size
297+
enable_cuda_graph,
298+
window_size_left,
299+
/* logits_soft_cap*/ 0.0, // not used
300+
query.size(-1), // head_dim_qk
301+
v_cache.size(-1), // head_dim_vo
302+
torch::Tensor(), // empty_q_data, not used
303+
torch::Tensor()); // empty_kv_data, not used
304+
305+
BatchDecodeWithPagedKVCacheRun(float_workspace_buffer,
306+
int_workspace_buffer,
307+
plan_info_vec,
308+
query,
309+
k_cache,
310+
v_cache,
311+
paged_kv_indptr,
312+
paged_kv_indices,
313+
paged_kv_last_page_len,
314+
output,
315+
output_lse,
316+
/*kv_layout_code=*/0,
317+
window_size_left,
318+
enable_pdl);
319+
}
320+
272321
} // namespace xllm::kernel::cuda

xllm/core/kernels/cuda/batch_prefill.cpp

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,162 @@ Array<int64_t> BatchPrefillWithKVCachePlan(
119119
return Array(plan_info.ToVector());
120120
}
121121

122+
void BatchPrefillWithRaggedKVCacheRun(TensorView float_workspace_buffer,
123+
TensorView int_workspace_buffer,
124+
Array<int64_t> plan_info_vec,
125+
TensorView q,
126+
TensorView k,
127+
TensorView v,
128+
TensorView qo_indptr,
129+
TensorView kv_indptr,
130+
TensorView o,
131+
Optional<TensorView> maybe_lse,
132+
int64_t mask_mode_code,
133+
int64_t layout,
134+
int64_t window_left,
135+
bool enable_pdl ADDITIONAL_FUNC_PARAMS) {
136+
PrefillPlanInfo plan_info;
137+
plan_info.FromVector(
138+
std::vector<int64_t>(plan_info_vec.begin(), plan_info_vec.end()));
139+
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
140+
141+
int64_t num_qo_heads = q->shape[1];
142+
int64_t head_dim_qk = q->shape[2];
143+
int64_t num_kv_heads =
144+
(kv_layout == QKVLayout::kNHD) ? k->shape[1] : k->shape[0];
145+
uint32_t q_stride_n = q->strides[0], q_stride_h = q->strides[1], k_stride_n,
146+
k_stride_h, v_stride_n, v_stride_h;
147+
if (kv_layout == QKVLayout::kNHD) {
148+
k_stride_n = k->strides[0];
149+
k_stride_h = k->strides[1];
150+
v_stride_n = v->strides[0];
151+
v_stride_h = v->strides[1];
152+
} else {
153+
k_stride_h = k->strides[0];
154+
k_stride_n = k->strides[1];
155+
v_stride_h = v->strides[0];
156+
v_stride_n = v->strides[1];
157+
}
158+
159+
if (maybe_lse.has_value()) {
160+
const auto& lse = *maybe_lse;
161+
TVM_FFI_ICHECK_EQ(lse->shape[0], q->shape[0]);
162+
TVM_FFI_ICHECK_EQ(lse->shape[1], q->shape[1]);
163+
}
164+
165+
void* float_buffer_ptr = float_workspace_buffer->data;
166+
void* int_buffer_ptr = int_workspace_buffer->data;
167+
168+
const MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
169+
170+
cudaSetDevice(float_workspace_buffer->device.device_id);
171+
const cudaStream_t stream = get_stream(float_workspace_buffer->device);
172+
173+
DISPATCH_context(
174+
DTypeQ,
175+
DTypeKV,
176+
DTypeO,
177+
IdType,
178+
MASK_MODE,
179+
HEAD_DIM_QK,
180+
HEAD_DIM_VO,
181+
POS_ENCODING_MODE,
182+
USE_SLIDING_WINDOW,
183+
USE_LOGITS_SOFT_CAP,
184+
USE_FP16_QK_REDUCTION,
185+
AttentionVariant,
186+
RaggedParams,
187+
PagedParams,
188+
[&] {
189+
RaggedParams params;
190+
191+
params.q = static_cast<DTypeQ*>(q->data);
192+
params.k = static_cast<DTypeKV*>(k->data);
193+
params.v = static_cast<DTypeKV*>(v->data);
194+
params.o = static_cast<DTypeO*>(o->data);
195+
params.lse = maybe_lse.has_value()
196+
? static_cast<float*>(maybe_lse.value()->data)
197+
: nullptr;
198+
params.q_indptr = static_cast<IdType*>(qo_indptr->data);
199+
params.kv_indptr = static_cast<IdType*>(kv_indptr->data);
200+
params.num_qo_heads = num_qo_heads;
201+
params.num_kv_heads = num_kv_heads;
202+
params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads);
203+
params.q_stride_n = q_stride_n;
204+
params.q_stride_h = q_stride_h;
205+
params.k_stride_n = k_stride_n;
206+
params.k_stride_h = k_stride_h;
207+
params.v_stride_n = v_stride_n;
208+
params.v_stride_h = v_stride_h;
209+
params.window_left = window_left;
210+
211+
params.request_indices = nullptr;
212+
params.qo_tile_indices = nullptr;
213+
params.kv_tile_indices = nullptr;
214+
params.merge_indptr = nullptr;
215+
params.o_indptr = nullptr;
216+
params.kv_chunk_size_ptr = nullptr;
217+
params.block_valid_mask = nullptr;
218+
params.total_num_rows = nullptr;
219+
params.max_total_num_rows = 0;
220+
params.padded_batch_size = 0;
221+
params.partition_kv = false;
222+
223+
ADDITIONAL_PARAMS_SETTER
224+
225+
DTypeO* tmp_v = nullptr;
226+
float* tmp_s = nullptr;
227+
228+
params.request_indices = GetPtrFromBaseOffset<IdType>(
229+
int_buffer_ptr, plan_info.request_indices_offset);
230+
params.qo_tile_indices = GetPtrFromBaseOffset<IdType>(
231+
int_buffer_ptr, plan_info.qo_tile_indices_offset);
232+
params.kv_tile_indices = GetPtrFromBaseOffset<IdType>(
233+
int_buffer_ptr, plan_info.kv_tile_indices_offset);
234+
params.o_indptr = GetPtrFromBaseOffset<IdType>(
235+
int_buffer_ptr, plan_info.o_indptr_offset);
236+
params.kv_chunk_size_ptr = GetPtrFromBaseOffset<IdType>(
237+
int_buffer_ptr, plan_info.kv_chunk_size_ptr_offset);
238+
if (plan_info.split_kv) {
239+
params.merge_indptr = GetPtrFromBaseOffset<IdType>(
240+
int_buffer_ptr, plan_info.merge_indptr_offset);
241+
tmp_v = GetPtrFromBaseOffset<DTypeO>(float_buffer_ptr,
242+
plan_info.v_offset);
243+
tmp_s =
244+
GetPtrFromBaseOffset<float>(float_buffer_ptr, plan_info.s_offset);
245+
if (plan_info.enable_cuda_graph) {
246+
params.block_valid_mask = GetPtrFromBaseOffset<bool>(
247+
int_buffer_ptr, plan_info.block_valid_mask_offset);
248+
}
249+
}
250+
params.padded_batch_size = plan_info.padded_batch_size;
251+
params.max_total_num_rows = plan_info.total_num_rows;
252+
if (plan_info.enable_cuda_graph) {
253+
params.total_num_rows = GetPtrFromBaseOffset<uint32_t>(
254+
int_buffer_ptr, plan_info.total_num_rows_offset);
255+
}
256+
257+
cudaError_t status = cudaSuccess;
258+
259+
DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, {
260+
status = flashinfer::BatchPrefillWithRaggedKVCacheDispatched<
261+
CTA_TILE_Q,
262+
HEAD_DIM_QK,
263+
HEAD_DIM_VO,
264+
POS_ENCODING_MODE,
265+
/*use_fp16_qk_reduction=*/USE_FP16_QK_REDUCTION,
266+
MASK_MODE,
267+
AttentionVariant,
268+
RaggedParams>(params, tmp_v, tmp_s, enable_pdl, stream);
269+
});
270+
271+
TVM_FFI_ICHECK(status == cudaSuccess)
272+
<< "BatchPrefillWithRaggedKVCache failed with error "
273+
<< cudaGetErrorString(status);
274+
return true;
275+
});
276+
}
277+
122278
void BatchPrefillWithPagedKVCacheRun(TensorView float_workspace_buffer,
123279
TensorView int_workspace_buffer,
124280
Array<int64_t> plan_info_vec,
@@ -287,4 +443,58 @@ void BatchPrefillWithPagedKVCacheRun(TensorView float_workspace_buffer,
287443
});
288444
}
289445

446+
void batch_prefill(torch::Tensor& float_workspace_buffer,
447+
torch::Tensor& int_workspace_buffer,
448+
torch::Tensor& page_locked_int_workspace_buffer,
449+
const torch::Tensor& query,
450+
const torch::Tensor& key,
451+
const torch::Tensor& value,
452+
const torch::Tensor& q_cu_seq_lens,
453+
const torch::Tensor& kv_cu_seq_lens,
454+
int64_t window_size_left,
455+
torch::Tensor& output,
456+
std::optional<torch::Tensor>& output_lse,
457+
bool enable_cuda_graph,
458+
bool enable_pdl) {
459+
torch::Tensor kv_len_arr =
460+
kv_cu_seq_lens.slice(0, 1) - kv_cu_seq_lens.slice(0, 0, -1);
461+
const int64_t total_num_rows = q_cu_seq_lens[-1].item<int64_t>();
462+
const int64_t batch_size = q_cu_seq_lens.size(0) - 1;
463+
464+
Array<int64_t> plan_info_vec = BatchPrefillWithKVCachePlan(
465+
float_workspace_buffer,
466+
int_workspace_buffer,
467+
page_locked_int_workspace_buffer,
468+
q_cu_seq_lens,
469+
kv_cu_seq_lens,
470+
kv_len_arr,
471+
total_num_rows,
472+
batch_size,
473+
query.size(-1), // num_qo_heads
474+
key.size(-1), // num_kv_heads
475+
/*page_size=*/1,
476+
enable_cuda_graph, // enable_cuda_graph
477+
query.size(1), // head_dim_qk
478+
query.size(1), // head_dim_vo
479+
/*causal=*/true,
480+
window_size_left,
481+
/*fixed_split_size=*/-1,
482+
/*disable_split_kv=*/false); // disable_split_kv
483+
484+
BatchPrefillWithRaggedKVCacheRun(float_workspace_buffer,
485+
int_workspace_buffer,
486+
plan_info_vec,
487+
query,
488+
key,
489+
value,
490+
q_cu_seq_lens,
491+
kv_cu_seq_lens,
492+
output,
493+
output_lse,
494+
/*mask_mode_code=CAUSAL*/ 1,
495+
/*layout=*/0,
496+
window_size_left,
497+
enable_pdl);
498+
}
499+
290500
} // namespace xllm::kernel::cuda

xllm/core/kernels/cuda/cuda_ops_api.h

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,37 +41,35 @@ void reshape_paged_cache(
4141
torch::Tensor& key_cache, // [n_blocks, block_size, n_heads, head_dim]
4242
torch::Tensor& value_cache);
4343

44-
void BatchPrefillWithPagedKVCacheRun(TensorView float_workspace_buffer,
45-
TensorView int_workspace_buffer,
46-
Array<int64_t> plan_info_vec,
47-
TensorView q,
48-
TensorView paged_k_cache,
49-
TensorView paged_v_cache,
50-
TensorView qo_indptr,
51-
TensorView paged_kv_indptr,
52-
TensorView paged_kv_indices,
53-
TensorView paged_kv_last_page_len,
54-
TensorView o,
55-
Optional<TensorView> maybe_lse,
56-
int64_t mask_mode_code,
57-
int64_t layout,
58-
int64_t window_left,
59-
bool enable_pdl ADDITIONAL_FUNC_PARAMS);
44+
void batch_prefill(torch::Tensor& float_workspace_buffer,
45+
torch::Tensor& int_workspace_buffer,
46+
torch::Tensor& page_locked_int_workspace_buffer,
47+
const torch::Tensor& query,
48+
const torch::Tensor& key,
49+
const torch::Tensor& value,
50+
const torch::Tensor& q_cu_seq_lens,
51+
const torch::Tensor& kv_cu_seq_lens,
52+
int64_t window_size_left,
53+
torch::Tensor& output,
54+
std::optional<torch::Tensor>& output_lse,
55+
bool enable_cuda_graph,
56+
bool enable_pdl);
6057

61-
void BatchDecodeWithPagedKVCacheRun(TensorView float_workspace_buffer,
62-
TensorView int_workspace_buffer,
63-
Array<int64_t> plan_info_vec,
64-
TensorView q,
65-
TensorView paged_k_cache,
66-
TensorView paged_v_cache,
67-
TensorView paged_kv_indptr,
68-
TensorView paged_kv_indices,
69-
TensorView paged_kv_last_page_len,
70-
TensorView o,
71-
Optional<TensorView> maybe_lse,
72-
int64_t kv_layout_code,
73-
int64_t window_left,
74-
bool enable_pdl ADDITIONAL_FUNC_PARAMS);
58+
void batch_decode(torch::Tensor& float_workspace_buffer,
59+
torch::Tensor& int_workspace_buffer,
60+
torch::Tensor& page_locked_int_workspace_buffer,
61+
const torch::Tensor& query,
62+
const torch::Tensor& k_cache,
63+
const torch::Tensor& v_cache,
64+
const torch::Tensor& q_cu_seq_lens,
65+
const torch::Tensor& paged_kv_indptr,
66+
const torch::Tensor& paged_kv_indices,
67+
const torch::Tensor& paged_kv_last_page_len,
68+
int64_t window_size_left,
69+
torch::Tensor& output,
70+
std::optional<torch::Tensor>& output_lse,
71+
bool enable_cuda_graph,
72+
bool enable_pdl);
7573

7674
void rmsnorm(TensorView output,
7775
TensorView input,

0 commit comments

Comments
 (0)