Skip to content

Commit 2aa451e

Browse files
committed
fix
1 parent e3fc005 commit 2aa451e

File tree

7 files changed

+102
-21
lines changed

7 files changed

+102
-21
lines changed

csrc/batch_attention.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
6868
at::Tensor v_cache, at::Tensor kv_indices, at::Tensor o,
6969
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
7070
int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads,
71-
int64_t page_size, double sm_scale,
72-
double logits_soft_cap ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS) {
71+
int64_t page_size, double sm_scale, double logits_soft_cap,
72+
int64_t window_left ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS) {
7373
HolisticPlanInfo<2> plan_info;
7474
plan_info.FromVector(tensor_to_vec(plan_info_vec));
7575

@@ -172,6 +172,7 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
172172

173173
params[i].sm_scale = sm_scale;
174174
params[i].logits_soft_cap = logits_soft_cap;
175+
params[i].window_left = window_left;
175176
// NOTE(Wenxuan) directly using the additional_params_decl from generate_additional_params
176177
// will be problematic because of the params[i]
177178
ADDITIONAL_PARAMS_SETTER

csrc/batch_attention_customize_config.jinja

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include <flashinfer/fastdiv.cuh>
88
#include <flashinfer/attention/variant_helper.cuh>
99
#include <flashinfer/profiler.cuh>
10-
10+
#include <cassert>
1111
using namespace flashinfer;
1212

1313
#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }}
@@ -22,29 +22,44 @@ using namespace flashinfer;
2222

2323
{{ variant_decl }}
2424

25-
template <bool UseLogitsSoftCap>
25+
template <bool use_sliding_window, bool use_logits_soft_cap>
2626
struct StandardAttention : AttentionVariantBase {
27-
float sm_scale_log2;
28-
float soft_cap_pre_tanh_scale;
29-
static constexpr bool use_logits_soft_cap = UseLogitsSoftCap;
27+
static constexpr bool UseLogitsSoftCap = use_logits_soft_cap;
28+
static constexpr bool UseSlidingWindow = use_sliding_window;
29+
30+
uint32_t window_left;
31+
float sm_scale_log2, soft_cap_pre_tanh_scale;
3032
PROFILER_CLOSURE_PARAMS_DECL
3133

3234
template <typename Params>
3335
__device__ __host__ StandardAttention(const Params& params, uint32_t batch_idx,
3436
uint8_t* smem_ptr) {
35-
if constexpr (UseLogitsSoftCap) {
37+
if constexpr (use_logits_soft_cap) {
3638
soft_cap_pre_tanh_scale = params.sm_scale * math::ptx_rcp(params.logits_soft_cap);
3739
sm_scale_log2 = math::log2e * params.logits_soft_cap;
3840
}else{
3941
sm_scale_log2 = params.sm_scale * math::log2e;
4042
}
43+
4144
}
4245
REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, {
43-
if constexpr (UseLogitsSoftCap) {
46+
if constexpr (use_logits_soft_cap) {
4447
logits = float(math::tanh(logits * soft_cap_pre_tanh_scale));
4548
}
4649
return logits;
4750
})
51+
52+
REGISTER_LOGITS_MASK(params, work_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, {
53+
bool mask = true;
54+
// TODO: to support custom mask (only used by Spec decoding in SGL), must register request_indices in plan info
55+
if constexpr (use_sliding_window) {
56+
uint32_t qo_len = params.q_len[work_idx];
57+
uint32_t kv_len = params.kv_len[work_idx];
58+
window_left = (params.window_left >= 0) ? params.window_left : kv_len;
59+
mask &= (kv_idx + qo_len + window_left >= kv_len + qo_idx);
60+
}
61+
return mask;
62+
})
4863
};
4964

5065
#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, AttentionVariant, Params, ...) \
@@ -110,6 +125,7 @@ struct PersistentParams {
110125

111126
float sm_scale;
112127
double logits_soft_cap;
128+
uint32_t window_left;
113129
{{ additional_params_decl }}
114130

115131
PROFILER_PARAMS_DECL

csrc/batch_attention_jit_pybind.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
2828
at::Tensor v_cache, at::Tensor kv_indices, at::Tensor o,
2929
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
3030
int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads,
31-
int64_t page_size, double sm_scale,
32-
double logits_soft_cap ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS);
31+
int64_t page_size, double sm_scale, double logits_soft_cap,
32+
int64_t window_left ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS);
3333

3434
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
3535
m.def("plan", &BatchPagedAttentionPlan);

flashinfer/attention.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,57 @@ def plan(
7474
page_size: int,
7575
causal: bool = False,
7676
sm_scale: float = None,
77+
window_left: int = -1,
7778
logits_soft_cap: Optional[float] = None,
7879
q_data_type: torch.dtype = torch.bfloat16,
7980
kv_data_type: torch.dtype = torch.bfloat16,
8081
use_profiler: bool = False,
8182
) -> None:
83+
"""Plan batch persistent attention on Paged KV-Cache with mixed prefill and decode batch.
84+
Parameters
85+
----------
86+
qo_indptr : torch.Tensor
87+
The indptr of the query/output tensor, shape: ``[batch_size + 1]``.
88+
kv_indptr : torch.Tensor
89+
The indptr of the paged kv-cache, shape: ``[batch_size + 1]``.
90+
kv_indices : torch.Tensor
91+
The page indices of the paged kv-cache, shape: ``[qo_indptr[-1]]``.
92+
kv_len_arr : torch.Tensor
93+
The kv length of each request, shape: ``[batch_size]``. Will be used in place of last_page_len.
94+
num_qo_heads : int
95+
The number of query/output heads.
96+
num_kv_heads : int
97+
The number of key/value heads.
98+
head_dim_qk : int
99+
The dimension of the query/key heads.
100+
head_dim_vo : int
101+
The dimension of the value/output heads.
102+
page_size : int
103+
The size of each page in the paged kv-cache.
104+
causal : bool
105+
Whether to apply causal mask to the attention matrix.
106+
sm_scale : float
107+
The scale used in softmax, if not provided, will be set to
108+
``1.0 / sqrt(head_dim)``.
109+
window_left : int
110+
The left (inclusive) window size for the attention window, when set to ``-1``, the window
111+
size will be set to the full length of the sequence. Defaults to ``-1``.
112+
logits_soft_cap : Optional[float]
113+
The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not
114+
provided, will be set to ``0``. If greater than 0, the logits will be capped according to
115+
formula:
116+
:math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`,
117+
where :math:`x` is the input logits.
118+
q_data_type : Union[str, torch.dtype]
119+
The data type of the query tensor, defaults torch.float16.
120+
kv_data_type : Optional[Union[str, torch.dtype]]
121+
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
122+
use_profiler : bool
123+
Whether to use the profiler, defaults to ``False``.
124+
"""
125+
82126
if logits_soft_cap is None:
83127
logits_soft_cap = 0.0
84-
self._logits_soft_cap = logits_soft_cap
85128

86129
# get jit module
87130
get_module_args = (
@@ -93,6 +136,7 @@ def plan(
93136
head_dim_vo,
94137
PosEncodingMode["NONE"].value,
95138
logits_soft_cap > 0.0,
139+
window_left >= 0, # use_sliding_window
96140
use_profiler, # different compiler path
97141
)
98142
self.module = get_holistic_attention_module(*get_module_args)
@@ -111,6 +155,8 @@ def plan(
111155
self._page_size = page_size
112156
self._sm_scale = sm_scale
113157
self._use_profiler = use_profiler
158+
self._logits_soft_cap = logits_soft_cap
159+
self._window_left = window_left
114160

115161
# No addtional buf allocated for CUDA graph tensor
116162
# Allocate outside FlashInfer
@@ -137,6 +183,7 @@ def run(
137183
out: Optional[torch.Tensor] = None,
138184
lse: Optional[torch.Tensor] = None,
139185
logits_soft_cap: float = 0.0,
186+
window_left: Optional[int] = None,
140187
profiler_buffer: Optional[torch.Tensor] = None,
141188
) -> Tuple[torch.Tensor, torch.Tensor]:
142189
if profiler_buffer is None:
@@ -146,7 +193,7 @@ def run(
146193
)
147194
if logits_soft_cap > 0.0 and self._logits_soft_cap <= 0.0:
148195
raise ValueError(
149-
"logits_soft_cap used in kernel run but not provided in plan(). This will cause template deduction error."
196+
"logits_soft_cap used in kernel run but not provided in plan(). This will cause the wrong template used."
150197
)
151198

152199
k_cache, v_cache = _unpack_paged_kv_cache(kv_cache, self._kv_layout)
@@ -162,6 +209,7 @@ def run(
162209

163210
# profiler_buffer is optional
164211
profiler_args = (profiler_buffer,) if self._use_profiler else ()
212+
window_left = window_left if window_left is not None else self._window_left
165213

166214
self.module.run(
167215
self.float_workspace_buffer,
@@ -180,6 +228,7 @@ def run(
180228
self._page_size,
181229
self._sm_scale,
182230
logits_soft_cap,
231+
window_left,
183232
# ADDITIONAL_FUNC_PARAMS
184233
# PROFILER_FUNC_PARAMS
185234
*profiler_args,

flashinfer/jit/attention/pytorch.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,7 @@ def get_batch_attention_uri(
395395
head_dim_qk: int,
396396
head_dim_vo: int,
397397
pos_encoding_mode: int,
398+
use_sliding_window: bool,
398399
use_logits_soft_cap: bool,
399400
use_profiler: bool,
400401
) -> str:
@@ -406,6 +407,7 @@ def get_batch_attention_uri(
406407
f"head_dim_qk_{head_dim_qk}_"
407408
f"head_dim_vo_{head_dim_vo}_"
408409
f"posenc_{pos_encoding_mode}_"
410+
f"use_swa_{use_sliding_window}_"
409411
f"use_logits_soft_cap_{str(use_logits_soft_cap).lower()}_"
410412
f"use_profiler_{str(use_profiler).lower()}"
411413
)
@@ -864,6 +866,7 @@ def gen_batch_attention_module(
864866
head_dim_vo: int,
865867
pos_encoding_mode: int,
866868
use_logits_soft_cap: bool,
869+
use_sliding_window: bool,
867870
use_profiler: bool,
868871
):
869872
uri = get_batch_attention_uri(
@@ -875,14 +878,15 @@ def gen_batch_attention_module(
875878
head_dim_vo,
876879
pos_encoding_mode,
877880
use_logits_soft_cap,
881+
use_sliding_window,
878882
use_profiler,
879883
)
880884

881885
additional_tensor_names = []
882886
additional_tensor_dtypes = []
883887
additional_scalar_names = []
884888
additional_scalar_dtypes = []
885-
variant_name = f"StandardAttention<{str(use_logits_soft_cap).lower()}>"
889+
variant_name = f"StandardAttention<{str(use_sliding_window).lower()}, {str(use_logits_soft_cap).lower()}>"
886890
variant_decl = f"#include<flashinfer/attention/variants.cuh>"
887891

888892
return gen_customize_batch_attention_module(
@@ -901,6 +905,7 @@ def gen_batch_attention_module(
901905
variant_decl,
902906
pos_encoding_mode=pos_encoding_mode,
903907
use_logits_soft_cap=use_logits_soft_cap,
908+
use_sliding_window=use_sliding_window,
904909
use_profiler=use_profiler,
905910
)
906911

@@ -1513,6 +1518,7 @@ def gen_customize_batch_attention_module(
15131518
variant_decl: str,
15141519
pos_encoding_mode: int = 0,
15151520
use_logits_soft_cap: bool = False,
1521+
use_sliding_window: bool = False,
15161522
use_profiler: bool = False,
15171523
):
15181524
kwargs = {
@@ -1526,6 +1532,7 @@ def gen_customize_batch_attention_module(
15261532
"head_dim_vo": head_dim_vo,
15271533
"pos_encoding_mode": pos_encoding_mode_literal[pos_encoding_mode],
15281534
"use_logits_soft_cap": str(use_logits_soft_cap).lower(),
1535+
"use_sliding_window": str(use_sliding_window).lower(),
15291536
}
15301537
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
15311538
(additional_params_decl, additional_func_params, additional_params_setter) = (

include/flashinfer/attention/persistent.cuh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -298,16 +298,16 @@ struct BlockBatchPagedAttentionPersistent {
298298
__syncthreads();
299299

300300
compute_qk<KTraits>(&q_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag);
301-
if constexpr (AttentionVariant::use_logits_soft_cap) {
301+
if constexpr (AttentionVariant::UseLogitsSoftCap) {
302302
logits_transform<KTraits>(
303303
params, variant, /*batch_idx=*/0, qo_packed_idx_base,
304304
kv_start + (kv_tile_idx * NUM_WARPS_KV + get_warp_idx_kv<KTraits>(tid.z)) *
305305
NUM_MMA_KV * 16,
306306
q_len, kv_len, gqa_group_size, s_frag, tid, kv_head_idx);
307307
}
308308
if constexpr (WITH_MASK) {
309-
logits_mask<KTraits>(
310-
params, variant, /*batch_idx=*/0, qo_packed_idx_base,
309+
logits_mask<KTraits>( // work_idx is for window_left
310+
params, variant, /*work_idx=*/work_idx, qo_packed_idx_base,
311311
kv_start + (kv_tile_idx * NUM_WARPS_KV + get_warp_idx_kv<KTraits>(tid.z)) *
312312
NUM_MMA_KV * 16,
313313
q_len, kv_len, kv_end, gqa_group_size, s_frag, tid, kv_head_idx);
@@ -336,15 +336,15 @@ struct BlockBatchPagedAttentionPersistent {
336336
#pragma unroll
337337
for (; kv_tile_idx >= 0; --kv_tile_idx) {
338338
compute_qk<KTraits>(&q_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag);
339-
if constexpr (AttentionVariant::use_logits_soft_cap) {
339+
if constexpr (AttentionVariant::UseLogitsSoftCap) {
340340
logits_transform<KTraits>(
341341
params, variant, /*batch_idx=*/0, qo_packed_idx_base,
342342
kv_start +
343343
(kv_tile_idx * NUM_WARPS_KV + get_warp_idx_kv<KTraits>(tid.z)) * NUM_MMA_KV * 16,
344344
q_len, kv_len, gqa_group_size, s_frag, tid, kv_head_idx);
345345
}
346346
logits_mask<KTraits>(
347-
params, variant, /*batch_idx=*/0, qo_packed_idx_base,
347+
params, variant, /*work_idx=*/work_idx, qo_packed_idx_base,
348348
kv_start +
349349
(kv_tile_idx * NUM_WARPS_KV + get_warp_idx_kv<KTraits>(tid.z)) * NUM_MMA_KV * 16,
350350
q_len, kv_len, kv_end, gqa_group_size, s_frag, tid, kv_head_idx);

tests/test_batch_attention.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def _run_attention(
6666
layout="NHD",
6767
test_dtype=torch.bfloat16,
6868
logits_soft_cap=0.0,
69+
window_left=-1,
6970
device="cuda",
7071
causal=True,
7172
):
@@ -129,8 +130,11 @@ def _run_attention(
129130
q_data_type=test_dtype,
130131
kv_data_type=test_dtype,
131132
logits_soft_cap=logits_soft_cap,
133+
window_left=window_left,
134+
)
135+
out_old, lse_old = wrapper_old.run(
136+
q, kv_data, return_lse=True, window_left=window_left
132137
)
133-
out_old, lse_old = wrapper_old.run(q, kv_data, return_lse=True)
134138

135139
# --------- new / mixed scheduler --------- #
136140
wrapper = flashinfer.BatchAttention(kv_layout=layout)
@@ -148,6 +152,7 @@ def _run_attention(
148152
q_data_type=test_dtype,
149153
kv_data_type=test_dtype,
150154
logits_soft_cap=logits_soft_cap,
155+
window_left=window_left,
151156
)
152157
out_new, lse_new = wrapper.run(q, kv_data, logits_soft_cap=logits_soft_cap)
153158

@@ -157,14 +162,15 @@ def _run_attention(
157162

158163
# ------------------------- PyTest test case ----------------------------- #
159164
@pytest.mark.parametrize("seq_len_pairs", _build_seq_len_configs())
160-
@pytest.mark.parametrize("page_block_size", [1, 8, 16])
165+
@pytest.mark.parametrize("page_block_size", [8, 16])
161166
@pytest.mark.parametrize("num_kv_heads", [1, 4])
162167
@pytest.mark.parametrize("gqa_group_size", [1, 4, 7])
163168
@pytest.mark.parametrize("head_dim", [64, 128, 256])
164169
@pytest.mark.parametrize("causal", [False, True])
165170
@pytest.mark.parametrize("layout", ["HND", "NHD"])
166171
@pytest.mark.parametrize("test_dtype", [torch.bfloat16, torch.float16])
167172
@pytest.mark.parametrize("logits_soft_cap", [0.0, 50.0])
173+
@pytest.mark.parametrize("window_left", [13, -1])
168174
def test_batch_attention_correctness(
169175
seq_len_pairs,
170176
page_block_size,
@@ -175,6 +181,7 @@ def test_batch_attention_correctness(
175181
layout,
176182
test_dtype,
177183
logits_soft_cap,
184+
window_left,
178185
):
179186
num_qo_heads = num_kv_heads * gqa_group_size
180187
kv_lens = [p[0] for p in seq_len_pairs]
@@ -191,5 +198,6 @@ def test_batch_attention_correctness(
191198
layout=layout,
192199
test_dtype=test_dtype,
193200
logits_soft_cap=logits_soft_cap,
201+
window_left=window_left,
194202
device="cuda",
195203
)

0 commit comments

Comments
 (0)