Skip to content

[CUDA] Support FP8 (E4M3) KV Cache for Group Query Attention#27321

Open
tianleiwu wants to merge 12 commits intomainfrom
tlwu/20260211/gqa_fp8_kv_cache
Open

[CUDA] Support FP8 (E4M3) KV Cache for Group Query Attention#27321
tianleiwu wants to merge 12 commits intomainfrom
tlwu/20260211/gqa_fp8_kv_cache

Conversation

@tianleiwu
Copy link
Contributor

@tianleiwu tianleiwu commented Feb 12, 2026

Support FP8 (E4M3) KV Cache for Group Query Attention

Description

This PR adds FP8 E4M3 quantized KV cache support for the Group Query Attention (GQA) operator on CUDA, complementing the existing INT8 and INT4 quantization paths. FP8 KV caches reduce memory bandwidth requirements during inference while maintaining higher numerical precision than INT8 for the same storage footprint.

Motivation

FP8 (E4M3) format preserves floating-point semantics with a wider dynamic range than INT8 (±448 vs ±128), making it well-suited for KV cache compression in LLM inference. This is especially beneficial on Ada Lovelace (SM89+) GPUs which have native FP8 hardware support.

Changes

Build System

  • cmake/CMakeLists.txt: Added onnxruntime_USE_FP8_KV_CACHE build option (ON by default) with USE_FP8_KV_CACHE compiler flag. Also added build info strings for fp8-kv-cache, dump-tensor, and dump-node flags.

Operator Schema

  • bert_defs.cc: Fixed shape inference for present_key/present_value when total_sequence_length input is provided and past_key has 0 length. The previous code could propagate a fixed dimension that later caused "Error merging shape info" warnings.

Kernel Registration

  • group_query_attention.cc: Registered GroupQueryAttention<MLFloat16, Float8E4M3FN> and <BFloat16, Float8E4M3FN> kernel variants. Added FP8 XQA support gating (requires SM89+) and correct XqaQuantType mapping.
  • cuda_contrib_kernels.cc: Added FP8 kernel class declarations and BuildKernelCreateInfo entries.

Core GQA Implementation

  • group_query_attention_impl.cu: Added template instantiations for <half, __nv_fp8_e4m3> and <__nv_bfloat16, __nv_fp8_e4m3>. Updated FlashAttentionAndQuantizeKV to dispatch to FP8 quantization kernels via constexpr type check. Wrapped INT4 instantiations in #ifdef USE_INT4_KV_CACHE.

Quantization / Dequantization Kernels

  • group_query_attention_qdq.cuh: Added FP8 E4M3 paths in both DequantizeKernel and QuantizeKernel using constexpr type dispatch on T_QUANT. FP8 values are clamped to ±448 before conversion.

Fused Unpack+RoPE+Append Kernel

  • group_query_attention_qkv.cuh: Refactored LaunchUnpackRoPEAppend to be templated on both T (query type) and U (cache type), replacing the runtime bit_width parameter with compile-time type-based dispatching. Added FP8 quantization path in the UnpackRoPEAppend kernel using __nv_fp8_e4m3 type. Fixed cache pointer arithmetic to use byte-level addressing.

XQA Kernel Integration

  • mha.h: Changed InputElem from half to __nv_fp8_e4m3 when CACHE_ELEM_ENUM == 2 (FP8).
  • xqa_loader_fp16_impl.cuh / xqa_loader_bf16_impl.cuh: Added extern declarations and dispatch logic for FP8 kernels (LaunchXQAFp8Kernel / LaunchXQAFp8KernelBF16).
  • xqa_loader_fp16_fp8_impl.cuh / xqa_loader_bf16_fp8_impl.cuh [NEW]: FP8 XQA kernel instantiation files with group sizes 4, 8, 16, 32.
  • xqa_loader_{fp16,bf16}fp8{64,128,256}.cu [NEW]: Per-head-size compilation units for FP8 XQA kernels.

Python Tooling

  • io_binding_helper.py: Extended TypeHelper with comprehensive data type coverage: added FP8 (e4m3fn, e4m3fnuz, e5m2, e5m2fnuz), int4/uint4, double, int16/uint16, uint32/uint64, complex64/complex128, and string mappings across all conversion methods.

Tests & Benchmarks

  • test_gqa.py: Added test_gqa_fp8_kv_cache, test_gqa_fp8_prompt, and test_gqa_fp8_fallback_unsupported_head_size test cases. Extended quantized test matrix to include FP8. Added FP8-specific tolerance values.
  • gqa_test_helper.py: Added FP8 cache type handling in parity_check_gqa_past and parity_check_gqa_prompt for proper tensor creation and dequantization comparison.
  • benchmark_gqa.py: Added FP8 benchmark support with --fp8 flag.

Testing

  • Unit tests: test_gqa_fp8_kv_cache, test_gqa_fp8_prompt, test_gqa_fp8_fallback_unsupported_head_size
  • Quantized test matrix expanded with FP8 variants (PER_TENSOR, PER_CHANNEL, shared/separate scales)
  • Benchmark: benchmark_gqa.py --fp8

Requirements

  • CUDA GPU with SM89+ (Ada Lovelace / RTX 4000 series or newer) for FP8 support
  • Build with onnxruntime_USE_FP8_KV_CACHE=ON (default)

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds FP8 (E4M3) quantized KV cache support for the Group Query Attention (GQA) operator on CUDA, enabling reduced memory bandwidth for LLM inference on SM89+ GPUs (Ada Lovelace and newer). FP8 provides better numerical precision than INT8 while maintaining the same 8-bit storage footprint, with a dynamic range of ±448 vs ±128.

Changes:

  • Added FP8 KV cache support with kernel registration, quantization/dequantization logic, and XQA kernel integration for FP16/BF16 query types
  • Extended Python test framework and benchmarks with FP8 test cases and type mappings
  • Fixed shape inference for present_key/present_value when past_key has zero length with total_sequence_length input

Reviewed changes

Copilot reviewed 23 out of 23 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
cmake/CMakeLists.txt Added USE_FP8_KV_CACHE build option (ON by default) with compiler flags and build info strings
onnxruntime/core/graph/contrib_ops/bert_defs.cc Fixed shape inference edge case for empty past_key with total_sequence_length
onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc Added FP8 kernel class declarations and build info entries
onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc Registered FP8 kernel variants with SM89+ hardware check for XQA support
onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu Added FP8 template instantiations and dispatch logic for quantization kernels
onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh Implemented FP8 quantization/dequantization kernels with ±448 clamping
onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh Refactored fused Unpack+RoPE+Append kernel to support FP8 cache type via template dispatch
onnxruntime/contrib_ops/cuda/bert/xqa/*.cuh, *.cu Added FP8 XQA kernel implementations for FP16/BF16 queries with head sizes 64/128/256
onnxruntime/python/tools/transformers/io_binding_helper.py Extended TypeHelper with FP8, int4/uint4, and additional numeric type mappings
onnxruntime/test/python/transformers/test_gqa.py Added FP8 test cases with proper tolerance values and detection logic
onnxruntime/test/python/transformers/gqa_test_helper.py Implemented FP8 quantization/dequantization helpers and dtype mappings
onnxruntime/test/python/transformers/benchmark_gqa.py Extended benchmark with FP8 support via --fp8 flag
onnxruntime/core/providers/cuda/llm/attention.cc Updated GQA data structure instantiation to use templated cache type

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@tianleiwu tianleiwu marked this pull request as ready for review February 12, 2026 07:17
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant