Skip to content

Commit 50dcd9a

Browse files
authored
feat: Support scale factor start index for fp4 mha prefill/decode (#1363)
<!-- .github/pull_request_template.md --> ## 📌 Description the start index of fp4 output scale factor `o_sf_start_index` is useful when the decode kernels are reusing the scale factor tensor of prefill kernels. It can write from a offset even though the scale factor is swizzled. This is a follow up of #1360, please only review the latest commit. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent db16e41 commit 50dcd9a

File tree

7 files changed

+119
-17
lines changed

7 files changed

+119
-17
lines changed

csrc/trtllm_fmha_kernel_launcher.cu

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ void trtllm_paged_attention_launcher(
8080
int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk,
8181
int64_t head_dim_vo, int64_t page_size, int64_t kv_stride_keys_values, int64_t kv_stride_heads,
8282
int64_t kv_stride_batch, int64_t max_num_blocks_per_seq, double bmm1_scale, double bmm2_scale,
83-
double o_sf_scale, int64_t o_sf_vec_size, int64_t window_left, int64_t sum_seq_q,
84-
int64_t sm_count, cudaStream_t stream) {
83+
double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t window_left,
84+
int64_t sum_seq_q, int64_t sm_count, cudaStream_t stream) {
8585
if (num_qo_heads % num_kv_heads != 0) {
8686
std::ostringstream err_msg;
8787
err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads
@@ -118,6 +118,7 @@ void trtllm_paged_attention_launcher(
118118
runner_params.outputScale = bmm2_scale;
119119
runner_params.scaleSoftmaxLog2 = bmm1_scale * M_LOG2E;
120120
runner_params.oSfPtr = out_scale_factor;
121+
runner_params.mSfStartTokenIdx = o_sf_start_index;
121122
runner_params.mScaleSfO = o_sf_scale;
122123
TORCH_CHECK(o_sf_vec_size == 16 || o_sf_vec_size == -1,
123124
"Only support o_sf_vec_size == 16 or -1(not used)");
@@ -189,7 +190,8 @@ void trtllm_paged_attention_decode(at::Tensor out, std::optional<at::Tensor> out
189190
at::Tensor workspace_buffer, at::Tensor block_tables,
190191
at::Tensor seq_lens, int64_t max_kv_len, double bmm1_scale,
191192
double bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size,
192-
int64_t window_left, int64_t sm_count) {
193+
int64_t o_sf_start_index, int64_t window_left,
194+
int64_t sm_count) {
193195
auto q_data_type = torch_dtype_to_tllm_data_type(query.scalar_type());
194196
auto kv_data_type = torch_dtype_to_tllm_data_type(key_value_cache.scalar_type());
195197
auto o_data_type = torch_dtype_to_tllm_data_type(out.scalar_type());
@@ -242,15 +244,17 @@ void trtllm_paged_attention_decode(at::Tensor out, std::optional<at::Tensor> out
242244
TllmPagedAttentionMode::ForGen, batch_size, /*max_q_len=*/q_len_per_request, max_kv_len,
243245
num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, page_size,
244246
kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale,
245-
bmm2_scale, o_sf_scale, o_sf_vec_size, window_left, sum_seq_q, sm_count, stream);
247+
bmm2_scale, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count,
248+
stream);
246249
}
247250

248251
void trtllm_paged_attention_context(at::Tensor out, std::optional<at::Tensor> out_scale_factor,
249252
at::Tensor query, at::Tensor key_value_cache,
250253
at::Tensor workspace_buffer, at::Tensor block_tables,
251254
at::Tensor seq_lens, int64_t max_q_len, int64_t max_kv_len,
252255
double bmm1_scale, double bmm2_scale, double o_sf_scale,
253-
int64_t o_sf_vec_size, int64_t batch_size, int64_t window_left,
256+
int64_t o_sf_vec_size, int64_t o_sf_start_index,
257+
int64_t batch_size, int64_t window_left,
254258
at::Tensor cum_seq_lens_q, at::Tensor cum_seq_lens_kv,
255259
int64_t sm_count) {
256260
auto q_data_type = torch_dtype_to_tllm_data_type(query.scalar_type());
@@ -299,7 +303,8 @@ void trtllm_paged_attention_context(at::Tensor out, std::optional<at::Tensor> ou
299303
o_data_type, TllmPagedAttentionMode::Context, batch_size, max_q_len, max_kv_len,
300304
num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, page_size,
301305
kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale,
302-
bmm2_scale, o_sf_scale, o_sf_vec_size, window_left, sum_seq_q, sm_count, stream);
306+
bmm2_scale, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count,
307+
stream);
303308
}
304309

305310
namespace trtllm_cubin_loader {

flashinfer/decode.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1821,6 +1821,7 @@ def _paged_run(
18211821
bmm2_scale,
18221822
-1, # o_sf_scale
18231823
-1, # o_sf_vec_size
1824+
0, # o_sf_start_index
18241825
window_left,
18251826
self._sm_count,
18261827
)
@@ -2021,12 +2022,14 @@ def trtllm_batch_decode_with_kv_cache(
20212022

20222023
if isinstance(out, FP4Tensor):
20232024
out_scale_factor = out.scale
2025+
o_sf_start_index = out.scale_start_index
20242026
out = out.data
20252027
elif out is None:
2026-
out = torch.empty(fp4_out_shape, dtype=torch.uint8, device=query.device)
20272028
out_scale_factor = torch.empty(
20282029
fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=query.device
20292030
)
2031+
o_sf_start_index = 0
2032+
out = torch.empty(fp4_out_shape, dtype=torch.uint8, device=query.device)
20302033
else:
20312034
raise ValueError(f"Invalid out: {out}")
20322035

@@ -2044,6 +2047,7 @@ def trtllm_batch_decode_with_kv_cache(
20442047
assert o_sf_scale is None
20452048
assert o_sf_vec_size is None
20462049
out_scale_factor = None
2050+
o_sf_start_index = 0
20472051
out_dtype = out_dtype or query.dtype
20482052
out = out if out is not None else torch.empty_like(query, dtype=out_dtype)
20492053
_check_shape_dtype_device(out, query.shape, query.dtype, query.device, "out")
@@ -2063,12 +2067,15 @@ def trtllm_batch_decode_with_kv_cache(
20632067
bmm2_scale,
20642068
o_sf_scale or -1.0,
20652069
o_sf_vec_size or -1,
2070+
o_sf_start_index,
20662071
window_left,
20672072
sm_count,
20682073
)
20692074

20702075
return (
2071-
out if out_dtype != "nvfp4" else FP4Tensor(out, out_scale_factor, query.shape)
2076+
out
2077+
if out_dtype != "nvfp4"
2078+
else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape)
20722079
)
20732080

20742081

@@ -2217,6 +2224,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
22172224
bmm2_scale,
22182225
-1, # o_sf_scale
22192226
-1, # o_sf_vec_size
2227+
0, # o_sf_start_index
22202228
-1, # window_left
22212229
sm_count,
22222230
)

flashinfer/prefill.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def _paged_run(
128128
bmm2_scale,
129129
-1, # o_sf_scale
130130
-1, # o_sf_vec_size
131+
0, # o_sf_start_index
131132
batch_size,
132133
window_left,
133134
cum_seq_lens_q,
@@ -3017,12 +3018,14 @@ def trtllm_batch_context_with_kv_cache(
30173018

30183019
if isinstance(out, FP4Tensor):
30193020
out_scale_factor = out.scale
3021+
o_sf_start_index = out.scale_start_index
30203022
out = out.data
30213023
elif out is None:
3022-
out = torch.empty(fp4_out_shape, dtype=torch.uint8, device=query.device)
30233024
out_scale_factor = torch.empty(
30243025
fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=query.device
30253026
)
3027+
o_sf_start_index = 0
3028+
out = torch.empty(fp4_out_shape, dtype=torch.uint8, device=query.device)
30263029
else:
30273030
raise ValueError(f"Invalid out: {out}")
30283031

@@ -3040,6 +3043,7 @@ def trtllm_batch_context_with_kv_cache(
30403043
assert o_sf_scale is None
30413044
assert o_sf_vec_size is None
30423045
out_scale_factor = None
3046+
o_sf_start_index = 0
30433047
out_dtype = out_dtype or query.dtype
30443048
out = out if out is not None else torch.empty_like(query, dtype=out_dtype)
30453049
_check_shape_dtype_device(out, query.shape, query.dtype, query.device, "out")
@@ -3060,12 +3064,15 @@ def trtllm_batch_context_with_kv_cache(
30603064
bmm2_scale,
30613065
o_sf_scale or -1.0,
30623066
o_sf_vec_size or -1,
3067+
o_sf_start_index,
30633068
batch_size,
30643069
window_left,
30653070
cum_seq_lens_q,
30663071
cum_seq_lens_kv,
30673072
sm_count,
30683073
)
30693074
return (
3070-
out if out_dtype != "nvfp4" else FP4Tensor(out, out_scale_factor, query.shape)
3075+
out
3076+
if out_dtype != "nvfp4"
3077+
else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape)
30713078
)

flashinfer/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,7 @@ def __init__(
524524
self,
525525
data: torch.Tensor,
526526
scale: torch.Tensor,
527+
scale_start_index: int = 0,
527528
original_shape: Optional[Tuple[int, ...]] = None,
528529
):
529530
"""Initialize FP4Tensor.
@@ -534,6 +535,8 @@ def __init__(
534535
uint8 tensor storing the compressed FP4 data
535536
scale : torch.Tensor
536537
float8_e4m3fn tensor storing the scale factors
538+
scale_start_index : int
539+
The start token index of the scale factors. This is needed when two kernels (like prefill and decode kernels) are reusing the same scale factor tensor with different offsets.
537540
original_shape : Optional[Tuple[int, ...]]
538541
The original shape before compression.
539542
"""
@@ -561,6 +564,7 @@ def __init__(
561564

562565
self.data = data
563566
self.scale = scale
567+
self.scale_start_index = scale_start_index
564568
self.original_shape = original_shape
565569
self.dtype = "nvfp4"
566570

tests/test_trtllm_gen_context.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@
55
from utils_fp4 import cast_from_fp4, recover_swizzled_scales, ref_nvfp4_quant
66

77
import flashinfer
8+
from flashinfer.utils import FP4Tensor
9+
10+
11+
def flip_coin(*args, **kwargs):
12+
# Use any test parameters to deterministically decide branch
13+
# This makes test configurations go through different paths
14+
param_tuple = args + tuple(sorted(kwargs.items()))
15+
hash_value = hash(param_tuple)
16+
return (hash_value % 2) == 0
817

918

1019
def to_float8(x, dtype=torch.float8_e4m3fn):
@@ -327,7 +336,7 @@ def test_trtllm_batch_prefill(
327336
o_sf_scale = (
328337
300 if o_dtype == "nvfp4" else None
329338
) # choose a value to make error smaller by testing.
330-
339+
o_sf_vec_size = 16 if o_dtype == "nvfp4" else None
331340
sm_scale = float(1.0 / (head_dim**0.5))
332341

333342
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
@@ -344,6 +353,29 @@ def test_trtllm_batch_prefill(
344353
]
345354
)
346355

356+
if flip_coin(batch_size, page_size, num_kv_heads, head_grp_size, o_dtype):
357+
if o_dtype == "nvfp4":
358+
fp4_out_shape = q.shape[:-1] + (math.ceil(q.shape[-1] / 2),)
359+
360+
fp4_out_scale_shape = (
361+
math.ceil(q.shape[0] / 128) * 128,
362+
math.ceil(q.shape[1] * q.shape[2] / o_sf_vec_size / 4) * 4,
363+
)
364+
365+
out_scale_factor = torch.empty(
366+
fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=q.device
367+
)
368+
extra_size = fp4_out_scale_shape[0] - q.shape[0]
369+
o_sf_start_index = (
370+
torch.randint(0, extra_size, (1,)).item() if extra_size > 0 else 0
371+
)
372+
out_data = torch.empty(fp4_out_shape, dtype=torch.uint8, device=q.device)
373+
out = FP4Tensor(out_data, out_scale_factor, o_sf_start_index)
374+
else:
375+
out = torch.empty_like(q, dtype=dtype_map[o_dtype])
376+
else:
377+
out = None
378+
347379
output = flashinfer.prefill.trtllm_batch_context_with_kv_cache(
348380
q.contiguous(),
349381
kv_cache,
@@ -358,14 +390,16 @@ def test_trtllm_batch_prefill(
358390
q_indptr,
359391
kv_indptr,
360392
window_left, # window_left
393+
out=out,
361394
out_dtype=dtype_map[o_dtype],
362395
o_sf_scale=o_sf_scale,
363-
o_sf_vec_size=16 if o_dtype == "nvfp4" else None,
396+
o_sf_vec_size=o_sf_vec_size,
364397
)
365398

366399
# Handle different return types based on out_dtype
367400
if o_dtype == "nvfp4":
368401
out_scale_factor = output.scale # FP4Tensor.scale
402+
o_sf_start_index = output.scale_start_index
369403
output = output.data # FP4Tensor.data
370404
else:
371405
out_scale_factor = None
@@ -407,7 +441,11 @@ def test_trtllm_batch_prefill(
407441
output = cast_from_fp4(output)
408442
output_ref, out_scale_factor_ref = ref_nvfp4_quant(output_ref, o_sf_scale, 16)
409443
out_scale_factor = recover_swizzled_scales(
410-
out_scale_factor, output.shape[0], output.shape[1] * output.shape[2], 16
444+
out_scale_factor,
445+
output.shape[0],
446+
output.shape[1] * output.shape[2],
447+
16,
448+
o_sf_start_index,
411449
)
412450

413451
torch.testing.assert_close(

tests/test_trtllm_gen_decode.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@
77
from utils_fp4 import cast_from_fp4, recover_swizzled_scales, ref_nvfp4_quant
88

99
import flashinfer
10+
from flashinfer.utils import FP4Tensor
11+
12+
13+
def flip_coin(*args, **kwargs):
14+
# Use any test parameters to deterministically decide branch
15+
# This makes test configurations go through different paths
16+
param_tuple = args + tuple(sorted(kwargs.items()))
17+
hash_value = hash(param_tuple)
18+
return (hash_value % 2) == 0
1019

1120

1221
def to_float8(x, dtype=torch.float8_e4m3fn):
@@ -224,6 +233,7 @@ def test_trtllm_batch_decode_fmha(
224233
o_sf_scale = (
225234
300 if o_dtype == "nvfp4" else None
226235
) # choose a value to make error smaller by testing.
236+
o_sf_vec_size = 16 if o_dtype == "nvfp4" else None
227237

228238
sm_scale = float(1.0 / (head_dim**0.5))
229239

@@ -237,6 +247,29 @@ def test_trtllm_batch_decode_fmha(
237247
]
238248
)
239249

250+
if flip_coin(batch_size, page_size, num_kv_heads, head_grp_size, o_dtype):
251+
if o_dtype == "nvfp4":
252+
fp4_out_shape = q.shape[:-1] + (math.ceil(q.shape[-1] / 2),)
253+
254+
fp4_out_scale_shape = (
255+
math.ceil(q.shape[0] / 128) * 128,
256+
math.ceil(q.shape[1] * q.shape[2] / o_sf_vec_size / 4) * 4,
257+
)
258+
259+
out_scale_factor = torch.empty(
260+
fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=q.device
261+
)
262+
extra_size = fp4_out_scale_shape[0] - q.shape[0]
263+
o_sf_start_index = (
264+
torch.randint(0, extra_size, (1,)).item() if extra_size > 0 else 0
265+
)
266+
out_data = torch.empty(fp4_out_shape, dtype=torch.uint8, device=q.device)
267+
out = FP4Tensor(out_data, out_scale_factor, o_sf_start_index)
268+
else:
269+
out = torch.empty_like(q, dtype=dtype_map[o_dtype])
270+
else:
271+
out = None
272+
240273
output = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
241274
q.contiguous(),
242275
kv_cache,
@@ -247,14 +280,16 @@ def test_trtllm_batch_decode_fmha(
247280
q_scale * k_scale * sm_scale, # bmm1_scale
248281
v_scale / o_scale, # bmm2_scale
249282
window_left, # window_left
283+
out=out,
250284
out_dtype=dtype_map[o_dtype],
251285
o_sf_scale=o_sf_scale,
252-
o_sf_vec_size=16 if o_dtype == "nvfp4" else None,
286+
o_sf_vec_size=o_sf_vec_size,
253287
)
254288

255289
# Handle different return types based on out_dtype
256290
if o_dtype == "nvfp4":
257291
out_scale_factor = output.scale # FP4Tensor.scale
292+
o_sf_start_index = output.scale_start_index
258293
output = output.data # FP4Tensor.data
259294
else:
260295
out_scale_factor = None
@@ -297,7 +332,11 @@ def test_trtllm_batch_decode_fmha(
297332
output = cast_from_fp4(output)
298333
output_ref, out_scale_factor_ref = ref_nvfp4_quant(output_ref, o_sf_scale, 16)
299334
out_scale_factor = recover_swizzled_scales(
300-
out_scale_factor, output.shape[0], output.shape[1] * output.shape[2], 16
335+
out_scale_factor,
336+
output.shape[0],
337+
output.shape[1] * output.shape[2],
338+
16,
339+
o_sf_start_index,
301340
)
302341

303342
torch.testing.assert_close(

tests/utils_fp4.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,13 @@ def ref_nvfp4_quant(x, global_scale, block_size):
8484
return cast_to_fp4(clipped_x), scale.squeeze(-1)
8585

8686

87-
def recover_swizzled_scales(scale, m, n, block_size):
87+
def recover_swizzled_scales(scale, m, n, block_size, sf_start_index=0):
88+
assert sf_start_index + m <= scale.shape[0]
8889
rounded_m = utils.round_up(m, 128)
8990
scale_n = n // block_size
9091
rounded_n = utils.round_up(scale_n, 4)
9192
# Recover the swizzled scaling factor to linear layout
9293
tmp = torch.reshape(scale, (1, rounded_m // 128, rounded_n // 4, 32, 4, 4))
9394
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
9495
result = torch.reshape(tmp, (rounded_m, rounded_n)).to(torch.float32)
95-
return result[:m, :scale_n]
96+
return result[sf_start_index : sf_start_index + m, :scale_n]

0 commit comments

Comments
 (0)