Skip to content

Commit 74b8785

Browse files
authored
bugfix: Fix TRTLLM NVFP4-out attention kernel scale factor dim issue (#1460)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Fixed the shape checking for FP4 scale factor tensor. After #1363, we could pass `o_sf_start_index` to write the scale factor shared by prefill and decode kernel. Current implementation still assumes the batch dim of scale factor tensor is the same with query, but it should be a combination of both prefill and decode scale. This PR fixed the checking, as well as do the correct swizzling recovery. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 7146ebc commit 74b8785

File tree

5 files changed

+81
-23
lines changed

5 files changed

+81
-23
lines changed

flashinfer/decode.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
is_float8,
6161
register_custom_op,
6262
register_fake_op,
63+
ceil_div,
64+
round_up,
6365
)
6466

6567

@@ -2085,18 +2087,21 @@ def trtllm_batch_decode_with_kv_cache(
20852087
assert o_sf_vec_size in [None, 16], "only o_sf_vec_size = 16 is supported"
20862088
o_sf_vec_size = o_sf_vec_size or 16
20872089

2088-
fp4_out_shape = query.shape[:-1] + (math.ceil(query.shape[-1] / 2),)
2089-
2090-
fp4_out_scale_shape = (
2091-
math.ceil(query.shape[0] / 128) * 128,
2092-
math.ceil(query.shape[1] * query.shape[2] / o_sf_vec_size / 4) * 4,
2093-
)
2090+
fp4_out_shape = query.shape[:-1] + (ceil_div(query.shape[-1], 2),)
20942091

20952092
if isinstance(out, FP4Tensor):
2093+
fp4_out_scale_shape = (
2094+
out.scale.shape[0],
2095+
round_up(query.shape[1] * query.shape[2] // o_sf_vec_size, 4),
2096+
)
20962097
out_scale_factor = out.scale
20972098
o_sf_start_index = out.scale_start_index
20982099
out = out.data
20992100
elif out is None:
2101+
fp4_out_scale_shape = (
2102+
round_up(query.shape[0], 128),
2103+
round_up(query.shape[1] * query.shape[2] // o_sf_vec_size, 4),
2104+
)
21002105
out_scale_factor = torch.empty(
21012106
fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=query.device
21022107
)
@@ -2105,16 +2110,30 @@ def trtllm_batch_decode_with_kv_cache(
21052110
else:
21062111
raise ValueError(f"Invalid out: {out}")
21072112

2108-
_check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.device, "out")
2113+
assert isinstance(out, torch.Tensor)
21092114

21102115
# Use uint8 as the container dtype to compliant with next fp4 gemm.
2116+
_check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.device, "out")
2117+
21112118
_check_shape_dtype_device(
21122119
out_scale_factor,
21132120
fp4_out_scale_shape,
21142121
torch.float8_e4m3fn,
21152122
query.device,
21162123
"out_scale_factor",
21172124
)
2125+
2126+
# Check o_sf_start_index is valid
2127+
if (
2128+
o_sf_start_index < 0
2129+
or o_sf_start_index + out.shape[0] > out_scale_factor.shape[0]
2130+
):
2131+
raise ValueError(
2132+
f"o_sf_start_index is out of the valid range of out_scale_factor. "
2133+
f"o_sf_start_index={o_sf_start_index}, out.shape[0]={out.shape[0]}, "
2134+
f"out_scale_factor.shape[0]={out_scale_factor.shape[0]}"
2135+
)
2136+
21182137
elif isinstance(out_dtype, torch.dtype) or out_dtype is None:
21192138
assert o_sf_scale is None
21202139
assert o_sf_vec_size is None

flashinfer/prefill.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
is_sm100a_supported,
5757
register_custom_op,
5858
register_fake_op,
59+
ceil_div,
60+
round_up,
5961
)
6062

6163

@@ -3216,18 +3218,21 @@ def trtllm_batch_context_with_kv_cache(
32163218
assert o_sf_vec_size in [None, 16], "only o_sf_vec_size = 16 is supported"
32173219
o_sf_vec_size = o_sf_vec_size or 16
32183220

3219-
fp4_out_shape = query.shape[:-1] + (math.ceil(query.shape[-1] / 2),)
3220-
3221-
fp4_out_scale_shape = (
3222-
math.ceil(query.shape[0] / 128) * 128,
3223-
math.ceil(query.shape[1] * query.shape[2] / o_sf_vec_size / 4) * 4,
3224-
)
3221+
fp4_out_shape = query.shape[:-1] + (ceil_div(query.shape[-1], 2),)
32253222

32263223
if isinstance(out, FP4Tensor):
3224+
fp4_out_scale_shape = (
3225+
out.scale.shape[0],
3226+
round_up(query.shape[1] * query.shape[2] // o_sf_vec_size, 4),
3227+
)
32273228
out_scale_factor = out.scale
32283229
o_sf_start_index = out.scale_start_index
32293230
out = out.data
32303231
elif out is None:
3232+
fp4_out_scale_shape = (
3233+
round_up(query.shape[0], 128),
3234+
round_up(query.shape[1] * query.shape[2] // o_sf_vec_size, 4),
3235+
)
32313236
out_scale_factor = torch.empty(
32323237
fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=query.device
32333238
)
@@ -3236,16 +3241,30 @@ def trtllm_batch_context_with_kv_cache(
32363241
else:
32373242
raise ValueError(f"Invalid out: {out}")
32383243

3239-
_check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.device, "out")
3244+
assert isinstance(out, torch.Tensor)
32403245

32413246
# Use uint8 as the container dtype to compliant with next fp4 gemm.
3247+
_check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.device, "out")
3248+
32423249
_check_shape_dtype_device(
32433250
out_scale_factor,
32443251
fp4_out_scale_shape,
32453252
torch.float8_e4m3fn,
32463253
query.device,
32473254
"out_scale_factor",
32483255
)
3256+
3257+
# Check o_sf_start_index is valid
3258+
if (
3259+
o_sf_start_index < 0
3260+
or o_sf_start_index + out.shape[0] > out_scale_factor.shape[0]
3261+
):
3262+
raise ValueError(
3263+
f"o_sf_start_index is out of the valid range of out_scale_factor. "
3264+
f"o_sf_start_index={o_sf_start_index}, out.shape[0]={out.shape[0]}, "
3265+
f"out_scale_factor.shape[0]={out_scale_factor.shape[0]}"
3266+
)
3267+
32493268
elif isinstance(out_dtype, torch.dtype) or out_dtype is None:
32503269
assert o_sf_scale is None
32513270
assert o_sf_vec_size is None

flashinfer/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,8 +546,24 @@ def __init__(
546546
"""
547547
if data.dtype != torch.uint8:
548548
raise ValueError(f"data must be uint8 tensor, got {data.dtype}")
549+
550+
# Validate scale factor tensor and scale start index
549551
if scale.dtype != torch.float8_e4m3fn:
550552
raise ValueError(f"scale must be float8_e4m3fn tensor, got {scale.dtype}")
553+
if scale.shape[0] % 128 != 0:
554+
raise ValueError(
555+
f"scale.shape[0] must be a multiple of 128, got {scale.shape[0]}"
556+
)
557+
if scale_start_index < 0 or scale_start_index >= scale.shape[0]:
558+
raise ValueError(
559+
f"scale start index must be in the range [0, scale.shape[0]). "
560+
f"scale_start_index={scale_start_index}, scale.shape[0]={scale.shape[0]}"
561+
)
562+
if scale_start_index + data.shape[0] > scale.shape[0]:
563+
raise ValueError(
564+
f"scale start index + data.shape[0] must not exceed scale.shape[0]. "
565+
f"scale_start_index={scale_start_index}, data.shape[0]={data.shape[0]}, scale.shape[0]={scale.shape[0]}"
566+
)
551567

552568
# Validate shape relationship if original_shape is provided
553569
if original_shape is not None:

tests/test_trtllm_gen_attention.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from utils_fp4 import cast_from_fp4, recover_swizzled_scales, ref_fp4_quant
66

77
import flashinfer
8-
from flashinfer.utils import FP4Tensor
8+
from flashinfer.utils import FP4Tensor, ceil_div, round_up
99

1010
DTYPE_MAP = {
1111
"half": torch.float16,
@@ -162,19 +162,23 @@ def create_output(q, o_dtype, create_out_tensor):
162162

163163
if create_out_tensor:
164164
if o_dtype == "nvfp4":
165-
fp4_out_shape = q.shape[:-1] + (math.ceil(q.shape[-1] / 2),)
165+
fp4_out_shape = q.shape[:-1] + (ceil_div(q.shape[-1], 2),)
166+
167+
extra_size = torch.randint(0, 256, (1,)).item()
166168

167169
fp4_out_scale_shape = (
168-
math.ceil(q.shape[0] / 128) * 128,
169-
math.ceil(q.shape[1] * q.shape[2] / o_sf_vec_size / 4) * 4,
170+
round_up(q.shape[0] + extra_size, 128),
171+
round_up(q.shape[1] * q.shape[2] // o_sf_vec_size, 4),
170172
)
171173

172174
out_scale_factor = torch.empty(
173175
fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=q.device
174176
)
175-
extra_size = fp4_out_scale_shape[0] - q.shape[0]
177+
rounded_extra_size = fp4_out_scale_shape[0] - q.shape[0]
176178
o_sf_start_index = (
177-
torch.randint(0, extra_size, (1,)).item() if extra_size > 0 else 0
179+
torch.randint(0, rounded_extra_size, (1,)).item()
180+
if rounded_extra_size > 0
181+
else 0
178182
)
179183
out_data = torch.empty(fp4_out_shape, dtype=torch.uint8, device=q.device)
180184
out = FP4Tensor(out_data, out_scale_factor, o_sf_start_index)

tests/utils_fp4.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,11 @@ def ref_fp4_quant(x, global_scale, block_size, sf_use_ue8m0=False):
9090

9191
def recover_swizzled_scales(scale, m, n, block_size, sf_start_index=0):
9292
assert sf_start_index + m <= scale.shape[0]
93-
rounded_m = utils.round_up(m, 128)
93+
full_m = scale.shape[0]
9494
scale_n = n // block_size
9595
rounded_n = utils.round_up(scale_n, 4)
9696
# Recover the swizzled scaling factor to linear layout
97-
tmp = torch.reshape(scale, (1, rounded_m // 128, rounded_n // 4, 32, 4, 4))
97+
tmp = torch.reshape(scale, (1, full_m // 128, rounded_n // 4, 32, 4, 4))
9898
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
99-
result = torch.reshape(tmp, (rounded_m, rounded_n)).to(torch.float32)
99+
result = torch.reshape(tmp, (full_m, rounded_n)).to(torch.float32)
100100
return result[sf_start_index : sf_start_index + m, :scale_n]

0 commit comments

Comments
 (0)