Skip to content

Commit bf3445f

Browse files
feat: Support logits_soft_cap for Persistent attn; fix kv split limit (#1324)
<!-- .github/pull_request_template.md --> ## 📌 Description When integrating this kernel into SGLang, I quickly hit an assertion error with input len 4000, output len 200 and 8 request/s due to the hard limit of 4 kv splits per tile size per SM. This PR fixes the constraint. <img width="652" height="52" alt="image" src="https://github.com/user-attachments/assets/de570432-5de0-4a82-9612-8dec51b9338a" /> It also adds support for logits_soft_cap, which is used by Gemma model in sgl cc @happierpig @yzh119 ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 5ec5371 commit bf3445f

File tree

10 files changed

+86
-24
lines changed

10 files changed

+86
-24
lines changed

benchmarks/bench_batch_attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def synthesize_seq_len_configs() -> List[List[Tuple[int, int]]]:
102102
[(8192, 1)] * 128, # decode-only
103103
[(4096, 128)] * 4, # prefill-only
104104
[(600, 1)] * 122 + [(10_000, 17)] * 8, # hybird
105-
[(8192, 1)] * 127 * 2 + [(2048, 512)] * 1, # hybrid (chunked-prefill)
105+
[(8192, 1)] * 127 * 2 + [(8192, 4096)] * 1, # hybrid (chunked-prefill)
106106
]
107107

108108
def _rand_case(bsz: int, lo: int, hi: int) -> List[Tuple[int, int]]:
@@ -198,6 +198,7 @@ def main() -> None:
198198
],
199199
)
200200
print(df.to_markdown(index=False, floatfmt=".2f"))
201+
df.to_csv("bench_batch_attention.csv", index=False)
201202

202203

203204
if __name__ == "__main__":

csrc/batch_attention.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
6868
at::Tensor v_cache, at::Tensor kv_indices, at::Tensor o,
6969
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
7070
int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads,
71-
int64_t page_size,
72-
double sm_scale ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS) {
71+
int64_t page_size, double sm_scale,
72+
double logits_soft_cap ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS) {
7373
HolisticPlanInfo<2> plan_info;
7474
plan_info.FromVector(tensor_to_vec(plan_info_vec));
7575

@@ -171,7 +171,9 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
171171
params[i].v_stride_n = v_stride_n;
172172

173173
params[i].sm_scale = sm_scale;
174-
174+
params[i].logits_soft_cap = logits_soft_cap;
175+
// NOTE(Wenxuan) directly using the additional_params_decl from generate_additional_params
176+
// will be problematic because of the params[i]
175177
ADDITIONAL_PARAMS_SETTER
176178
PROFILER_PARAMS_SETTER
177179
}

csrc/batch_attention_customize_config.jinja

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,29 @@ using namespace flashinfer;
2222

2323
{{ variant_decl }}
2424

25+
template <bool UseLogitsSoftCap>
2526
struct StandardAttention : AttentionVariantBase {
2627
float sm_scale_log2;
27-
28+
float soft_cap_pre_tanh_scale;
29+
static constexpr bool use_logits_soft_cap = UseLogitsSoftCap;
2830
PROFILER_CLOSURE_PARAMS_DECL
2931

3032
template <typename Params>
3133
__device__ __host__ StandardAttention(const Params& params, uint32_t batch_idx,
3234
uint8_t* smem_ptr) {
33-
sm_scale_log2 = params.sm_scale * math::log2e;
35+
if constexpr (UseLogitsSoftCap) {
36+
soft_cap_pre_tanh_scale = params.sm_scale * math::ptx_rcp(params.logits_soft_cap);
37+
sm_scale_log2 = math::log2e * params.logits_soft_cap;
38+
}else{
39+
sm_scale_log2 = params.sm_scale * math::log2e;
40+
}
3441
}
42+
REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, {
43+
if constexpr (UseLogitsSoftCap) {
44+
logits = float(math::tanh(logits * soft_cap_pre_tanh_scale));
45+
}
46+
return logits;
47+
})
3548
};
3649

3750
#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, AttentionVariant, Params, ...) \
@@ -96,7 +109,7 @@ struct PersistentParams {
96109
uint32_t v_stride_n;
97110

98111
float sm_scale;
99-
112+
double logits_soft_cap;
100113
{{ additional_params_decl }}
101114

102115
PROFILER_PARAMS_DECL

csrc/batch_attention_jit_pybind.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
2828
at::Tensor v_cache, at::Tensor kv_indices, at::Tensor o,
2929
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
3030
int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads,
31-
int64_t page_size,
32-
double sm_scale ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS);
31+
int64_t page_size, double sm_scale,
32+
double logits_soft_cap ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS);
3333

3434
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
3535
m.def("plan", &BatchPagedAttentionPlan);

flashinfer/attention.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,15 @@ def plan(
7474
page_size: int,
7575
causal: bool = False,
7676
sm_scale: float = None,
77+
logits_soft_cap: Optional[float] = None,
7778
q_data_type: torch.dtype = torch.bfloat16,
7879
kv_data_type: torch.dtype = torch.bfloat16,
7980
use_profiler: bool = False,
8081
) -> None:
82+
if logits_soft_cap is None:
83+
logits_soft_cap = 0.0
84+
self._logits_soft_cap = logits_soft_cap
85+
8186
# get jit module
8287
get_module_args = (
8388
q_data_type,
@@ -87,6 +92,7 @@ def plan(
8792
head_dim_qk,
8893
head_dim_vo,
8994
PosEncodingMode["NONE"].value,
95+
logits_soft_cap > 0.0,
9096
use_profiler, # different compiler path
9197
)
9298
self.module = get_holistic_attention_module(*get_module_args)
@@ -130,13 +136,19 @@ def run(
130136
kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
131137
out: Optional[torch.Tensor] = None,
132138
lse: Optional[torch.Tensor] = None,
139+
logits_soft_cap: float = 0.0,
133140
profiler_buffer: Optional[torch.Tensor] = None,
134141
) -> Tuple[torch.Tensor, torch.Tensor]:
135142
if profiler_buffer is None:
136143
if self._use_profiler:
137144
raise ValueError(
138145
"Profiler is enabled, profiler_buffer must be provided"
139146
)
147+
if logits_soft_cap > 0.0 and self._logits_soft_cap <= 0.0:
148+
raise ValueError(
149+
"logits_soft_cap used in kernel run but not provided in plan(). This will cause template deduction error."
150+
)
151+
140152
k_cache, v_cache = _unpack_paged_kv_cache(kv_cache, self._kv_layout)
141153
if out is None:
142154
out = torch.empty_like(q)
@@ -167,6 +179,9 @@ def run(
167179
self._num_kv_heads,
168180
self._page_size,
169181
self._sm_scale,
182+
logits_soft_cap,
183+
# ADDITIONAL_FUNC_PARAMS
184+
# PROFILER_FUNC_PARAMS
170185
*profiler_args,
171186
)
172187

flashinfer/jit/attention/pytorch.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,7 @@ def get_batch_attention_uri(
395395
head_dim_qk: int,
396396
head_dim_vo: int,
397397
pos_encoding_mode: int,
398+
use_logits_soft_cap: bool,
398399
use_profiler: bool,
399400
) -> str:
400401
return (
@@ -405,6 +406,7 @@ def get_batch_attention_uri(
405406
f"head_dim_qk_{head_dim_qk}_"
406407
f"head_dim_vo_{head_dim_vo}_"
407408
f"posenc_{pos_encoding_mode}_"
409+
f"use_logits_soft_cap_{str(use_logits_soft_cap).lower()}_"
408410
f"use_profiler_{str(use_profiler).lower()}"
409411
)
410412

@@ -861,6 +863,7 @@ def gen_batch_attention_module(
861863
head_dim_qk: int,
862864
head_dim_vo: int,
863865
pos_encoding_mode: int,
866+
use_logits_soft_cap: bool,
864867
use_profiler: bool,
865868
):
866869
uri = get_batch_attention_uri(
@@ -871,14 +874,15 @@ def gen_batch_attention_module(
871874
head_dim_qk,
872875
head_dim_vo,
873876
pos_encoding_mode,
877+
use_logits_soft_cap,
874878
use_profiler,
875879
)
876880

877881
additional_tensor_names = []
878882
additional_tensor_dtypes = []
879883
additional_scalar_names = []
880884
additional_scalar_dtypes = []
881-
variant_name = f"StandardAttention"
885+
variant_name = f"StandardAttention<{str(use_logits_soft_cap).lower()}>"
882886
variant_decl = f"#include<flashinfer/attention/variants.cuh>"
883887

884888
return gen_customize_batch_attention_module(
@@ -896,6 +900,7 @@ def gen_batch_attention_module(
896900
variant_name,
897901
variant_decl,
898902
pos_encoding_mode=pos_encoding_mode,
903+
use_logits_soft_cap=use_logits_soft_cap,
899904
use_profiler=use_profiler,
900905
)
901906

@@ -1507,6 +1512,7 @@ def gen_customize_batch_attention_module(
15071512
variant_name: str,
15081513
variant_decl: str,
15091514
pos_encoding_mode: int = 0,
1515+
use_logits_soft_cap: bool = False,
15101516
use_profiler: bool = False,
15111517
):
15121518
kwargs = {
@@ -1519,6 +1525,7 @@ def gen_customize_batch_attention_module(
15191525
"head_dim_qk": head_dim_qk,
15201526
"head_dim_vo": head_dim_vo,
15211527
"pos_encoding_mode": pos_encoding_mode_literal[pos_encoding_mode],
1528+
"use_logits_soft_cap": str(use_logits_soft_cap).lower(),
15221529
}
15231530
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
15241531
(additional_params_decl, additional_func_params, additional_params_setter) = (
@@ -1529,7 +1536,6 @@ def gen_customize_batch_attention_module(
15291536
additional_scalar_dtypes,
15301537
)
15311538
)
1532-
15331539
with open(
15341540
jit_env.FLASHINFER_CSRC_DIR / "batch_attention_customize_config.jinja"
15351541
) as f:

include/flashinfer/attention/persistent.cuh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,13 @@ struct BlockBatchPagedAttentionPersistent {
298298
__syncthreads();
299299

300300
compute_qk<KTraits>(&q_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag);
301+
if constexpr (AttentionVariant::use_logits_soft_cap) {
302+
logits_transform<KTraits>(
303+
params, variant, /*batch_idx=*/0, qo_packed_idx_base,
304+
kv_start + (kv_tile_idx * NUM_WARPS_KV + get_warp_idx_kv<KTraits>(tid.z)) *
305+
NUM_MMA_KV * 16,
306+
q_len, kv_len, gqa_group_size, s_frag, tid, kv_head_idx);
307+
}
301308
if constexpr (WITH_MASK) {
302309
logits_mask<KTraits>(
303310
params, variant, /*batch_idx=*/0, qo_packed_idx_base,
@@ -329,6 +336,13 @@ struct BlockBatchPagedAttentionPersistent {
329336
#pragma unroll
330337
for (; kv_tile_idx >= 0; --kv_tile_idx) {
331338
compute_qk<KTraits>(&q_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag);
339+
if constexpr (AttentionVariant::use_logits_soft_cap) {
340+
logits_transform<KTraits>(
341+
params, variant, /*batch_idx=*/0, qo_packed_idx_base,
342+
kv_start +
343+
(kv_tile_idx * NUM_WARPS_KV + get_warp_idx_kv<KTraits>(tid.z)) * NUM_MMA_KV * 16,
344+
q_len, kv_len, gqa_group_size, s_frag, tid, kv_head_idx);
345+
}
332346
logits_mask<KTraits>(
333347
params, variant, /*batch_idx=*/0, qo_packed_idx_base,
334348
kv_start +

include/flashinfer/attention/prefill.cuh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2287,7 +2287,6 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice(
22872287

22882288
// compute attention score
22892289
compute_qk<KTraits>(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag);
2290-
22912290
logits_transform<KTraits>(
22922291
params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base,
22932292
chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv<KTraits>(tid.z)) * NUM_MMA_KV * 16,

include/flashinfer/attention/scheduler.cuh

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,8 +1145,8 @@ inline cudaError_t TwoStageHolisticPlan(void* float_buffer, size_t float_workspa
11451145
AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes);
11461146

11471147
// NOTE(Zihao): adjust it later
1148-
const int max_total_num_works = 16384;
1149-
const int max_packed_qo_lens =
1148+
const int max_total_num_works = 65536;
1149+
const int max_num_kv_splits =
11501150
4 * num_clusters * cluster_size * (CTA_TILE_Q_SIZES[0] + CTA_TILE_Q_SIZES[1]);
11511151

11521152
// calculate kv_len_limit first, considering all workloads
@@ -1167,7 +1167,7 @@ inline cudaError_t TwoStageHolisticPlan(void* float_buffer, size_t float_workspa
11671167
}
11681168

11691169
// used for remapping the output offsets
1170-
// layout [packed_qo_len x num_kv_tiels, num_kv_heads, head_dim]
1170+
// layout [packed_qo_len x num_kv_tiles, num_kv_heads, head_dim]
11711171
int partial_o_nnz = 0;
11721172
std::vector<IdType> merge_indptr, merge_o_indices, num_expand_qo_len_vec;
11731173
merge_indptr.push_back(partial_o_nnz);
@@ -1251,6 +1251,12 @@ inline cudaError_t TwoStageHolisticPlan(void* float_buffer, size_t float_workspa
12511251
work_indptr_vec[i + 1] = work_indptr_vec[i] + cluster_q_indptr[i].size();
12521252
}
12531253
int total_num_works = work_indptr_vec.back();
1254+
if (total_num_works > max_total_num_works) {
1255+
std::ostringstream err_msg;
1256+
err_msg << "total_num_works (#q tiles * #kv tiles) " << total_num_works
1257+
<< " exceeds max_total_num_works " << max_total_num_works;
1258+
FLASHINFER_ERROR(err_msg.str());
1259+
}
12541260
auto q_indptr_vec = flatten(cluster_q_indptr, total_num_works);
12551261
auto kv_indptr_vec = flatten(cluster_kv_indptr, total_num_works);
12561262
auto partial_indptr_vec = flatten(cluster_partial_indptr, total_num_works);
@@ -1306,20 +1312,20 @@ inline cudaError_t TwoStageHolisticPlan(void* float_buffer, size_t float_workspa
13061312
len_kv_chunk_vec);
13071313
}
13081314

1309-
if (partial_o_nnz > max_packed_qo_lens) {
1315+
if (merge_indptr.size() > max_num_kv_splits) {
13101316
std::ostringstream err_msg;
1311-
err_msg << "partial_o_nnz " << partial_o_nnz << " exceeds max_packed_qo_lens "
1312-
<< max_packed_qo_lens;
1317+
err_msg << "Number of kv splits " << merge_indptr.size() << " exceeds max buffer size "
1318+
<< max_num_kv_splits << ". Please increase the threshold.";
13131319
FLASHINFER_ERROR(err_msg.str());
13141320
}
13151321

13161322
// update num_qo_len_vec
13171323
num_expand_qo_len_vec.push_back(merge_indptr.size() - 1);
13181324
// allocate buffer for state merge function
13191325
plan_info.merge_indptr_offset =
1320-
int_allocator.aligned_alloc_offset(sizeof(IdType) * max_packed_qo_lens, 16, "merge_indptr");
1321-
plan_info.merge_o_indices_offset = int_allocator.aligned_alloc_offset(
1322-
sizeof(IdType) * max_packed_qo_lens, 16, "merge_o_indices");
1326+
int_allocator.aligned_alloc_offset(sizeof(IdType) * max_num_kv_splits, 16, "merge_indptr");
1327+
plan_info.merge_o_indices_offset =
1328+
int_allocator.aligned_alloc_offset(sizeof(IdType) * max_num_kv_splits, 16, "merge_o_indices");
13231329
plan_info.num_qo_len_offset =
13241330
int_allocator.aligned_alloc_offset(sizeof(IdType), 16, "num_qo_len_offset");
13251331
// copy data to paged cpu buffer
@@ -1336,9 +1342,9 @@ inline cudaError_t TwoStageHolisticPlan(void* float_buffer, size_t float_workspa
13361342
// Note(Yilong): adjust it later
13371343
AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes);
13381344
plan_info.partial_o_offset = float_allocator.aligned_alloc_offset(
1339-
2 * max_packed_qo_lens * sizeof_dtype_o * head_dim, 16, "holistic_partial_o");
1345+
2 * max_num_kv_splits * sizeof_dtype_o * head_dim, 16, "holistic_partial_o");
13401346
plan_info.partial_lse_offset = float_allocator.aligned_alloc_offset(
1341-
2 * max_packed_qo_lens * sizeof(float), 16, "holistic_partial_lse");
1347+
2 * max_num_kv_splits * sizeof(float), 16, "holistic_partial_lse");
13421348

13431349
return cudaSuccess;
13441350
}

tests/test_batch_attention.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def _run_attention(
6565
head_dim=128,
6666
layout="NHD",
6767
test_dtype=torch.bfloat16,
68+
logits_soft_cap=0.0,
6869
device="cuda",
6970
causal=True,
7071
):
@@ -127,6 +128,7 @@ def _run_attention(
127128
causal=causal,
128129
q_data_type=test_dtype,
129130
kv_data_type=test_dtype,
131+
logits_soft_cap=logits_soft_cap,
130132
)
131133
out_old, lse_old = wrapper_old.run(q, kv_data, return_lse=True)
132134

@@ -145,8 +147,9 @@ def _run_attention(
145147
causal=causal,
146148
q_data_type=test_dtype,
147149
kv_data_type=test_dtype,
150+
logits_soft_cap=logits_soft_cap,
148151
)
149-
out_new, lse_new = wrapper.run(q, kv_data)
152+
out_new, lse_new = wrapper.run(q, kv_data, logits_soft_cap=logits_soft_cap)
150153

151154
torch.cuda.synchronize()
152155
torch.testing.assert_close(out_old, out_new, rtol=1e-2, atol=1e-2)
@@ -161,6 +164,7 @@ def _run_attention(
161164
@pytest.mark.parametrize("causal", [False, True])
162165
@pytest.mark.parametrize("layout", ["HND", "NHD"])
163166
@pytest.mark.parametrize("test_dtype", [torch.bfloat16, torch.float16])
167+
@pytest.mark.parametrize("logits_soft_cap", [0.0, 50.0])
164168
def test_batch_attention_correctness(
165169
seq_len_pairs,
166170
page_block_size,
@@ -170,6 +174,7 @@ def test_batch_attention_correctness(
170174
causal,
171175
layout,
172176
test_dtype,
177+
logits_soft_cap,
173178
):
174179
num_qo_heads = num_kv_heads * gqa_group_size
175180
kv_lens = [p[0] for p in seq_len_pairs]
@@ -185,5 +190,6 @@ def test_batch_attention_correctness(
185190
causal=causal,
186191
layout=layout,
187192
test_dtype=test_dtype,
193+
logits_soft_cap=logits_soft_cap,
188194
device="cuda",
189195
)

0 commit comments

Comments
 (0)