Skip to content

Commit cd928a7

Browse files
authored
support trtllm-gen prefill fp4 output (#1360)
<!-- .github/pull_request_template.md --> ## 📌 Description Support nvfp4 for prefill function call. (not wrapper yet) The nvfp4 test won't pass until the trtllm-gen kernel update as there is a bug that ignored v_scale. I tested locally and will update the kernels latter. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent d8e7d6a commit cd928a7

File tree

6 files changed

+120
-34
lines changed

6 files changed

+120
-34
lines changed

csrc/trtllm_fmha_kernel_launcher.cu

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ inline Data_type torch_dtype_to_tllm_data_type(at::ScalarType dtype) {
184184

185185
inline bool is_4bit(Data_type data_type) { return data_type == Data_type::DATA_TYPE_E2M1; }
186186

187-
void trtllm_paged_attention_decode(at::Tensor out, std::optional<at::Tensor> const out_scale_factor,
187+
void trtllm_paged_attention_decode(at::Tensor out, std::optional<at::Tensor> out_scale_factor,
188188
at::Tensor query, at::Tensor key_value_cache,
189189
at::Tensor workspace_buffer, at::Tensor block_tables,
190190
at::Tensor seq_lens, int64_t max_kv_len, double bmm1_scale,
@@ -245,12 +245,14 @@ void trtllm_paged_attention_decode(at::Tensor out, std::optional<at::Tensor> con
245245
bmm2_scale, o_sf_scale, o_sf_vec_size, window_left, sum_seq_q, sm_count, stream);
246246
}
247247

248-
void trtllm_paged_attention_context(at::Tensor out, at::Tensor query, at::Tensor key_value_cache,
248+
void trtllm_paged_attention_context(at::Tensor out, std::optional<at::Tensor> out_scale_factor,
249+
at::Tensor query, at::Tensor key_value_cache,
249250
at::Tensor workspace_buffer, at::Tensor block_tables,
250251
at::Tensor seq_lens, int64_t max_q_len, int64_t max_kv_len,
251-
double bmm1_scale, double bmm2_scale, int64_t batch_size,
252-
int64_t window_left, at::Tensor cum_seq_lens_q,
253-
at::Tensor cum_seq_lens_kv, int64_t sm_count) {
252+
double bmm1_scale, double bmm2_scale, double o_sf_scale,
253+
int64_t o_sf_vec_size, int64_t batch_size, int64_t window_left,
254+
at::Tensor cum_seq_lens_q, at::Tensor cum_seq_lens_kv,
255+
int64_t sm_count) {
254256
auto q_data_type = torch_dtype_to_tllm_data_type(query.scalar_type());
255257
auto kv_data_type = torch_dtype_to_tllm_data_type(key_value_cache.scalar_type());
256258
auto o_data_type = torch_dtype_to_tllm_data_type(out.scalar_type());
@@ -284,9 +286,10 @@ void trtllm_paged_attention_context(at::Tensor out, at::Tensor query, at::Tensor
284286

285287
auto device = query.device();
286288
const auto stream = at::cuda::getCurrentCUDAStream(device.index());
289+
void* output_sf_ptr = out_scale_factor ? out_scale_factor.value().data_ptr() : nullptr;
287290

288291
trtllm_paged_attention_launcher(
289-
out.data_ptr(), /*out_scale_factor=*/nullptr, query.data_ptr(), key_value_cache.data_ptr(),
292+
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_value_cache.data_ptr(),
290293
(char*)key_value_cache.data_ptr() +
291294
(share_kv_cache ? 0 : key_value_cache.stride(1) * key_value_cache.element_size()),
292295
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()),
@@ -296,8 +299,7 @@ void trtllm_paged_attention_context(at::Tensor out, at::Tensor query, at::Tensor
296299
o_data_type, TllmPagedAttentionMode::Context, batch_size, max_q_len, max_kv_len,
297300
num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, page_size,
298301
kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale,
299-
bmm2_scale, /* o_sf_scale =*/-1, /* o_sf_vec_size =*/-1, window_left, sum_seq_q, sm_count,
300-
stream);
302+
bmm2_scale, o_sf_scale, o_sf_vec_size, window_left, sum_seq_q, sm_count, stream);
301303
}
302304

303305
namespace trtllm_cubin_loader {

flashinfer/decode.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1988,12 +1988,12 @@ def trtllm_batch_decode_with_kv_cache(
19881988
block_tables: page_table of kv cache, [batch_size, num_pages]
19891989
seq_lens: A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``
19901990
max_seq_len: max sequence length for kv_cache
1991-
out: output tensor, if not provided, will be allocated with ``out_dtype``, if ``out_dtype`` is not provided, will use the type of ``query``.
1992-
out_dtype: output dtype, if not provided, will use the type of ``out``.
19931991
bmm1_scale: fused scale for bmm1 input.
19941992
bmm2_scale: fused scale for bmm2 input.
19951993
window_left: The left (inclusive) window size for the attention window, when set to ``-1``, the window
19961994
size will be set to the full length of the sequence. Defaults to ``-1``.
1995+
out: output tensor, if not provided, will be allocated with ``out_dtype``, if ``out_dtype`` is not provided, will use the type of ``query``.
1996+
out_dtype: output dtype, if not provided, will use the type of ``out``. For nvfp4, use string ``nvfp4``.
19971997
o_sf_scale: scale for nvfp4 output tensor scale factor.
19981998
o_sf_vec_size: vector size for nvfp4 output tensor scale factor.
19991999
@@ -2020,8 +2020,8 @@ def trtllm_batch_decode_with_kv_cache(
20202020
)
20212021

20222022
if isinstance(out, FP4Tensor):
2023-
out_scale_factor = out.scale_factor
2024-
out = out.tensor
2023+
out_scale_factor = out.scale
2024+
out = out.data
20252025
elif out is None:
20262026
out = torch.empty(fp4_out_shape, dtype=torch.uint8, device=query.device)
20272027
out_scale_factor = torch.empty(

flashinfer/prefill.py

Lines changed: 83 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from .page import block_sparse_indices_to_vector_sparse_offsets, get_seq_lens
3838
from .quantization import packbits, segment_packbits
3939
from .utils import (
40+
FP4Tensor,
4041
MaskMode,
4142
PosEncodingMode,
4243
TensorLayout,
@@ -115,6 +116,7 @@ def _paged_run(
115116
out = torch.empty_like(query)
116117
op.trtllm_paged_attention_context(
117118
out,
119+
None, # fp4 output not supported in wrapper api yet.
118120
query,
119121
kv_cache,
120122
workspace_buffer,
@@ -124,6 +126,8 @@ def _paged_run(
124126
max_kv_len,
125127
bmm1_scale,
126128
bmm2_scale,
129+
-1, # o_sf_scale
130+
-1, # o_sf_vec_size
127131
batch_size,
128132
window_left,
129133
cum_seq_lens_q,
@@ -2964,18 +2968,87 @@ def trtllm_batch_context_with_kv_cache(
29642968
cum_seq_lens_q: torch.Tensor,
29652969
cum_seq_lens_kv: torch.Tensor,
29662970
window_left: int = -1,
2967-
out: Optional[torch.Tensor] = None,
2968-
) -> torch.Tensor:
2971+
out: Optional[Union[torch.Tensor, FP4Tensor]] = None,
2972+
out_dtype: Optional[Union[torch.dtype, str]] = None,
2973+
o_sf_scale: Optional[float] = None,
2974+
o_sf_vec_size: Optional[int] = None,
2975+
) -> Union[torch.Tensor, FP4Tensor]:
2976+
"""
2977+
Parameters:
2978+
query: query tensor with shape [num_tokens, num_heads, head_dim]
2979+
kv_cache: kv_cache tensor with shape [num_pages, 1 or 2, num_kv_heads, page_size, head_dim]
2980+
workspace_buffer: workspace
2981+
block_tables: page_table of kv cache, [batch_size, num_pages]
2982+
seq_lens: A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``
2983+
max_q_len: max sequence length for query
2984+
max_kv_len: max sequence length for kv_cache
2985+
bmm1_scale: fused scale for bmm1 input.
2986+
bmm2_scale: fused scale for bmm2 input.
2987+
batch_size: batch size
2988+
cum_seq_lens_q: cumulative sequence length for query. shape: ``[batch_size + 1]``
2989+
cum_seq_lens_kv: cumulative sequence length for kv_cache. shape: ``[batch_size + 1]``
2990+
window_left: The left (inclusive) window size for the attention window, when set to ``-1``, the window
2991+
size will be set to the full length of the sequence. Defaults to ``-1``.
2992+
out: output tensor, if not provided, will be allocated with ``out_dtype``, if ``out_dtype`` is not provided, will use the type of ``query``.
2993+
out_dtype: output dtype, if not provided, will use the type of ``out``. For nvfp4, use string ``nvfp4``.
2994+
o_sf_scale: scale for nvfp4 output tensor scale factor.
2995+
o_sf_vec_size: vector size for nvfp4 output tensor scale factor.
2996+
2997+
Returns:
2998+
out: output torch.Tensor or FP4Tensor.
2999+
"""
29693000
run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_context
29703001
sm_count = get_device_sm_count(query.device)
29713002

2972-
if out is None:
2973-
out = torch.empty_like(query)
2974-
else:
3003+
if out_dtype == "nvfp4" or (out_dtype is None and isinstance(out, FP4Tensor)):
3004+
assert (
3005+
query.dtype == torch.float8_e4m3fn
3006+
), "query must be fp8 when out_dtype is nvfp4."
3007+
assert o_sf_scale is not None
3008+
assert o_sf_vec_size in [None, 16], "only o_sf_vec_size = 16 is supported"
3009+
o_sf_vec_size = o_sf_vec_size or 16
3010+
3011+
fp4_out_shape = query.shape[:-1] + (math.ceil(query.shape[-1] / 2),)
3012+
3013+
fp4_out_scale_shape = (
3014+
math.ceil(query.shape[0] / 128) * 128,
3015+
math.ceil(query.shape[1] * query.shape[2] / o_sf_vec_size / 4) * 4,
3016+
)
3017+
3018+
if isinstance(out, FP4Tensor):
3019+
out_scale_factor = out.scale
3020+
out = out.data
3021+
elif out is None:
3022+
out = torch.empty(fp4_out_shape, dtype=torch.uint8, device=query.device)
3023+
out_scale_factor = torch.empty(
3024+
fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=query.device
3025+
)
3026+
else:
3027+
raise ValueError(f"Invalid out: {out}")
3028+
3029+
_check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.device, "out")
3030+
3031+
# Use uint8 as the container dtype to compliant with next fp4 gemm.
3032+
_check_shape_dtype_device(
3033+
out_scale_factor,
3034+
fp4_out_scale_shape,
3035+
torch.float8_e4m3fn,
3036+
query.device,
3037+
"out_scale_factor",
3038+
)
3039+
elif isinstance(out_dtype, torch.dtype) or out_dtype is None:
3040+
assert o_sf_scale is None
3041+
assert o_sf_vec_size is None
3042+
out_scale_factor = None
3043+
out_dtype = out_dtype or query.dtype
3044+
out = out if out is not None else torch.empty_like(query, dtype=out_dtype)
29753045
_check_shape_dtype_device(out, query.shape, query.dtype, query.device, "out")
3046+
else:
3047+
raise ValueError(f"Invalid out_dtype: {out_dtype}")
29763048

29773049
run_func(
29783050
out,
3051+
out_scale_factor,
29793052
query,
29803053
kv_cache,
29813054
workspace_buffer,
@@ -2985,10 +3058,14 @@ def trtllm_batch_context_with_kv_cache(
29853058
max_kv_len,
29863059
bmm1_scale,
29873060
bmm2_scale,
3061+
o_sf_scale or -1.0,
3062+
o_sf_vec_size or -1,
29883063
batch_size,
29893064
window_left,
29903065
cum_seq_lens_q,
29913066
cum_seq_lens_kv,
29923067
sm_count,
29933068
)
2994-
return out
3069+
return (
3070+
out if out_dtype != "nvfp4" else FP4Tensor(out, out_scale_factor, query.shape)
3071+
)

include/flashinfer/trtllm/fmha/fmhaKernels.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ class TllmGenFmhaKernel {
237237
static std::string getCubinPath() {
238238
const char* env_hash = std::getenv("FLASHINFER_CUBIN_ARTIFACTORY_HASH");
239239
std::string hash =
240-
env_hash ? std::string(env_hash) : "4c7bdebb4eba13311fc652a069e64782d5c0723d";
240+
env_hash ? std::string(env_hash) : "52e676342c67a3772e06f10b84600044c0c22b76";
241241
std::string cubin_path = hash + "/fmha/trtllm-gen/";
242242
return cubin_path;
243243
}
@@ -595,7 +595,7 @@ class TllmFmhaKernelFactory {
595595
if (!metainfo_loaded) {
596596
std::string metainfo_raw =
597597
getMetaInfo(TllmGenFmhaKernel::getCubinPath() + "flashInferMetaInfo",
598-
"b3907fa4e30a75a0f72cfded44e6cf0f04fe5868166659732487726cbc23c0b9", ".h");
598+
"8c5630020c0452fb1cd1ea7e3b8fdbb7bf94f71bd899ed5b704a490bdb4f7368", ".h");
599599
metainfo = KernelType::KernelMeta::loadFromMetaInfoRaw(metainfo_raw);
600600
metainfo_loaded = true;
601601
}

tests/test_trtllm_gen_context.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -205,13 +205,9 @@ def test_trtllm_batch_context_wrapper(
205205
"q_dtype,kv_cache_dtype,o_dtype",
206206
[
207207
("half", "half", "half"),
208-
# ("half", "fp8", "half"),
209208
("bf16", "bf16", "bf16"),
210-
# ("bf16", "fp8", "bf16"),
211209
("fp8", "fp8", "fp8"),
212-
# ("fp8", "fp8", "half"),
213-
# ("fp8", "fp8", "bf16"),
214-
# ("fp8", "fp8", "nvfp4"),
210+
("fp8", "fp8", "nvfp4"),
215211
],
216212
)
217213
def test_trtllm_batch_prefill(
@@ -329,7 +325,7 @@ def test_trtllm_batch_prefill(
329325
else:
330326
o_scale = 1.0
331327
o_sf_scale = (
332-
0.2 if o_dtype == "nvfp4" else None
328+
300 if o_dtype == "nvfp4" else None
333329
) # choose a value to make error smaller by testing.
334330

335331
sm_scale = float(1.0 / (head_dim**0.5))
@@ -362,6 +358,9 @@ def test_trtllm_batch_prefill(
362358
q_indptr,
363359
kv_indptr,
364360
window_left, # window_left
361+
out_dtype=dtype_map[o_dtype],
362+
o_sf_scale=o_sf_scale,
363+
o_sf_vec_size=16 if o_dtype == "nvfp4" else None,
365364
)
366365

367366
# Handle different return types based on out_dtype
@@ -398,7 +397,7 @@ def test_trtllm_batch_prefill(
398397
output_ref = wrapper.run(ref_q, ref_kv_cache)
399398

400399
if q_dtype == "fp8" and o_dtype == "nvfp4":
401-
rtol, atol = 5e-1, 1.1e0
400+
rtol, atol = 4e-1, 1e0
402401
elif q_dtype == "fp8" and o_dtype == "fp8":
403402
rtol, atol = 5e-2, 7e-2
404403
else:
@@ -414,9 +413,14 @@ def test_trtllm_batch_prefill(
414413
torch.testing.assert_close(
415414
out_scale_factor.float().reshape(out_scale_factor_ref.shape),
416415
out_scale_factor_ref.float(),
417-
rtol=rtol,
418-
atol=atol,
416+
rtol=2e-1,
417+
atol=2e-1,
419418
)
419+
rmse = torch.sqrt(
420+
torch.mean((output.float() * o_scale - output_ref.float()) ** 2)
421+
)
422+
assert rmse.item() < 0.3
423+
420424
# convert to float32 for fp8 is not supported by assert_close
421425
torch.testing.assert_close(
422426
output.float() * o_scale, output_ref.float(), rtol=rtol, atol=atol

tests/test_trtllm_gen_decode.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def test_trtllm_batch_decode_fmha(
222222
else:
223223
o_scale = 1.0
224224
o_sf_scale = (
225-
0.2 if o_dtype == "nvfp4" else None
225+
300 if o_dtype == "nvfp4" else None
226226
) # choose a value to make error smaller by testing.
227227

228228
sm_scale = float(1.0 / (head_dim**0.5))
@@ -287,7 +287,7 @@ def test_trtllm_batch_decode_fmha(
287287
output_ref = wrapper.run(ref_q, ref_kv_cache)
288288

289289
if q_dtype == "fp8" and o_dtype == "nvfp4":
290-
rtol, atol = 5e-1, 1.1e0
290+
rtol, atol = 3e-1, 1e0
291291
elif q_dtype == "fp8" and o_dtype == "fp8":
292292
rtol, atol = 5e-2, 7e-2
293293
else:
@@ -303,10 +303,13 @@ def test_trtllm_batch_decode_fmha(
303303
torch.testing.assert_close(
304304
out_scale_factor.float().reshape(out_scale_factor_ref.shape),
305305
out_scale_factor_ref.float(),
306-
rtol=rtol,
307-
atol=atol,
306+
rtol=2e-1,
307+
atol=2e-1,
308308
)
309-
309+
rmse = torch.sqrt(
310+
torch.mean((output.float() * o_scale - output_ref.float()) ** 2)
311+
)
312+
assert rmse.item() < 0.3
310313
# convert to float32 for fp8 is not supported by assert_close
311314
torch.testing.assert_close(
312315
output.float() * o_scale, output_ref.float(), rtol=rtol, atol=atol

0 commit comments

Comments
 (0)