Skip to content

Commit e5d342f

Browse files
authored
feat: Add k_scale and v_scale to persistent attention (#1322)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description For consistency with the general call in SGLang cc @yzh119 ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 7442e5a commit e5d342f

File tree

6 files changed

+33
-13
lines changed

6 files changed

+33
-13
lines changed

β€Žcsrc/batch_attention.cuβ€Ž

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
6868
at::Tensor v_cache, at::Tensor kv_indices, at::Tensor o,
6969
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
7070
int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads,
71-
int64_t page_size, double sm_scale,
71+
int64_t page_size,
72+
double v_scale, // must use double due to pytorch binding
73+
double sm_scale,
7274
double logits_soft_cap ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS) {
7375
HolisticPlanInfo<2> plan_info;
7476
plan_info.FromVector(tensor_to_vec(plan_info_vec));
@@ -171,6 +173,7 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
171173
params[i].v_stride_n = v_stride_n;
172174

173175
params[i].sm_scale = sm_scale;
176+
params[i].v_scale = v_scale;
174177
params[i].logits_soft_cap = logits_soft_cap;
175178
// NOTE(Wenxuan) directly using the additional_params_decl from generate_additional_params
176179
// will be problematic because of the params[i]

β€Žcsrc/batch_attention_customize_config.jinjaβ€Ž

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ struct PersistentParams {
109109
uint32_t v_stride_n;
110110

111111
float sm_scale;
112-
double logits_soft_cap;
112+
float logits_soft_cap;
113+
float v_scale;
113114
{{ additional_params_decl }}
114115

115116
PROFILER_PARAMS_DECL

β€Žcsrc/batch_attention_jit_pybind.cuβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
2828
at::Tensor v_cache, at::Tensor kv_indices, at::Tensor o,
2929
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
3030
int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads,
31-
int64_t page_size, double sm_scale,
31+
int64_t page_size, double v_scale, double sm_scale,
3232
double logits_soft_cap ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS);
3333

3434
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {

β€Žflashinfer/attention.pyβ€Ž

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ def plan(
109109
self._num_qo_heads = num_qo_heads
110110
self._num_kv_heads = num_kv_heads
111111
self._page_size = page_size
112-
self._sm_scale = sm_scale
113112
self._use_profiler = use_profiler
114113

115114
# No addtional buf allocated for CUDA graph tensor
@@ -135,6 +134,8 @@ def run(
135134
kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
136135
out: Optional[torch.Tensor] = None,
137136
lse: Optional[torch.Tensor] = None,
137+
k_scale: Optional[torch.Tensor] = None,
138+
v_scale: Optional[torch.Tensor] = None,
138139
logits_soft_cap: float = 0.0,
139140
profiler_buffer: Optional[torch.Tensor] = None,
140141
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -157,9 +158,13 @@ def run(
157158
q.shape[0], q.shape[1], device=q.device, dtype=torch.float32
158159
)
159160
head_dim_qk = q.shape[2]
160-
if self._sm_scale is None:
161-
self._sm_scale = 1.0 / math.sqrt(head_dim_qk)
162-
161+
sm_scale = self._sm_scale
162+
if sm_scale is None:
163+
sm_scale = 1.0 / math.sqrt(head_dim_qk)
164+
if k_scale is not None:
165+
sm_scale *= k_scale
166+
if v_scale is None:
167+
v_scale = 1.0
163168
# profiler_buffer is optional
164169
profiler_args = (profiler_buffer,) if self._use_profiler else ()
165170

@@ -178,7 +183,8 @@ def run(
178183
self._num_qo_heads,
179184
self._num_kv_heads,
180185
self._page_size,
181-
self._sm_scale,
186+
v_scale,
187+
sm_scale,
182188
logits_soft_cap,
183189
# ADDITIONAL_FUNC_PARAMS
184190
# PROFILER_FUNC_PARAMS

β€Žinclude/flashinfer/attention/persistent.cuhβ€Ž

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ __device__ __forceinline__ void write_o_(float (*o_frag)[KTraits::NUM_MMA_D_VO][
140140

141141
template <typename KTraits>
142142
__device__ __forceinline__ void normalize_d(float (*o_frag)[KTraits::NUM_MMA_D_VO][8],
143-
typename KTraits::DTypeQKAccum (*m)[2], float (*d)[2]) {
143+
typename KTraits::DTypeQKAccum (*m)[2], float (*d)[2],
144+
float v_scale = 1.0f) {
144145
using AttentionVariant = typename KTraits::AttentionVariant;
145146
if constexpr (AttentionVariant::use_softmax) {
146147
float d_rcp[KTraits::NUM_MMA_Q][2];
@@ -163,6 +164,9 @@ __device__ __forceinline__ void normalize_d(float (*o_frag)[KTraits::NUM_MMA_D_V
163164
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
164165
o_frag[mma_q][mma_d][reg_id] =
165166
o_frag[mma_q][mma_d][reg_id] * d_rcp[mma_q][(reg_id >> 1) & 1];
167+
if (v_scale != 1.0f) {
168+
o_frag[mma_q][mma_d][reg_id] *= v_scale;
169+
}
166170
}
167171
}
168172
}
@@ -391,7 +395,7 @@ struct BlockBatchPagedAttentionPersistent {
391395
threadblock_sync_mdo_states<KTraits>(o_frag, smem_storage, m, d, warp_idx, lane_idx, tid);
392396

393397
// normalize d
394-
normalize_d<KTraits>(o_frag, m, d);
398+
normalize_d<KTraits>(o_frag, m, d, params.v_scale);
395399

396400
// write back to global memory
397401
// o_indptr (partial_o): [packed_qo_len * num_kv_chunks, num_kv_heads, head_dim]

β€Žtests/test_batch_attention.pyβ€Ž

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def _run_attention(
9090
num_kv_heads=1,
9191
num_qo_heads=1,
9292
head_dim=128,
93+
v_scale=None,
9394
layout="NHD",
9495
test_dtype=torch.bfloat16,
9596
logits_soft_cap=0.0,
@@ -140,7 +141,7 @@ def _run_attention(
140141

141142
# --------- old scheduler --------- #
142143
wrapper_old = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
143-
torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=dev),
144+
torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=dev),
144145
kv_layout=layout,
145146
backend="fa2",
146147
)
@@ -159,7 +160,7 @@ def _run_attention(
159160
kv_data_type=test_dtype,
160161
logits_soft_cap=logits_soft_cap,
161162
)
162-
out_old, lse_old = wrapper_old.run(q, kv_data, return_lse=True)
163+
out_old, lse_old = wrapper_old.run(q, kv_data, return_lse=True, v_scale=v_scale)
163164

164165
# --------- new / mixed scheduler --------- #
165166
wrapper = flashinfer.BatchAttention(kv_layout=layout)
@@ -178,7 +179,9 @@ def _run_attention(
178179
kv_data_type=test_dtype,
179180
logits_soft_cap=logits_soft_cap,
180181
)
181-
out_new, lse_new = wrapper.run(q, kv_data, logits_soft_cap=logits_soft_cap)
182+
out_new, lse_new = wrapper.run(
183+
q, kv_data, v_scale=v_scale, logits_soft_cap=logits_soft_cap
184+
)
182185

183186
torch.cuda.synchronize()
184187
torch.testing.assert_close(out_old, out_new, rtol=1e-2, atol=1e-2)
@@ -191,6 +194,7 @@ def _run_attention(
191194
@pytest.mark.parametrize("num_kv_heads", [1, 4])
192195
@pytest.mark.parametrize("gqa_group_size", [1, 4, 7, 8])
193196
@pytest.mark.parametrize("head_dim", [64, 128, 256])
197+
@pytest.mark.parametrize("v_scale", [2.0, None])
194198
@pytest.mark.parametrize("causal", [False, True])
195199
@pytest.mark.parametrize("layout", ["HND", "NHD"])
196200
@pytest.mark.parametrize("test_dtype", [torch.bfloat16, torch.float16])
@@ -201,6 +205,7 @@ def test_batch_attention_correctness(
201205
num_kv_heads,
202206
gqa_group_size,
203207
head_dim,
208+
v_scale,
204209
causal,
205210
layout,
206211
test_dtype,
@@ -217,6 +222,7 @@ def test_batch_attention_correctness(
217222
num_kv_heads=num_kv_heads,
218223
num_qo_heads=num_qo_heads,
219224
head_dim=head_dim,
225+
v_scale=v_scale,
220226
causal=causal,
221227
layout=layout,
222228
test_dtype=test_dtype,

0 commit comments

Comments
Β (0)