Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o,
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
int64_t layout, int64_t window_left,
std::optional<at::Tensor> maybe_q_rope_offset,
std::optional<at::Tensor> maybe_k_rope_offset,
bool enable_pdl ADDITIONAL_FUNC_PARAMS) {
PrefillPlanInfo plan_info;
plan_info.FromVector(tensor_to_vec(plan_info_vec));
Expand Down Expand Up @@ -128,6 +130,8 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
params.lse = maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr;
params.q_indptr = static_cast<IdType*>(qo_indptr.data_ptr());
params.kv_indptr = static_cast<IdType*>(kv_indptr.data_ptr());
params.maybe_q_rope_offset = maybe_q_rope_offset ? static_cast<IdType*>(maybe_q_rope_offset->data_ptr()) : nullptr;
params.maybe_k_rope_offset = maybe_k_rope_offset ? static_cast<IdType*>(maybe_k_rope_offset->data_ptr()) : nullptr;
params.num_qo_heads = num_qo_heads;
params.num_kv_heads = num_kv_heads;
params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads);
Expand Down Expand Up @@ -202,7 +206,8 @@ void BatchPrefillWithPagedKVCacheRun(
at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr,
at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len,
at::Tensor o, std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code, int64_t layout,
int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS) {
int64_t window_left, std::optional<at::Tensor> maybe_q_rope_offset,
bool enable_pdl ADDITIONAL_FUNC_PARAMS) {
PrefillPlanInfo plan_info;
plan_info.FromVector(tensor_to_vec(plan_info_vec));
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
Expand Down Expand Up @@ -263,8 +268,8 @@ void BatchPrefillWithPagedKVCacheRun(
params.paged_kv = paged_kv;
params.q_indptr = static_cast<IdType*>(qo_indptr.data_ptr());
params.o = static_cast<DTypeO*>(o.data_ptr());

params.lse = maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr;
params.maybe_q_rope_offset = maybe_q_rope_offset ? static_cast<IdType*>(maybe_q_rope_offset->data_ptr()) : nullptr;
params.num_qo_heads = num_qo_heads;
params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads);
params.q_stride_n = q_stride_n;
Expand Down
55 changes: 52 additions & 3 deletions csrc/batch_prefill_jit_pybind.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,69 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o,
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
int64_t layout, int64_t window_left,
std::optional<at::Tensor> maybe_q_rope_offset,
std::optional<at::Tensor> maybe_k_rope_offset,
bool enable_pdl ADDITIONAL_FUNC_PARAMS);

void BatchPrefillWithPagedKVCacheRun(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr,
at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len,
at::Tensor o, std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code, int64_t layout,
int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS);
int64_t window_left, std::optional<at::Tensor> maybe_q_rope_offset,
bool enable_pdl ADDITIONAL_FUNC_PARAMS);

TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
// Batch-request prefill attention with KV-Cache plan
m.def("plan", BatchPrefillWithKVCachePlan);
// Batch-request prefill attention with KV-Cache operator
m.def("ragged_run", BatchPrefillWithRaggedKVCacheRun);
m.def("ragged_run",
[](at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer,
at::Tensor plan_info_vec,
at::Tensor q,
at::Tensor k,
at::Tensor v,
at::Tensor qo_indptr,
at::Tensor kv_indptr,
at::Tensor o,
std::optional<at::Tensor> maybe_lse,
int64_t mask_mode_code,
int64_t layout,
int64_t window_left,
std::optional<at::Tensor> maybe_q_rope_offset,
std::optional<at::Tensor> maybe_k_rope_offset,
bool enable_pdl,
std::vector<at::Tensor> additional_params) {
return BatchPrefillWithRaggedKVCacheRun(
float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, k, v,
qo_indptr, kv_indptr, o, maybe_lse, mask_mode_code, layout, window_left,
maybe_q_rope_offset, maybe_k_rope_offset, enable_pdl, additional_params);
});
// Batch-request prefill attention with KV-Cache operator
m.def("paged_run", BatchPrefillWithPagedKVCacheRun);
m.def("paged_run",
[](at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer,
at::Tensor plan_info_vec,
at::Tensor q,
at::Tensor paged_k_cache,
at::Tensor paged_v_cache,
at::Tensor qo_indptr,
at::Tensor paged_kv_indptr,
at::Tensor paged_kv_indices,
at::Tensor paged_kv_last_page_len,
at::Tensor o,
std::optional<at::Tensor> maybe_lse,
int64_t mask_mode_code,
int64_t layout,
int64_t window_left,
std::optional<at::Tensor> maybe_q_rope_offset,
bool enable_pdl,
std::vector<at::Tensor> additional_params) {
return BatchPrefillWithPagedKVCacheRun(
float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, paged_k_cache,
paged_v_cache, qo_indptr, paged_kv_indptr, paged_kv_indices,
paged_kv_last_page_len, o, maybe_lse, mask_mode_code, layout, window_left,
maybe_q_rope_offset, enable_pdl, additional_params);
});
}
59 changes: 54 additions & 5 deletions csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,18 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
at::Tensor q, at::Tensor k, at::Tensor v,
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o,
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
int64_t layout,
int64_t window_left BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS);
int64_t layout, int64_t window_left,
std::optional<at::Tensor> maybe_q_rope_offset,
std::optional<at::Tensor> maybe_k_rope_offset,
bool enable_pdl BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS);

void BatchPrefillWithPagedKVCacheRun(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr,
at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len,
at::Tensor o, std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code, int64_t layout,
int64_t window_left BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS);
int64_t window_left, std::optional<at::Tensor> maybe_q_rope_offset,
bool enable_pdl BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS);

//========== pod-attention =========
void pod_with_kv_cache_tensor(
Expand Down Expand Up @@ -275,8 +278,54 @@ TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
// Single-request prefill attention with KV-Cache operator
m.def("single_prefill_with_kv_cache", single_prefill_with_kv_cache);
m.def("batch_prefill_with_kv_cache_plan", BatchPrefillWithKVCachePlan);
m.def("batch_prefill_with_ragged_kv_cache_run", BatchPrefillWithRaggedKVCacheRun);
m.def("batch_prefill_with_paged_kv_cache_run", BatchPrefillWithPagedKVCacheRun);
m.def("batch_prefill_with_ragged_kv_cache_run",
[](at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer,
at::Tensor plan_info_vec,
at::Tensor q,
at::Tensor k,
at::Tensor v,
at::Tensor qo_indptr,
at::Tensor kv_indptr,
at::Tensor o,
std::optional<at::Tensor> maybe_lse,
int64_t mask_mode_code,
int64_t layout,
int64_t window_left,
std::optional<at::Tensor> maybe_q_rope_offset,
std::optional<at::Tensor> maybe_k_rope_offset,
bool enable_pdl,
std::vector<at::Tensor> additional_params) {
return BatchPrefillWithRaggedKVCacheRun(
float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, k, v,
qo_indptr, kv_indptr, o, maybe_lse, mask_mode_code, layout, window_left,
maybe_q_rope_offset, maybe_k_rope_offset, enable_pdl, additional_params);
});
m.def("batch_prefill_with_paged_kv_cache_run",
[](at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer,
at::Tensor plan_info_vec,
at::Tensor q,
at::Tensor paged_k_cache,
at::Tensor paged_v_cache,
at::Tensor qo_indptr,
at::Tensor paged_kv_indptr,
at::Tensor paged_kv_indices,
at::Tensor paged_kv_last_page_len,
at::Tensor o,
std::optional<at::Tensor> maybe_lse,
int64_t mask_mode_code,
int64_t layout,
int64_t window_left,
std::optional<at::Tensor> maybe_q_rope_offset,
bool enable_pdl,
std::vector<at::Tensor> additional_params) {
return BatchPrefillWithPagedKVCacheRun(
float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, paged_k_cache,
paged_v_cache, qo_indptr, paged_kv_indptr, paged_kv_indices,
paged_kv_last_page_len, o, maybe_lse, mask_mode_code, layout, window_left,
maybe_q_rope_offset, enable_pdl, additional_params);
});

// pod-attention
// Temporarily disabled because we don't generate the implementation yet.
Expand Down
4 changes: 4 additions & 0 deletions flashinfer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def run(
lse: Optional[torch.Tensor] = None,
logits_soft_cap: float = 0.0,
profiler_buffer: Optional[torch.Tensor] = None,
q_rope_offset: Optional[torch.Tensor] = None,
k_rope_offset: Optional[torch.Tensor] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The k_rope_offset parameter is not supported by the paged attention kernel that BatchAttention wraps. The underlying C++ function BatchPrefillWithPagedKVCacheRun only accepts q_rope_offset. Including k_rope_offset here will lead to a runtime error when calling the kernel. Please remove this parameter.

) -> Tuple[torch.Tensor, torch.Tensor]:
if profiler_buffer is None:
if self._use_profiler:
Expand Down Expand Up @@ -180,6 +182,8 @@ def run(
self._page_size,
self._sm_scale,
logits_soft_cap,
q_rope_offset,
k_rope_offset,
# ADDITIONAL_FUNC_PARAMS
# PROFILER_FUNC_PARAMS
*profiler_args,
Expand Down