Skip to content

Commit 19c9efc

Browse files
authored
[CUDA] Support FP8 (E4M3) KV Cache for Group Query Attention (#27321)
# Support FP8 (E4M3) KV Cache for Group Query Attention ## Description This PR adds FP8 E4M3 quantized KV cache support for the Group Query Attention (GQA) operator on CUDA, complementing the existing INT8 and INT4 quantization paths. FP8 KV caches reduce memory bandwidth requirements during inference while maintaining higher numerical precision than INT8 for the same storage footprint. ### Motivation FP8 (E4M3) format preserves floating-point semantics with a wider dynamic range than INT8 (±448 vs ±128), making it well-suited for KV cache compression in LLM inference. This is especially beneficial on Ada Lovelace (SM89+) GPUs which have native FP8 hardware support. ## Changes ### Build System - **cmake/CMakeLists.txt**: Added `onnxruntime_USE_FP8_KV_CACHE` build option (ON by default) with `USE_FP8_KV_CACHE` compiler flag. Also added build info strings for `fp8-kv-cache`, `dump-tensor`, and `dump-node` flags. ### Operator Schema - **bert_defs.cc**: Fixed shape inference for `present_key`/`present_value` when `total_sequence_length` input is provided and past_key has 0 length. The previous code could propagate a fixed dimension that later caused "Error merging shape info" warnings. ### Kernel Registration - **group_query_attention.cc**: Registered `GroupQueryAttention<MLFloat16, Float8E4M3FN>` and `<BFloat16, Float8E4M3FN>` kernel variants. Added FP8 XQA support gating (requires SM89+) and correct `XqaQuantType` mapping. - **cuda_contrib_kernels.cc**: Added FP8 kernel class declarations and `BuildKernelCreateInfo` entries. ### Core GQA Implementation - **group_query_attention_impl.cu**: Added template instantiations for `<half, __nv_fp8_e4m3>` and `<__nv_bfloat16, __nv_fp8_e4m3>`. Updated `FlashAttentionAndQuantizeKV` to dispatch to FP8 quantization kernels via `constexpr` type check. Wrapped INT4 instantiations in `#ifdef USE_INT4_KV_CACHE`. ### Quantization / Dequantization Kernels - **group_query_attention_qdq.cuh**: Added FP8 E4M3 paths in both `DequantizeKernel` and `QuantizeKernel` using `constexpr` type dispatch on `T_QUANT`. FP8 values are clamped to ±448 before conversion. ### Fused Unpack+RoPE+Append Kernel - **group_query_attention_qkv.cuh**: Refactored `LaunchUnpackRoPEAppend` to be templated on both `T` (query type) and `U` (cache type), replacing the runtime `bit_width` parameter with compile-time type-based dispatching. Added FP8 quantization path in the `UnpackRoPEAppend` kernel using `__nv_fp8_e4m3` type. Fixed cache pointer arithmetic to use byte-level addressing. ### XQA Kernel Integration - **mha.h**: Changed `InputElem` from `half` to `__nv_fp8_e4m3` when `CACHE_ELEM_ENUM == 2` (FP8). - **xqa_loader_fp16_impl.cuh / xqa_loader_bf16_impl.cuh**: Added extern declarations and dispatch logic for FP8 kernels (`LaunchXQAFp8Kernel` / `LaunchXQAFp8KernelBF16`). - **xqa_loader_fp16_fp8_impl.cuh / xqa_loader_bf16_fp8_impl.cuh** [NEW]: FP8 XQA kernel instantiation files with group sizes 4, 8, 16, 32. - **xqa_loader_{fp16,bf16}_fp8_{64,128,256}.cu** [NEW]: Per-head-size compilation units for FP8 XQA kernels. ### Python Tooling - **io_binding_helper.py**: Extended `TypeHelper` with comprehensive data type coverage: added FP8 (e4m3fn, e4m3fnuz, e5m2, e5m2fnuz), int4/uint4, double, int16/uint16, uint32/uint64, complex64/complex128, and string mappings across all conversion methods. ### Tests & Benchmarks - **test_gqa.py**: Added `test_gqa_fp8_kv_cache`, `test_gqa_fp8_prompt`, and `test_gqa_fp8_fallback_unsupported_head_size` test cases. Extended quantized test matrix to include FP8. Added FP8-specific tolerance values. - **gqa_test_helper.py**: Added FP8 cache type handling in `parity_check_gqa_past` and `parity_check_gqa_prompt` for proper tensor creation and dequantization comparison. - **benchmark_gqa.py**: Added FP8 benchmark support with `--fp8` flag. ## Testing - Unit tests: `test_gqa_fp8_kv_cache`, `test_gqa_fp8_prompt`, `test_gqa_fp8_fallback_unsupported_head_size` - Quantized test matrix expanded with FP8 variants (PER_TENSOR, PER_CHANNEL, shared/separate scales) - Benchmark: `benchmark_gqa.py --fp8` ## Requirements - CUDA GPU with SM89+ (Ada Lovelace / RTX 4000 series or newer) for FP8 support - Build with `onnxruntime_USE_FP8_KV_CACHE=ON` (default)
1 parent b5f246b commit 19c9efc

26 files changed

+1140
-422
lines changed

cmake/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled do
104104
cmake_dependent_option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
105105
option(onnxruntime_USE_FPA_INTB_GEMM "Build FpA IntB gemm cuda kernels" OFF)
106106
option(onnxruntime_USE_INT4_KV_CACHE "Build cuda kernels for int4 kv cache" OFF)
107+
option(onnxruntime_USE_FP8_KV_CACHE "Build cuda kernels for fp8 kv cache" ON)
107108
option(onnxruntime_QUICK_BUILD "Speed up build by skipping some kernels for faster development" OFF)
108109

109110
option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF)
@@ -783,6 +784,11 @@ if (onnxruntime_USE_CUDA)
783784
message( STATUS "Enable int4 kv cache for CUDA EP")
784785
list(APPEND ORT_PROVIDER_FLAGS -DUSE_INT4_KV_CACHE=1)
785786
endif()
787+
788+
if (onnxruntime_USE_FP8_KV_CACHE)
789+
message( STATUS "Enable fp8 kv cache for CUDA EP")
790+
list(APPEND ORT_PROVIDER_FLAGS -DUSE_FP8_KV_CACHE=1)
791+
endif()
786792
endif()
787793

788794
if (onnxruntime_USE_CUDA_INTERFACE AND (NOT onnxruntime_USE_CUDA))
@@ -1442,6 +1448,15 @@ if (Git_FOUND)
14421448
if (onnxruntime_USE_INT4_KV_CACHE)
14431449
string(APPEND ORT_BUILD_INFO "int4-kv-cache=1, ")
14441450
endif()
1451+
if (onnxruntime_USE_FP8_KV_CACHE)
1452+
string(APPEND ORT_BUILD_INFO "fp8-kv-cache=1, ")
1453+
endif()
1454+
if (onnxruntime_DUMP_TENSOR)
1455+
string(APPEND ORT_BUILD_INFO "dump-tensor=1, ")
1456+
endif()
1457+
if (onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS)
1458+
string(APPEND ORT_BUILD_INFO "dump-node=1, ")
1459+
endif()
14451460
endif()
14461461
string(APPEND ORT_BUILD_INFO "build type=${CMAKE_BUILD_TYPE}")
14471462
configure_file(onnxruntime_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/onnxruntime_config.h)

docs/OperatorKernels.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1003,7 +1003,7 @@ Do not modify directly.*
10031003
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
10041004
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
10051005
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
1006-
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T_CACHE**<br> *in* past_value:**T_CACHE**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *in* k_scale:**T_KV_SCALE**<br> *in* v_scale:**T_KV_SCALE**<br> *out* output:**T**<br> *out* present_key:**T_CACHE**<br> *out* present_value:**T_CACHE**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)<br/> **T_CACHE** = tensor(bfloat16), tensor(float16), tensor(int8)<br/> **T_KV_SCALE** = tensor(float)|
1006+
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T_CACHE**<br> *in* past_value:**T_CACHE**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *in* k_scale:**T_KV_SCALE**<br> *in* v_scale:**T_KV_SCALE**<br> *out* output:**T**<br> *out* present_key:**T_CACHE**<br> *out* present_value:**T_CACHE**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)<br/> **T_CACHE** = tensor(bfloat16), tensor(float16), tensor(float8e4m3fn), tensor(int8)<br/> **T_KV_SCALE** = tensor(float)|
10071007
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
10081008
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
10091009
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|

onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ REGISTER_KERNEL_TYPED(MLFloat16, MLFloat16)
6363
REGISTER_KERNEL_TYPED(BFloat16, BFloat16)
6464
REGISTER_KERNEL_TYPED(MLFloat16, int8_t)
6565
REGISTER_KERNEL_TYPED(BFloat16, int8_t)
66+
#ifdef USE_FP8_KV_CACHE
67+
REGISTER_KERNEL_TYPED(MLFloat16, Float8E4M3FN)
68+
REGISTER_KERNEL_TYPED(BFloat16, Float8E4M3FN)
69+
#endif
6670
#ifdef USE_INT4_KV_CACHE
6771
REGISTER_KERNEL_TYPED(MLFloat16, uint8_t)
6872
REGISTER_KERNEL_TYPED(BFloat16, uint8_t)
@@ -292,6 +296,8 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
292296
parameters.past_present_share_buffer = (data.past_key == data.present_key);
293297

294298
bool is_inputs_quantized = (k_quant_type_ != KVQuantizationType::NONE) || (v_quant_type_ != KVQuantizationType::NONE);
299+
constexpr bool is_int8 = std::is_same<U, int8_t>::value;
300+
constexpr bool is_fp8 = std::is_same<U, Float8E4M3FN>::value;
295301

296302
// Allocate XQA scratch if needed (only for Flash Decoding path)
297303
IAllocatorUniquePtr<void> xqa_scratch_buffer;
@@ -315,18 +321,30 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
315321
parameters.local_window_size == -1) {
316322
int group_size = parameters.num_heads / parameters.kv_num_heads;
317323

318-
bool is_int8_quantized_supported = (k_quant_type_ == KVQuantizationType::PER_TENSOR &&
324+
bool is_int8_quantized_supported = is_int8 &&
325+
(k_quant_type_ == KVQuantizationType::PER_TENSOR &&
319326
v_quant_type_ == KVQuantizationType::PER_TENSOR &&
320327
data.k_scale == data.v_scale && // XQA requires k_scale and v_scale to be the same. Here requires k_scale and v_scale are same tensor.
321-
parameters.kv_cache_bit_width == 8 &&
322328
(parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) &&
323329
(group_size == 4 || group_size == 8 || group_size == 16 || group_size == 32));
324330

331+
#ifdef USE_FP8_KV_CACHE
332+
bool is_fp8_quantized_supported = is_fp8 &&
333+
(k_quant_type_ == KVQuantizationType::PER_TENSOR &&
334+
v_quant_type_ == KVQuantizationType::PER_TENSOR &&
335+
data.k_scale == data.v_scale &&
336+
(parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) &&
337+
(group_size == 4 || group_size == 8 || group_size == 16 || group_size == 32) &&
338+
(device_prop.major >= 9 || (device_prop.major == 8 && device_prop.minor == 9))); // FP8 requires SM89+ (Ada Lovelace)
339+
#else
340+
constexpr bool is_fp8_quantized_supported = false;
341+
#endif
342+
325343
bool is_non_quantized_supported = !is_inputs_quantized &&
326344
(parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) &&
327345
(64 % group_size == 0);
328346

329-
data.use_xqa = (is_non_quantized_supported || is_int8_quantized_supported);
347+
data.use_xqa = (is_non_quantized_supported || is_int8_quantized_supported || is_fp8_quantized_supported);
330348

331349
if (data.use_xqa) {
332350
size_t xqa_internal_bytes = onnxruntime::contrib::cuda::GetXQAScratchSize(
@@ -336,7 +354,7 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
336354
parameters.kv_num_heads,
337355
parameters.head_size,
338356
parameters.seqlen_present_kv_cache,
339-
parameters.k_quant_type != KVQuantizationType::NONE ? XqaQuantType::kInt8 : XqaQuantType::kNone,
357+
parameters.k_quant_type != KVQuantizationType::NONE ? (is_fp8 ? XqaQuantType::kFp8 : XqaQuantType::kInt8) : XqaQuantType::kNone,
340358
std::is_same<T, BFloat16>::value);
341359
assert(xqa_internal_bytes > 0);
342360
// Calculate additional scratch needed for manual RoPE/Append in ExtremeDecoding

0 commit comments

Comments
 (0)