Commit 19c9efc
authored
[CUDA] Support FP8 (E4M3) KV Cache for Group Query Attention (#27321)
# Support FP8 (E4M3) KV Cache for Group Query Attention
## Description
This PR adds FP8 E4M3 quantized KV cache support for the Group Query
Attention (GQA) operator on CUDA, complementing the existing INT8 and
INT4 quantization paths. FP8 KV caches reduce memory bandwidth
requirements during inference while maintaining higher numerical
precision than INT8 for the same storage footprint.
### Motivation
FP8 (E4M3) format preserves floating-point semantics with a wider
dynamic range than INT8 (±448 vs ±128), making it well-suited for KV
cache compression in LLM inference. This is especially beneficial on Ada
Lovelace (SM89+) GPUs which have native FP8 hardware support.
## Changes
### Build System
- **cmake/CMakeLists.txt**: Added `onnxruntime_USE_FP8_KV_CACHE` build
option (ON by default) with `USE_FP8_KV_CACHE` compiler flag. Also added
build info strings for `fp8-kv-cache`, `dump-tensor`, and `dump-node`
flags.
### Operator Schema
- **bert_defs.cc**: Fixed shape inference for
`present_key`/`present_value` when `total_sequence_length` input is
provided and past_key has 0 length. The previous code could propagate a
fixed dimension that later caused "Error merging shape info" warnings.
### Kernel Registration
- **group_query_attention.cc**: Registered
`GroupQueryAttention<MLFloat16, Float8E4M3FN>` and `<BFloat16,
Float8E4M3FN>` kernel variants. Added FP8 XQA support gating (requires
SM89+) and correct `XqaQuantType` mapping.
- **cuda_contrib_kernels.cc**: Added FP8 kernel class declarations and
`BuildKernelCreateInfo` entries.
### Core GQA Implementation
- **group_query_attention_impl.cu**: Added template instantiations for
`<half, __nv_fp8_e4m3>` and `<__nv_bfloat16, __nv_fp8_e4m3>`. Updated
`FlashAttentionAndQuantizeKV` to dispatch to FP8 quantization kernels
via `constexpr` type check. Wrapped INT4 instantiations in `#ifdef
USE_INT4_KV_CACHE`.
### Quantization / Dequantization Kernels
- **group_query_attention_qdq.cuh**: Added FP8 E4M3 paths in both
`DequantizeKernel` and `QuantizeKernel` using `constexpr` type dispatch
on `T_QUANT`. FP8 values are clamped to ±448 before conversion.
### Fused Unpack+RoPE+Append Kernel
- **group_query_attention_qkv.cuh**: Refactored `LaunchUnpackRoPEAppend`
to be templated on both `T` (query type) and `U` (cache type), replacing
the runtime `bit_width` parameter with compile-time type-based
dispatching. Added FP8 quantization path in the `UnpackRoPEAppend`
kernel using `__nv_fp8_e4m3` type. Fixed cache pointer arithmetic to use
byte-level addressing.
### XQA Kernel Integration
- **mha.h**: Changed `InputElem` from `half` to `__nv_fp8_e4m3` when
`CACHE_ELEM_ENUM == 2` (FP8).
- **xqa_loader_fp16_impl.cuh / xqa_loader_bf16_impl.cuh**: Added extern
declarations and dispatch logic for FP8 kernels (`LaunchXQAFp8Kernel` /
`LaunchXQAFp8KernelBF16`).
- **xqa_loader_fp16_fp8_impl.cuh / xqa_loader_bf16_fp8_impl.cuh** [NEW]:
FP8 XQA kernel instantiation files with group sizes 4, 8, 16, 32.
- **xqa_loader_{fp16,bf16}_fp8_{64,128,256}.cu** [NEW]: Per-head-size
compilation units for FP8 XQA kernels.
### Python Tooling
- **io_binding_helper.py**: Extended `TypeHelper` with comprehensive
data type coverage: added FP8 (e4m3fn, e4m3fnuz, e5m2, e5m2fnuz),
int4/uint4, double, int16/uint16, uint32/uint64, complex64/complex128,
and string mappings across all conversion methods.
### Tests & Benchmarks
- **test_gqa.py**: Added `test_gqa_fp8_kv_cache`, `test_gqa_fp8_prompt`,
and `test_gqa_fp8_fallback_unsupported_head_size` test cases. Extended
quantized test matrix to include FP8. Added FP8-specific tolerance
values.
- **gqa_test_helper.py**: Added FP8 cache type handling in
`parity_check_gqa_past` and `parity_check_gqa_prompt` for proper tensor
creation and dequantization comparison.
- **benchmark_gqa.py**: Added FP8 benchmark support with `--fp8` flag.
## Testing
- Unit tests: `test_gqa_fp8_kv_cache`, `test_gqa_fp8_prompt`,
`test_gqa_fp8_fallback_unsupported_head_size`
- Quantized test matrix expanded with FP8 variants (PER_TENSOR,
PER_CHANNEL, shared/separate scales)
- Benchmark: `benchmark_gqa.py --fp8`
## Requirements
- CUDA GPU with SM89+ (Ada Lovelace / RTX 4000 series or newer) for FP8
support
- Build with `onnxruntime_USE_FP8_KV_CACHE=ON` (default)1 parent b5f246b commit 19c9efc
File tree
26 files changed
+1140
-422
lines changed- cmake
- docs
- onnxruntime
- contrib_ops/cuda
- bert
- xqa
- core
- graph/contrib_ops
- providers/cuda/llm
- python/tools/transformers
- test/python/transformers
26 files changed
+1140
-422
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
104 | 104 | | |
105 | 105 | | |
106 | 106 | | |
| 107 | + | |
107 | 108 | | |
108 | 109 | | |
109 | 110 | | |
| |||
783 | 784 | | |
784 | 785 | | |
785 | 786 | | |
| 787 | + | |
| 788 | + | |
| 789 | + | |
| 790 | + | |
| 791 | + | |
786 | 792 | | |
787 | 793 | | |
788 | 794 | | |
| |||
1442 | 1448 | | |
1443 | 1449 | | |
1444 | 1450 | | |
| 1451 | + | |
| 1452 | + | |
| 1453 | + | |
| 1454 | + | |
| 1455 | + | |
| 1456 | + | |
| 1457 | + | |
| 1458 | + | |
| 1459 | + | |
1445 | 1460 | | |
1446 | 1461 | | |
1447 | 1462 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1003 | 1003 | | |
1004 | 1004 | | |
1005 | 1005 | | |
1006 | | - | |
| 1006 | + | |
1007 | 1007 | | |
1008 | 1008 | | |
1009 | 1009 | | |
| |||
Lines changed: 22 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
63 | 63 | | |
64 | 64 | | |
65 | 65 | | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
66 | 70 | | |
67 | 71 | | |
68 | 72 | | |
| |||
292 | 296 | | |
293 | 297 | | |
294 | 298 | | |
| 299 | + | |
| 300 | + | |
295 | 301 | | |
296 | 302 | | |
297 | 303 | | |
| |||
315 | 321 | | |
316 | 322 | | |
317 | 323 | | |
318 | | - | |
| 324 | + | |
| 325 | + | |
319 | 326 | | |
320 | 327 | | |
321 | | - | |
322 | 328 | | |
323 | 329 | | |
324 | 330 | | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
325 | 343 | | |
326 | 344 | | |
327 | 345 | | |
328 | 346 | | |
329 | | - | |
| 347 | + | |
330 | 348 | | |
331 | 349 | | |
332 | 350 | | |
| |||
336 | 354 | | |
337 | 355 | | |
338 | 356 | | |
339 | | - | |
| 357 | + | |
340 | 358 | | |
341 | 359 | | |
342 | 360 | | |
| |||
0 commit comments