[CUDA] Support FP8 (E4M3) KV Cache for Group Query Attention#27321
Open
[CUDA] Support FP8 (E4M3) KV Cache for Group Query Attention#27321
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR adds FP8 (E4M3) quantized KV cache support for the Group Query Attention (GQA) operator on CUDA, enabling reduced memory bandwidth for LLM inference on SM89+ GPUs (Ada Lovelace and newer). FP8 provides better numerical precision than INT8 while maintaining the same 8-bit storage footprint, with a dynamic range of ±448 vs ±128.
Changes:
- Added FP8 KV cache support with kernel registration, quantization/dequantization logic, and XQA kernel integration for FP16/BF16 query types
- Extended Python test framework and benchmarks with FP8 test cases and type mappings
- Fixed shape inference for present_key/present_value when past_key has zero length with total_sequence_length input
Reviewed changes
Copilot reviewed 23 out of 23 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| cmake/CMakeLists.txt | Added USE_FP8_KV_CACHE build option (ON by default) with compiler flags and build info strings |
| onnxruntime/core/graph/contrib_ops/bert_defs.cc | Fixed shape inference edge case for empty past_key with total_sequence_length |
| onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc | Added FP8 kernel class declarations and build info entries |
| onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc | Registered FP8 kernel variants with SM89+ hardware check for XQA support |
| onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu | Added FP8 template instantiations and dispatch logic for quantization kernels |
| onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh | Implemented FP8 quantization/dequantization kernels with ±448 clamping |
| onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh | Refactored fused Unpack+RoPE+Append kernel to support FP8 cache type via template dispatch |
| onnxruntime/contrib_ops/cuda/bert/xqa/*.cuh, *.cu | Added FP8 XQA kernel implementations for FP16/BF16 queries with head sizes 64/128/256 |
| onnxruntime/python/tools/transformers/io_binding_helper.py | Extended TypeHelper with FP8, int4/uint4, and additional numeric type mappings |
| onnxruntime/test/python/transformers/test_gqa.py | Added FP8 test cases with proper tolerance values and detection logic |
| onnxruntime/test/python/transformers/gqa_test_helper.py | Implemented FP8 quantization/dequantization helpers and dtype mappings |
| onnxruntime/test/python/transformers/benchmark_gqa.py | Extended benchmark with FP8 support via --fp8 flag |
| onnxruntime/core/providers/cuda/llm/attention.cc | Updated GQA data structure instantiation to use templated cache type |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
Outdated
Show resolved
Hide resolved
onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
Outdated
Show resolved
Hide resolved
onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh
Outdated
Show resolved
Hide resolved
tianleiwu
commented
Feb 12, 2026
onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
Outdated
Show resolved
Hide resolved
tianleiwu
commented
Feb 12, 2026
onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh
Outdated
Show resolved
Hide resolved
tianleiwu
commented
Feb 12, 2026
onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh
Outdated
Show resolved
Hide resolved
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Support FP8 (E4M3) KV Cache for Group Query Attention
Description
This PR adds FP8 E4M3 quantized KV cache support for the Group Query Attention (GQA) operator on CUDA, complementing the existing INT8 and INT4 quantization paths. FP8 KV caches reduce memory bandwidth requirements during inference while maintaining higher numerical precision than INT8 for the same storage footprint.
Motivation
FP8 (E4M3) format preserves floating-point semantics with a wider dynamic range than INT8 (±448 vs ±128), making it well-suited for KV cache compression in LLM inference. This is especially beneficial on Ada Lovelace (SM89+) GPUs which have native FP8 hardware support.
Changes
Build System
onnxruntime_USE_FP8_KV_CACHEbuild option (ON by default) withUSE_FP8_KV_CACHEcompiler flag. Also added build info strings forfp8-kv-cache,dump-tensor, anddump-nodeflags.Operator Schema
present_key/present_valuewhentotal_sequence_lengthinput is provided and past_key has 0 length. The previous code could propagate a fixed dimension that later caused "Error merging shape info" warnings.Kernel Registration
GroupQueryAttention<MLFloat16, Float8E4M3FN>and<BFloat16, Float8E4M3FN>kernel variants. Added FP8 XQA support gating (requires SM89+) and correctXqaQuantTypemapping.BuildKernelCreateInfoentries.Core GQA Implementation
<half, __nv_fp8_e4m3>and<__nv_bfloat16, __nv_fp8_e4m3>. UpdatedFlashAttentionAndQuantizeKVto dispatch to FP8 quantization kernels viaconstexprtype check. Wrapped INT4 instantiations in#ifdef USE_INT4_KV_CACHE.Quantization / Dequantization Kernels
DequantizeKernelandQuantizeKernelusingconstexprtype dispatch onT_QUANT. FP8 values are clamped to ±448 before conversion.Fused Unpack+RoPE+Append Kernel
LaunchUnpackRoPEAppendto be templated on bothT(query type) andU(cache type), replacing the runtimebit_widthparameter with compile-time type-based dispatching. Added FP8 quantization path in theUnpackRoPEAppendkernel using__nv_fp8_e4m3type. Fixed cache pointer arithmetic to use byte-level addressing.XQA Kernel Integration
InputElemfromhalfto__nv_fp8_e4m3whenCACHE_ELEM_ENUM == 2(FP8).LaunchXQAFp8Kernel/LaunchXQAFp8KernelBF16).Python Tooling
TypeHelperwith comprehensive data type coverage: added FP8 (e4m3fn, e4m3fnuz, e5m2, e5m2fnuz), int4/uint4, double, int16/uint16, uint32/uint64, complex64/complex128, and string mappings across all conversion methods.Tests & Benchmarks
test_gqa_fp8_kv_cache,test_gqa_fp8_prompt, andtest_gqa_fp8_fallback_unsupported_head_sizetest cases. Extended quantized test matrix to include FP8. Added FP8-specific tolerance values.parity_check_gqa_pastandparity_check_gqa_promptfor proper tensor creation and dequantization comparison.--fp8flag.Testing
test_gqa_fp8_kv_cache,test_gqa_fp8_prompt,test_gqa_fp8_fallback_unsupported_head_sizebenchmark_gqa.py --fp8Requirements
onnxruntime_USE_FP8_KV_CACHE=ON(default)