Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o,
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
int64_t layout, int64_t window_left,
std::optional<at::Tensor> maybe_q_rope_offset,
std::optional<at::Tensor> maybe_k_rope_offset,
bool enable_pdl ADDITIONAL_FUNC_PARAMS) {
PrefillPlanInfo plan_info;
plan_info.FromVector(tensor_to_vec(plan_info_vec));
Expand Down Expand Up @@ -128,6 +130,8 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
params.lse = maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr;
params.q_indptr = static_cast<IdType*>(qo_indptr.data_ptr());
params.kv_indptr = static_cast<IdType*>(kv_indptr.data_ptr());
params.maybe_q_rope_offset = maybe_q_rope_offset ? static_cast<IdType*>(maybe_q_rope_offset->data_ptr()) : nullptr;
params.maybe_k_rope_offset = maybe_k_rope_offset ? static_cast<IdType*>(maybe_k_rope_offset->data_ptr()) : nullptr;
params.num_qo_heads = num_qo_heads;
params.num_kv_heads = num_kv_heads;
params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads);
Expand Down Expand Up @@ -202,7 +206,8 @@ void BatchPrefillWithPagedKVCacheRun(
at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr,
at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len,
at::Tensor o, std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code, int64_t layout,
int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS) {
int64_t window_left, std::optional<at::Tensor> maybe_q_rope_offset,
bool enable_pdl ADDITIONAL_FUNC_PARAMS) {
PrefillPlanInfo plan_info;
plan_info.FromVector(tensor_to_vec(plan_info_vec));
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
Expand Down Expand Up @@ -263,8 +268,8 @@ void BatchPrefillWithPagedKVCacheRun(
params.paged_kv = paged_kv;
params.q_indptr = static_cast<IdType*>(qo_indptr.data_ptr());
params.o = static_cast<DTypeO*>(o.data_ptr());

params.lse = maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr;
params.maybe_q_rope_offset = maybe_q_rope_offset ? static_cast<IdType*>(maybe_q_rope_offset->data_ptr()) : nullptr;
params.num_qo_heads = num_qo_heads;
params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads);
params.q_stride_n = q_stride_n;
Expand Down
55 changes: 52 additions & 3 deletions csrc/batch_prefill_jit_pybind.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,69 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o,
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
int64_t layout, int64_t window_left,
std::optional<at::Tensor> maybe_q_rope_offset,
std::optional<at::Tensor> maybe_k_rope_offset,
bool enable_pdl ADDITIONAL_FUNC_PARAMS);

void BatchPrefillWithPagedKVCacheRun(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr,
at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len,
at::Tensor o, std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code, int64_t layout,
int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS);
int64_t window_left, std::optional<at::Tensor> maybe_q_rope_offset,
bool enable_pdl ADDITIONAL_FUNC_PARAMS);

TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
// Batch-request prefill attention with KV-Cache plan
m.def("plan", BatchPrefillWithKVCachePlan);
// Batch-request prefill attention with KV-Cache operator
m.def("ragged_run", BatchPrefillWithRaggedKVCacheRun);
m.def("ragged_run",
[](at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer,
at::Tensor plan_info_vec,
at::Tensor q,
at::Tensor k,
at::Tensor v,
at::Tensor qo_indptr,
at::Tensor kv_indptr,
at::Tensor o,
std::optional<at::Tensor> maybe_lse,
int64_t mask_mode_code,
int64_t layout,
int64_t window_left,
std::optional<at::Tensor> maybe_q_rope_offset,
std::optional<at::Tensor> maybe_k_rope_offset,
bool enable_pdl,
std::vector<at::Tensor> additional_params) {
return BatchPrefillWithRaggedKVCacheRun(
float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, k, v,
qo_indptr, kv_indptr, o, maybe_lse, mask_mode_code, layout, window_left,
maybe_q_rope_offset, maybe_k_rope_offset, enable_pdl, additional_params);
});
// Batch-request prefill attention with KV-Cache operator
m.def("paged_run", BatchPrefillWithPagedKVCacheRun);
m.def("paged_run",
[](at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer,
at::Tensor plan_info_vec,
at::Tensor q,
at::Tensor paged_k_cache,
at::Tensor paged_v_cache,
at::Tensor qo_indptr,
at::Tensor paged_kv_indptr,
at::Tensor paged_kv_indices,
at::Tensor paged_kv_last_page_len,
at::Tensor o,
std::optional<at::Tensor> maybe_lse,
int64_t mask_mode_code,
int64_t layout,
int64_t window_left,
std::optional<at::Tensor> maybe_q_rope_offset,
bool enable_pdl,
std::vector<at::Tensor> additional_params) {
return BatchPrefillWithPagedKVCacheRun(
float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, paged_k_cache,
paged_v_cache, qo_indptr, paged_kv_indptr, paged_kv_indices,
paged_kv_last_page_len, o, maybe_lse, mask_mode_code, layout, window_left,
maybe_q_rope_offset, enable_pdl, additional_params);
});
}
59 changes: 54 additions & 5 deletions csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,18 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
at::Tensor q, at::Tensor k, at::Tensor v,
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o,
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
int64_t layout,
int64_t window_left BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS);
int64_t layout, int64_t window_left,
std::optional<at::Tensor> maybe_q_rope_offset,
std::optional<at::Tensor> maybe_k_rope_offset,
bool enable_pdl BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS);

void BatchPrefillWithPagedKVCacheRun(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr,
at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len,
at::Tensor o, std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code, int64_t layout,
int64_t window_left BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS);
int64_t window_left, std::optional<at::Tensor> maybe_q_rope_offset,
bool enable_pdl BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS);

//========== pod-attention =========
void pod_with_kv_cache_tensor(
Expand Down Expand Up @@ -275,8 +278,54 @@ TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
// Single-request prefill attention with KV-Cache operator
m.def("single_prefill_with_kv_cache", single_prefill_with_kv_cache);
m.def("batch_prefill_with_kv_cache_plan", BatchPrefillWithKVCachePlan);
m.def("batch_prefill_with_ragged_kv_cache_run", BatchPrefillWithRaggedKVCacheRun);
m.def("batch_prefill_with_paged_kv_cache_run", BatchPrefillWithPagedKVCacheRun);
m.def("batch_prefill_with_ragged_kv_cache_run",
[](at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer,
at::Tensor plan_info_vec,
at::Tensor q,
at::Tensor k,
at::Tensor v,
at::Tensor qo_indptr,
at::Tensor kv_indptr,
at::Tensor o,
std::optional<at::Tensor> maybe_lse,
int64_t mask_mode_code,
int64_t layout,
int64_t window_left,
std::optional<at::Tensor> maybe_q_rope_offset,
std::optional<at::Tensor> maybe_k_rope_offset,
bool enable_pdl,
std::vector<at::Tensor> additional_params) {
return BatchPrefillWithRaggedKVCacheRun(
float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, k, v,
qo_indptr, kv_indptr, o, maybe_lse, mask_mode_code, layout, window_left,
maybe_q_rope_offset, maybe_k_rope_offset, enable_pdl, additional_params);
});
m.def("batch_prefill_with_paged_kv_cache_run",
[](at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer,
at::Tensor plan_info_vec,
at::Tensor q,
at::Tensor paged_k_cache,
at::Tensor paged_v_cache,
at::Tensor qo_indptr,
at::Tensor paged_kv_indptr,
at::Tensor paged_kv_indices,
at::Tensor paged_kv_last_page_len,
at::Tensor o,
std::optional<at::Tensor> maybe_lse,
int64_t mask_mode_code,
int64_t layout,
int64_t window_left,
std::optional<at::Tensor> maybe_q_rope_offset,
bool enable_pdl,
std::vector<at::Tensor> additional_params) {
return BatchPrefillWithPagedKVCacheRun(
float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, paged_k_cache,
paged_v_cache, qo_indptr, paged_kv_indptr, paged_kv_indices,
paged_kv_last_page_len, o, maybe_lse, mask_mode_code, layout, window_left,
maybe_q_rope_offset, enable_pdl, additional_params);
});

// pod-attention
// Temporarily disabled because we don't generate the implementation yet.
Expand Down
2 changes: 2 additions & 0 deletions flashinfer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def run(
lse: Optional[torch.Tensor] = None,
logits_soft_cap: float = 0.0,
profiler_buffer: Optional[torch.Tensor] = None,
q_rope_offset: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if profiler_buffer is None:
if self._use_profiler:
Expand Down Expand Up @@ -180,6 +181,7 @@ def run(
self._page_size,
self._sm_scale,
logits_soft_cap,
q_rope_offset,
# ADDITIONAL_FUNC_PARAMS
# PROFILER_FUNC_PARAMS
*profiler_args,
Expand Down