Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions csrc/batch_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
at::Tensor v_cache, at::Tensor kv_indices, at::Tensor o,
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t page_size, double sm_scale,
double logits_soft_cap ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS) {
int64_t page_size, double sm_scale, double logits_soft_cap,
int64_t window_left ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS) {
HolisticPlanInfo<2> plan_info;
plan_info.FromVector(tensor_to_vec(plan_info_vec));

Expand Down Expand Up @@ -111,7 +111,8 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE,
AttentionVariant, PersistentParams, [&] {
PersistentParams params[2];

IdType* len_kv_chunk =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.len_kv_chunk_offset);
for (int i = 0; i < 2; i++) {
params[i].q = static_cast<DTypeQ*>(q.data_ptr());
params[i].k = static_cast<DTypeKV*>(k_cache.data_ptr());
Expand All @@ -138,8 +139,7 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].kv_head_idx_offset);
params[i].work_indptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].work_indptr_offset);
params[i].len_kv_chunk =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].len_kv_chunk_offset);
params[i].len_kv_chunk = len_kv_chunk + i;

params[i].final_o = static_cast<DTypeO*>(o.data_ptr());
params[i].final_lse =
Expand Down Expand Up @@ -172,6 +172,7 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo

params[i].sm_scale = sm_scale;
params[i].logits_soft_cap = logits_soft_cap;
params[i].window_left = window_left;
// NOTE(Wenxuan) directly using the additional_params_decl from generate_additional_params
// will be problematic because of the params[i]
ADDITIONAL_PARAMS_SETTER
Expand Down
29 changes: 22 additions & 7 deletions csrc/batch_attention_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <flashinfer/fastdiv.cuh>
#include <flashinfer/attention/variant_helper.cuh>
#include <flashinfer/profiler.cuh>

using namespace flashinfer;

#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }}
Expand All @@ -22,29 +21,44 @@ using namespace flashinfer;

{{ variant_decl }}

template <bool UseLogitsSoftCap>
template <bool use_sliding_window, bool use_logits_soft_cap>
struct StandardAttention : AttentionVariantBase {
float sm_scale_log2;
float soft_cap_pre_tanh_scale;
static constexpr bool use_logits_soft_cap = UseLogitsSoftCap;
static constexpr bool UseLogitsSoftCap = use_logits_soft_cap;
static constexpr bool UseSlidingWindow = use_sliding_window;

uint32_t window_left;
float sm_scale_log2, soft_cap_pre_tanh_scale;
PROFILER_CLOSURE_PARAMS_DECL

template <typename Params>
__device__ __host__ StandardAttention(const Params& params, uint32_t batch_idx,
uint8_t* smem_ptr) {
if constexpr (UseLogitsSoftCap) {
if constexpr (use_logits_soft_cap) {
soft_cap_pre_tanh_scale = params.sm_scale * math::ptx_rcp(params.logits_soft_cap);
sm_scale_log2 = math::log2e * params.logits_soft_cap;
}else{
sm_scale_log2 = params.sm_scale * math::log2e;
}

}
REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, {
if constexpr (UseLogitsSoftCap) {
if constexpr (use_logits_soft_cap) {
logits = float(math::tanh(logits * soft_cap_pre_tanh_scale));
}
return logits;
})

REGISTER_LOGITS_MASK(params, work_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, {
bool mask = true;
// TODO: to support custom mask (only used by Spec decoding in SGL), must register request_indices in plan info
if constexpr (use_sliding_window) {
uint32_t qo_len = params.q_len[work_idx];
uint32_t kv_len = params.kv_len[work_idx];
window_left = (params.window_left >= 0) ? params.window_left : kv_len;
mask &= (kv_idx + qo_len + window_left >= kv_len + qo_idx);
}
return mask;
Comment on lines +55 to +60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The member variable window_left is only used within this mask generation logic. It's better to declare it as a local const variable to improve code clarity and prevent accidental misuse elsewhere. This also removes the need for the member declaration at line 30.

    if constexpr (use_sliding_window) {
      uint32_t qo_len = params.q_len[work_idx];
      uint32_t kv_len = params.kv_len[work_idx];
      const auto window_left_val = (params.window_left >= 0) ? static_cast<uint32_t>(params.window_left) : kv_len;
      mask &= (kv_idx + qo_len + window_left_val >= kv_len + qo_idx);
    }

})
};

#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, AttentionVariant, Params, ...) \
Expand Down Expand Up @@ -110,6 +124,7 @@ struct PersistentParams {

float sm_scale;
double logits_soft_cap;
int64_t window_left;
{{ additional_params_decl }}

PROFILER_PARAMS_DECL
Expand Down
4 changes: 2 additions & 2 deletions csrc/batch_attention_jit_pybind.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
at::Tensor v_cache, at::Tensor kv_indices, at::Tensor o,
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t page_size, double sm_scale,
double logits_soft_cap ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS);
int64_t page_size, double sm_scale, double logits_soft_cap,
int64_t window_left ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS);

TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
m.def("plan", &BatchPagedAttentionPlan);
Expand Down
57 changes: 55 additions & 2 deletions flashinfer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,57 @@ def plan(
page_size: int,
causal: bool = False,
sm_scale: float = None,
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
q_data_type: torch.dtype = torch.bfloat16,
kv_data_type: torch.dtype = torch.bfloat16,
use_profiler: bool = False,
) -> None:
"""Plan batch persistent attention on Paged KV-Cache with mixed prefill and decode batch.
Parameters
----------
qo_indptr : torch.Tensor
The indptr of the query/output tensor, shape: ``[batch_size + 1]``.
kv_indptr : torch.Tensor
The indptr of the paged kv-cache, shape: ``[batch_size + 1]``.
kv_indices : torch.Tensor
The page indices of the paged kv-cache, shape: ``[qo_indptr[-1]]``.
kv_len_arr : torch.Tensor
The kv length of each request, shape: ``[batch_size]``. Will be used in place of last_page_len.
num_qo_heads : int
The number of query/output heads.
num_kv_heads : int
The number of key/value heads.
head_dim_qk : int
The dimension of the query/key heads.
head_dim_vo : int
The dimension of the value/output heads.
page_size : int
The size of each page in the paged kv-cache.
causal : bool
Whether to apply causal mask to the attention matrix.
sm_scale : float
The scale used in softmax, if not provided, will be set to
``1.0 / sqrt(head_dim)``.
window_left : int
The left (inclusive) window size for the attention window, when set to ``-1``, the window
size will be set to the full length of the sequence. Defaults to ``-1``.
logits_soft_cap : Optional[float]
The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not
provided, will be set to ``0``. If greater than 0, the logits will be capped according to
formula:
:math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`,
where :math:`x` is the input logits.
q_data_type : Union[str, torch.dtype]
The data type of the query tensor, defaults torch.float16.
kv_data_type : Optional[Union[str, torch.dtype]]
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
use_profiler : bool
Whether to use the CTA-level profiler, defaults to ``False``.
"""

if logits_soft_cap is None:
logits_soft_cap = 0.0
self._logits_soft_cap = logits_soft_cap

# get jit module
get_module_args = (
Expand All @@ -93,6 +136,7 @@ def plan(
head_dim_vo,
PosEncodingMode["NONE"].value,
logits_soft_cap > 0.0,
window_left >= 0, # use_sliding_window
use_profiler, # different compiler path
)
self.module = get_holistic_attention_module(*get_module_args)
Expand All @@ -111,6 +155,8 @@ def plan(
self._page_size = page_size
self._sm_scale = sm_scale
self._use_profiler = use_profiler
self._logits_soft_cap = logits_soft_cap
self._window_left = window_left

# No addtional buf allocated for CUDA graph tensor
# Allocate outside FlashInfer
Expand All @@ -137,6 +183,7 @@ def run(
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
logits_soft_cap: float = 0.0,
window_left: int = -1,
profiler_buffer: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if profiler_buffer is None:
Expand All @@ -146,7 +193,11 @@ def run(
)
if logits_soft_cap > 0.0 and self._logits_soft_cap <= 0.0:
raise ValueError(
"logits_soft_cap used in kernel run but not provided in plan(). This will cause template deduction error."
"logits_soft_cap used in kernel run but not provided in plan(). This will cause the wrong template used."
)
if window_left >= 0 and self._window_left < 0:
raise ValueError(
"Sliding window attention used in kernel run but not provided in plan(). This will cause the wrong template used."
)

k_cache, v_cache = _unpack_paged_kv_cache(kv_cache, self._kv_layout)
Expand All @@ -162,6 +213,7 @@ def run(

# profiler_buffer is optional
profiler_args = (profiler_buffer,) if self._use_profiler else ()
window_left = window_left if window_left is not None else self._window_left

self.module.run(
self.float_workspace_buffer,
Expand All @@ -180,6 +232,7 @@ def run(
self._page_size,
self._sm_scale,
logits_soft_cap,
window_left,
# ADDITIONAL_FUNC_PARAMS
# PROFILER_FUNC_PARAMS
*profiler_args,
Expand Down
9 changes: 8 additions & 1 deletion flashinfer/jit/attention/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def get_batch_attention_uri(
head_dim_qk: int,
head_dim_vo: int,
pos_encoding_mode: int,
use_sliding_window: bool,
use_logits_soft_cap: bool,
use_profiler: bool,
) -> str:
Expand All @@ -406,6 +407,7 @@ def get_batch_attention_uri(
f"head_dim_qk_{head_dim_qk}_"
f"head_dim_vo_{head_dim_vo}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{use_sliding_window}_"
f"use_logits_soft_cap_{str(use_logits_soft_cap).lower()}_"
f"use_profiler_{str(use_profiler).lower()}"
)
Expand Down Expand Up @@ -864,6 +866,7 @@ def gen_batch_attention_module(
head_dim_vo: int,
pos_encoding_mode: int,
use_logits_soft_cap: bool,
use_sliding_window: bool,
use_profiler: bool,
):
uri = get_batch_attention_uri(
Expand All @@ -875,14 +878,15 @@ def gen_batch_attention_module(
head_dim_vo,
pos_encoding_mode,
use_logits_soft_cap,
use_sliding_window,
use_profiler,
)

additional_tensor_names = []
additional_tensor_dtypes = []
additional_scalar_names = []
additional_scalar_dtypes = []
variant_name = f"StandardAttention<{str(use_logits_soft_cap).lower()}>"
variant_name = f"StandardAttention<{str(use_sliding_window).lower()}, {str(use_logits_soft_cap).lower()}>"
variant_decl = f"#include<flashinfer/attention/variants.cuh>"

return gen_customize_batch_attention_module(
Expand All @@ -901,6 +905,7 @@ def gen_batch_attention_module(
variant_decl,
pos_encoding_mode=pos_encoding_mode,
use_logits_soft_cap=use_logits_soft_cap,
use_sliding_window=use_sliding_window,
use_profiler=use_profiler,
)

Expand Down Expand Up @@ -1513,6 +1518,7 @@ def gen_customize_batch_attention_module(
variant_decl: str,
pos_encoding_mode: int = 0,
use_logits_soft_cap: bool = False,
use_sliding_window: bool = False,
use_profiler: bool = False,
):
kwargs = {
Expand All @@ -1526,6 +1532,7 @@ def gen_customize_batch_attention_module(
"head_dim_vo": head_dim_vo,
"pos_encoding_mode": pos_encoding_mode_literal[pos_encoding_mode],
"use_logits_soft_cap": str(use_logits_soft_cap).lower(),
"use_sliding_window": str(use_sliding_window).lower(),
}
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
(additional_params_decl, additional_func_params, additional_params_setter) = (
Expand Down
50 changes: 39 additions & 11 deletions include/flashinfer/attention/persistent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ __device__ __forceinline__ auto get_block_coord(const Params& params, const uint
params.partial_indptr[work_idx], params.q_len[work_idx],
params.kv_len[work_idx], params.q_start[work_idx], params.kv_start[work_idx],
params.kv_end[work_idx], params.kv_head_idx_arr[work_idx],
params.len_kv_chunk[work_idx]);
*params.len_kv_chunk);
}

template <typename KTraits>
Expand Down Expand Up @@ -269,7 +269,26 @@ struct BlockBatchPagedAttentionPersistent {
: kv_end) /
CTA_TILE_KV -
(kv_start / CTA_TILE_KV);

// int window_tile_idx =
// (CAUSAL ? min(kv_end, kv_len - q_len + ceil_div(packed_qo_start, gqa_group_size))
// : kv_end) /
// CTA_TILE_KV -
// (q_len + params.window_left + kv_start) / CTA_TILE_KV;
// window_tile_idx = params.window_left > 0 ? window_tile_idx : 0;
// if ( blockIdx.x == 0 && blockIdx.y == 0) {
// printf("kv_tile_idx: %d, mask_tile_idx: %d, window_tile_idx: %d\n", kv_tile_idx,
// mask_tile_idx, window_tile_idx);
// }
int window_tile_idx = 0;
if (params.window_left > 0) {
window_tile_idx =
ceil_div(
min(kv_end, kv_len + ceil_div(packed_qo_start + cluster_tile_q, gqa_group_size)),
CTA_TILE_KV) -
1 - (q_len + params.window_left + kv_start) / CTA_TILE_KV;
}
auto mask_tile_idx_ = mask_tile_idx, window_tile_idx_ = window_tile_idx,
kv_tile_idx_ = kv_tile_idx;
uint32_t block_iter_base = kv_indptr * block_size + kv_start;
// last kv tile
__syncthreads();
Expand All @@ -289,7 +308,8 @@ struct BlockBatchPagedAttentionPersistent {

// loop with mask
LOOP_SPLIT_MASK(
kv_tile_idx, kv_tile_idx >= mask_tile_idx && kv_tile_idx > 0,
kv_tile_idx,
(kv_tile_idx >= mask_tile_idx || kv_tile_idx < window_tile_idx) && kv_tile_idx > 0,
kv_tile_idx + 1 > NUM_STAGES, {
prefetch_offest<KTraits>(block_iter_base + (kv_tile_idx - 1) * CTA_TILE_KV,
packed_kv_bound, kv_head_idx, k_stride_page, k_stride_h,
Expand All @@ -298,16 +318,16 @@ struct BlockBatchPagedAttentionPersistent {
__syncthreads();

compute_qk<KTraits>(&q_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag);
if constexpr (AttentionVariant::use_logits_soft_cap) {
if constexpr (AttentionVariant::UseLogitsSoftCap) {
logits_transform<KTraits>(
params, variant, /*batch_idx=*/0, qo_packed_idx_base,
kv_start + (kv_tile_idx * NUM_WARPS_KV + get_warp_idx_kv<KTraits>(tid.z)) *
NUM_MMA_KV * 16,
q_len, kv_len, gqa_group_size, s_frag, tid, kv_head_idx);
}
if constexpr (WITH_MASK) {
logits_mask<KTraits>(
params, variant, /*batch_idx=*/0, qo_packed_idx_base,
logits_mask<KTraits>( // work_idx is for window_left
params, variant, /*work_idx=*/work_idx, qo_packed_idx_base,
kv_start + (kv_tile_idx * NUM_WARPS_KV + get_warp_idx_kv<KTraits>(tid.z)) *
NUM_MMA_KV * 16,
q_len, kv_len, kv_end, gqa_group_size, s_frag, tid, kv_head_idx);
Expand Down Expand Up @@ -336,22 +356,30 @@ struct BlockBatchPagedAttentionPersistent {
#pragma unroll
for (; kv_tile_idx >= 0; --kv_tile_idx) {
compute_qk<KTraits>(&q_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag);
if constexpr (AttentionVariant::use_logits_soft_cap) {
if constexpr (AttentionVariant::UseLogitsSoftCap) {
logits_transform<KTraits>(
params, variant, /*batch_idx=*/0, qo_packed_idx_base,
kv_start +
(kv_tile_idx * NUM_WARPS_KV + get_warp_idx_kv<KTraits>(tid.z)) * NUM_MMA_KV * 16,
q_len, kv_len, gqa_group_size, s_frag, tid, kv_head_idx);
}
// if (kv_tile_idx < window_tile_idx) { // TODO: check why adding this leads to nan
logits_mask<KTraits>(
params, variant, /*batch_idx=*/0, qo_packed_idx_base,
params, variant, /*work_idx=*/work_idx, qo_packed_idx_base,
kv_start +
(kv_tile_idx * NUM_WARPS_KV + get_warp_idx_kv<KTraits>(tid.z)) * NUM_MMA_KV * 16,
q_len, kv_len, kv_end, gqa_group_size, s_frag, tid, kv_head_idx);
update_mdo_states<KTraits>(variant, s_frag, o_frag, m, d);
compute_sfm_v<KTraits>(&v_smem, &v_smem_offset_r, s_frag, o_frag, d);
// }
// update_mdo_states<KTraits>(variant, s_frag, o_frag, m, d);
// compute_sfm_v<KTraits>(&v_smem, &v_smem_offset_r, s_frag, o_frag, d);
}
if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {
printf(
"kv_start:%d, kv_end:%d, kv_tile_idx: %d, mask_tile_idx: %d, window_tile_idx: %d, "
"o_frag: %f\n",
kv_start, kv_end, kv_tile_idx, mask_tile_idx_, window_tile_idx,
o_frag[NUM_MMA_Q - 1][NUM_MMA_KV - 1][7]);
}

__syncthreads();

finalize_m<KTraits>(variant, m);
Expand Down
Loading