diff --git a/csrc/batch_prefill.cu b/csrc/batch_prefill.cu index a51fc7f56..469a81c4c 100644 --- a/csrc/batch_prefill.cu +++ b/csrc/batch_prefill.cu @@ -76,6 +76,8 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o, std::optional maybe_lse, int64_t mask_mode_code, int64_t layout, int64_t window_left, + std::optional maybe_q_rope_offset, + std::optional maybe_k_rope_offset, bool enable_pdl ADDITIONAL_FUNC_PARAMS) { PrefillPlanInfo plan_info; plan_info.FromVector(tensor_to_vec(plan_info_vec)); @@ -128,6 +130,8 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer, params.lse = maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr; params.q_indptr = static_cast(qo_indptr.data_ptr()); params.kv_indptr = static_cast(kv_indptr.data_ptr()); + params.maybe_q_rope_offset = maybe_q_rope_offset ? static_cast(maybe_q_rope_offset->data_ptr()) : nullptr; + params.maybe_k_rope_offset = maybe_k_rope_offset ? static_cast(maybe_k_rope_offset->data_ptr()) : nullptr; params.num_qo_heads = num_qo_heads; params.num_kv_heads = num_kv_heads; params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads); @@ -202,7 +206,8 @@ void BatchPrefillWithPagedKVCacheRun( at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, std::optional maybe_lse, int64_t mask_mode_code, int64_t layout, - int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS) { + int64_t window_left, std::optional maybe_q_rope_offset, + bool enable_pdl ADDITIONAL_FUNC_PARAMS) { PrefillPlanInfo plan_info; plan_info.FromVector(tensor_to_vec(plan_info_vec)); QKVLayout kv_layout = static_cast(layout); @@ -263,8 +268,8 @@ void BatchPrefillWithPagedKVCacheRun( params.paged_kv = paged_kv; params.q_indptr = static_cast(qo_indptr.data_ptr()); params.o = static_cast(o.data_ptr()); - params.lse = maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr; + params.maybe_q_rope_offset = maybe_q_rope_offset ? static_cast(maybe_q_rope_offset->data_ptr()) : nullptr; params.num_qo_heads = num_qo_heads; params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads); params.q_stride_n = q_stride_n; diff --git a/csrc/batch_prefill_jit_pybind.cu b/csrc/batch_prefill_jit_pybind.cu index 5421ab1cf..740a9623e 100644 --- a/csrc/batch_prefill_jit_pybind.cu +++ b/csrc/batch_prefill_jit_pybind.cu @@ -29,6 +29,8 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o, std::optional maybe_lse, int64_t mask_mode_code, int64_t layout, int64_t window_left, + std::optional maybe_q_rope_offset, + std::optional maybe_k_rope_offset, bool enable_pdl ADDITIONAL_FUNC_PARAMS); void BatchPrefillWithPagedKVCacheRun( @@ -36,13 +38,60 @@ void BatchPrefillWithPagedKVCacheRun( at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, std::optional maybe_lse, int64_t mask_mode_code, int64_t layout, - int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS); + int64_t window_left, std::optional maybe_q_rope_offset, + bool enable_pdl ADDITIONAL_FUNC_PARAMS); TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { // Batch-request prefill attention with KV-Cache plan m.def("plan", BatchPrefillWithKVCachePlan); // Batch-request prefill attention with KV-Cache operator - m.def("ragged_run", BatchPrefillWithRaggedKVCacheRun); + m.def("ragged_run", + [](at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, + at::Tensor plan_info_vec, + at::Tensor q, + at::Tensor k, + at::Tensor v, + at::Tensor qo_indptr, + at::Tensor kv_indptr, + at::Tensor o, + std::optional maybe_lse, + int64_t mask_mode_code, + int64_t layout, + int64_t window_left, + std::optional maybe_q_rope_offset, + std::optional maybe_k_rope_offset, + bool enable_pdl, + std::vector additional_params) { + return BatchPrefillWithRaggedKVCacheRun( + float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, k, v, + qo_indptr, kv_indptr, o, maybe_lse, mask_mode_code, layout, window_left, + maybe_q_rope_offset, maybe_k_rope_offset, enable_pdl, additional_params); + }); // Batch-request prefill attention with KV-Cache operator - m.def("paged_run", BatchPrefillWithPagedKVCacheRun); + m.def("paged_run", + [](at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, + at::Tensor plan_info_vec, + at::Tensor q, + at::Tensor paged_k_cache, + at::Tensor paged_v_cache, + at::Tensor qo_indptr, + at::Tensor paged_kv_indptr, + at::Tensor paged_kv_indices, + at::Tensor paged_kv_last_page_len, + at::Tensor o, + std::optional maybe_lse, + int64_t mask_mode_code, + int64_t layout, + int64_t window_left, + std::optional maybe_q_rope_offset, + bool enable_pdl, + std::vector additional_params) { + return BatchPrefillWithPagedKVCacheRun( + float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, paged_k_cache, + paged_v_cache, qo_indptr, paged_kv_indptr, paged_kv_indices, + paged_kv_last_page_len, o, maybe_lse, mask_mode_code, layout, window_left, + maybe_q_rope_offset, enable_pdl, additional_params); + }); } diff --git a/csrc/flashinfer_ops.cu b/csrc/flashinfer_ops.cu index 526ad969a..9877bf656 100644 --- a/csrc/flashinfer_ops.cu +++ b/csrc/flashinfer_ops.cu @@ -112,15 +112,18 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o, std::optional maybe_lse, int64_t mask_mode_code, - int64_t layout, - int64_t window_left BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS); + int64_t layout, int64_t window_left, + std::optional maybe_q_rope_offset, + std::optional maybe_k_rope_offset, + bool enable_pdl BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS); void BatchPrefillWithPagedKVCacheRun( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, std::optional maybe_lse, int64_t mask_mode_code, int64_t layout, - int64_t window_left BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS); + int64_t window_left, std::optional maybe_q_rope_offset, + bool enable_pdl BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS); //========== pod-attention ========= void pod_with_kv_cache_tensor( @@ -275,8 +278,54 @@ TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { // Single-request prefill attention with KV-Cache operator m.def("single_prefill_with_kv_cache", single_prefill_with_kv_cache); m.def("batch_prefill_with_kv_cache_plan", BatchPrefillWithKVCachePlan); - m.def("batch_prefill_with_ragged_kv_cache_run", BatchPrefillWithRaggedKVCacheRun); - m.def("batch_prefill_with_paged_kv_cache_run", BatchPrefillWithPagedKVCacheRun); + m.def("batch_prefill_with_ragged_kv_cache_run", + [](at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, + at::Tensor plan_info_vec, + at::Tensor q, + at::Tensor k, + at::Tensor v, + at::Tensor qo_indptr, + at::Tensor kv_indptr, + at::Tensor o, + std::optional maybe_lse, + int64_t mask_mode_code, + int64_t layout, + int64_t window_left, + std::optional maybe_q_rope_offset, + std::optional maybe_k_rope_offset, + bool enable_pdl, + std::vector additional_params) { + return BatchPrefillWithRaggedKVCacheRun( + float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, k, v, + qo_indptr, kv_indptr, o, maybe_lse, mask_mode_code, layout, window_left, + maybe_q_rope_offset, maybe_k_rope_offset, enable_pdl, additional_params); + }); + m.def("batch_prefill_with_paged_kv_cache_run", + [](at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, + at::Tensor plan_info_vec, + at::Tensor q, + at::Tensor paged_k_cache, + at::Tensor paged_v_cache, + at::Tensor qo_indptr, + at::Tensor paged_kv_indptr, + at::Tensor paged_kv_indices, + at::Tensor paged_kv_last_page_len, + at::Tensor o, + std::optional maybe_lse, + int64_t mask_mode_code, + int64_t layout, + int64_t window_left, + std::optional maybe_q_rope_offset, + bool enable_pdl, + std::vector additional_params) { + return BatchPrefillWithPagedKVCacheRun( + float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, paged_k_cache, + paged_v_cache, qo_indptr, paged_kv_indptr, paged_kv_indices, + paged_kv_last_page_len, o, maybe_lse, mask_mode_code, layout, window_left, + maybe_q_rope_offset, enable_pdl, additional_params); + }); // pod-attention // Temporarily disabled because we don't generate the implementation yet. diff --git a/flashinfer/attention.py b/flashinfer/attention.py index 224403e86..c6d3ec069 100644 --- a/flashinfer/attention.py +++ b/flashinfer/attention.py @@ -138,6 +138,7 @@ def run( lse: Optional[torch.Tensor] = None, logits_soft_cap: float = 0.0, profiler_buffer: Optional[torch.Tensor] = None, + q_rope_offset: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: if profiler_buffer is None: if self._use_profiler: @@ -180,6 +181,7 @@ def run( self._page_size, self._sm_scale, logits_soft_cap, + q_rope_offset, # ADDITIONAL_FUNC_PARAMS # PROFILER_FUNC_PARAMS *profiler_args,