diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt
index 11244e46b78a0..6cfee9e495451 100644
--- a/cmake/CMakeLists.txt
+++ b/cmake/CMakeLists.txt
@@ -104,6 +104,7 @@ option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled do
cmake_dependent_option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
option(onnxruntime_USE_FPA_INTB_GEMM "Build FpA IntB gemm cuda kernels" OFF)
option(onnxruntime_USE_INT4_KV_CACHE "Build cuda kernels for int4 kv cache" OFF)
+option(onnxruntime_USE_FP8_KV_CACHE "Build cuda kernels for fp8 kv cache" ON)
option(onnxruntime_QUICK_BUILD "Speed up build by skipping some kernels for faster development" OFF)
option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF)
@@ -783,6 +784,11 @@ if (onnxruntime_USE_CUDA)
message( STATUS "Enable int4 kv cache for CUDA EP")
list(APPEND ORT_PROVIDER_FLAGS -DUSE_INT4_KV_CACHE=1)
endif()
+
+ if (onnxruntime_USE_FP8_KV_CACHE)
+ message( STATUS "Enable fp8 kv cache for CUDA EP")
+ list(APPEND ORT_PROVIDER_FLAGS -DUSE_FP8_KV_CACHE=1)
+ endif()
endif()
if (onnxruntime_USE_CUDA_INTERFACE AND (NOT onnxruntime_USE_CUDA))
@@ -1442,6 +1448,15 @@ if (Git_FOUND)
if (onnxruntime_USE_INT4_KV_CACHE)
string(APPEND ORT_BUILD_INFO "int4-kv-cache=1, ")
endif()
+ if (onnxruntime_USE_FP8_KV_CACHE)
+ string(APPEND ORT_BUILD_INFO "fp8-kv-cache=1, ")
+ endif()
+ if (onnxruntime_DUMP_TENSOR)
+ string(APPEND ORT_BUILD_INFO "dump-tensor=1, ")
+ endif()
+ if (onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS)
+ string(APPEND ORT_BUILD_INFO "dump-node=1, ")
+ endif()
endif()
string(APPEND ORT_BUILD_INFO "build type=${CMAKE_BUILD_TYPE}")
configure_file(onnxruntime_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/onnxruntime_config.h)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 0230f2866fcb4..41811201cbf0e 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -1003,7 +1003,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)|
|GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
-|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)
**T_CACHE** = tensor(bfloat16), tensor(float16), tensor(int8)
**T_KV_SCALE** = tensor(float)|
+|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)
**T_CACHE** = tensor(bfloat16), tensor(float16), tensor(float8e4m3fn), tensor(int8)
**T_KV_SCALE** = tensor(float)|
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
index 39154ca395fc1..a965e00f6a391 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
@@ -63,6 +63,10 @@ REGISTER_KERNEL_TYPED(MLFloat16, MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16, BFloat16)
REGISTER_KERNEL_TYPED(MLFloat16, int8_t)
REGISTER_KERNEL_TYPED(BFloat16, int8_t)
+#ifdef USE_FP8_KV_CACHE
+REGISTER_KERNEL_TYPED(MLFloat16, Float8E4M3FN)
+REGISTER_KERNEL_TYPED(BFloat16, Float8E4M3FN)
+#endif
#ifdef USE_INT4_KV_CACHE
REGISTER_KERNEL_TYPED(MLFloat16, uint8_t)
REGISTER_KERNEL_TYPED(BFloat16, uint8_t)
@@ -292,6 +296,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons
parameters.past_present_share_buffer = (data.past_key == data.present_key);
bool is_inputs_quantized = (k_quant_type_ != KVQuantizationType::NONE) || (v_quant_type_ != KVQuantizationType::NONE);
+ constexpr bool is_int8 = std::is_same::value;
+ constexpr bool is_fp8 = std::is_same::value;
// Allocate XQA scratch if needed (only for Flash Decoding path)
IAllocatorUniquePtr xqa_scratch_buffer;
@@ -315,18 +321,30 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons
parameters.local_window_size == -1) {
int group_size = parameters.num_heads / parameters.kv_num_heads;
- bool is_int8_quantized_supported = (k_quant_type_ == KVQuantizationType::PER_TENSOR &&
+ bool is_int8_quantized_supported = is_int8 &&
+ (k_quant_type_ == KVQuantizationType::PER_TENSOR &&
v_quant_type_ == KVQuantizationType::PER_TENSOR &&
data.k_scale == data.v_scale && // XQA requires k_scale and v_scale to be the same. Here requires k_scale and v_scale are same tensor.
- parameters.kv_cache_bit_width == 8 &&
(parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) &&
(group_size == 4 || group_size == 8 || group_size == 16 || group_size == 32));
+#ifdef USE_FP8_KV_CACHE
+ bool is_fp8_quantized_supported = is_fp8 &&
+ (k_quant_type_ == KVQuantizationType::PER_TENSOR &&
+ v_quant_type_ == KVQuantizationType::PER_TENSOR &&
+ data.k_scale == data.v_scale &&
+ (parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) &&
+ (group_size == 4 || group_size == 8 || group_size == 16 || group_size == 32) &&
+ (device_prop.major >= 9 || (device_prop.major == 8 && device_prop.minor == 9))); // FP8 requires SM89+ (Ada Lovelace)
+#else
+ constexpr bool is_fp8_quantized_supported = false;
+#endif
+
bool is_non_quantized_supported = !is_inputs_quantized &&
(parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) &&
(64 % group_size == 0);
- data.use_xqa = (is_non_quantized_supported || is_int8_quantized_supported);
+ data.use_xqa = (is_non_quantized_supported || is_int8_quantized_supported || is_fp8_quantized_supported);
if (data.use_xqa) {
size_t xqa_internal_bytes = onnxruntime::contrib::cuda::GetXQAScratchSize(
@@ -336,7 +354,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons
parameters.kv_num_heads,
parameters.head_size,
parameters.seqlen_present_kv_cache,
- parameters.k_quant_type != KVQuantizationType::NONE ? XqaQuantType::kInt8 : XqaQuantType::kNone,
+ parameters.k_quant_type != KVQuantizationType::NONE ? (is_fp8 ? XqaQuantType::kFp8 : XqaQuantType::kInt8) : XqaQuantType::kNone,
std::is_same::value);
assert(xqa_internal_bytes > 0);
// Calculate additional scratch needed for manual RoPE/Append in ExtremeDecoding
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
index 59e2be5e8cd4b..961c80748d228 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
@@ -50,7 +50,6 @@ limitations under the License.
#include "contrib_ops/cuda/utils/dump_cuda_tensor.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/cuda_type_conversion.h"
-
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"
@@ -84,68 +83,58 @@ Status PrepareQKV(
const GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData& data,
const T*& q) {
+ static_assert(std::is_same::type>::value);
+ static_assert(std::is_same::type>::value);
+
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int num_heads = parameters.num_heads;
const int kv_num_heads = parameters.kv_num_heads;
const int head_size = parameters.head_size;
- typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT;
- typedef typename onnxruntime::cuda::OrtToCudaType::type CudaU;
- CudaT* q_out = reinterpret_cast(data.qkv_buffer);
+ T* q_out = reinterpret_cast(data.qkv_buffer);
if (!parameters.is_packed_qkv && !parameters.do_rotary) {
q_out = nullptr;
}
- CudaT* k = reinterpret_cast(data.present_key);
- CudaT* v = reinterpret_cast(data.present_value);
+ U* k = reinterpret_cast(data.present_key);
+ U* v = reinterpret_cast(data.present_value);
int max_cache_length = parameters.seqlen_present_kv_cache;
- bool is_cache_bnsh = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
if (!parameters.past_present_share_buffer) {
- size_t kv_buffer_size = (size_t)batch_size * kv_num_heads * max_cache_length * head_size * sizeof(CudaU);
+ size_t kv_buffer_size = (size_t)batch_size * kv_num_heads * max_cache_length * head_size * sizeof(U);
CUDA_CALL_THROW(cudaMemsetAsync(data.present_key, 0, kv_buffer_size, stream));
CUDA_CALL_THROW(cudaMemsetAsync(data.present_value, 0, kv_buffer_size, stream));
}
+ bool is_cache_bnsh = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
+ assert(is_cache_bnsh); // Only support BNSH format for now
+
// Copy past KV to present KV if needed
if (!parameters.past_present_share_buffer && data.past_key != nullptr && parameters.seqlen_past_kv_cache > 0) {
- if (is_cache_bnsh) {
- size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * head_size * sizeof(CudaU);
- size_t dst_pitch = (size_t)parameters.seqlen_present_kv_cache * head_size * sizeof(CudaU);
- size_t width = src_pitch;
- size_t height = (size_t)batch_size * kv_num_heads;
-
- CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_key, dst_pitch, data.past_key, src_pitch, width, height,
- cudaMemcpyDeviceToDevice, stream));
- CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_value, dst_pitch, data.past_value, src_pitch, width, height,
- cudaMemcpyDeviceToDevice, stream));
- } else {
- size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * kv_num_heads * head_size * sizeof(CudaU);
- size_t dst_pitch = (size_t)parameters.seqlen_present_kv_cache * kv_num_heads * head_size * sizeof(CudaU);
- size_t width = src_pitch;
- size_t height = (size_t)batch_size;
-
- CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_key, dst_pitch, data.past_key, src_pitch, width, height,
- cudaMemcpyDeviceToDevice, stream));
- CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_value, dst_pitch, data.past_value, src_pitch, width, height,
- cudaMemcpyDeviceToDevice, stream));
- }
+ size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * head_size * sizeof(U);
+ size_t dst_pitch = (size_t)max_cache_length * head_size * sizeof(U);
+ size_t width = src_pitch;
+ size_t height = (size_t)batch_size * kv_num_heads;
+ CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_key, dst_pitch, data.past_key, src_pitch, width, height,
+ cudaMemcpyDeviceToDevice, stream));
+ CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_value, dst_pitch, data.past_value, src_pitch, width, height,
+ cudaMemcpyDeviceToDevice, stream));
}
- ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppend(
- parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr,
- parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query),
- parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key),
- parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value),
+ ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend(
+ parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr,
+ parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query),
+ parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key),
+ parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value),
q_out, k, v, data.k_scale, data.v_scale,
num_heads, kv_num_heads, head_size, sequence_length, batch_size,
max_cache_length, data.past_seq_lens,
- reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache),
+ reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache),
parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved,
- is_cache_bnsh, parameters.k_quant_type, parameters.kv_cache_bit_width,
- stream, max_threads_per_block));
+ is_cache_bnsh, parameters.k_quant_type,
+ stream, max_threads_per_block)));
if (q_out != nullptr) {
q = reinterpret_cast(q_out);
@@ -572,6 +561,9 @@ Status ExtremeDecoding(
GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData& data,
float scale) {
+ static_assert(std::is_same::type>::value);
+ static_assert(std::is_same::type>::value);
+
ORT_GQA_TRACE("ExtremeDecoding");
const int batch_size = parameters.batch_size;
@@ -584,25 +576,24 @@ Status ExtremeDecoding(
// bool is_causal = parameters.is_unidirectional;
// bool is_bf16 = std::is_same::value;
- typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT;
bool past_bsnh = (past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
// Ultimate Fused Preprocessing: Unpack, RoPE Q, RoPE K, Quantize K/V, Append K/V
// This replaces all manual steps (Rotate Q, Rotate K, Quantize, StridedCopy)
- CudaT* q_rot_ptr = reinterpret_cast(data.qkv_buffer);
- const CudaT* q_input_for_xqa = q_rot_ptr;
+ T* q_rot_ptr = reinterpret_cast(data.qkv_buffer);
+ const T* q_input_for_xqa = q_rot_ptr;
if (q_rot_ptr == nullptr) {
- q_input_for_xqa = reinterpret_cast(data.query);
+ q_input_for_xqa = reinterpret_cast(data.query);
}
- ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppend(
- parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr,
- parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query),
- parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key),
- parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value),
+ ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend(
+ parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr,
+ parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query),
+ parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key),
+ parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value),
q_rot_ptr, // unpacked_q (can be null if !do_rotary)
- data.present_key,
- data.present_value,
+ reinterpret_cast(data.present_key),
+ reinterpret_cast(data.present_value),
data.k_scale,
data.v_scale,
num_heads,
@@ -612,23 +603,24 @@ Status ExtremeDecoding(
batch_size,
parameters.seqlen_present_kv_cache, // max_seqlen (capacity)
data.past_seq_lens,
- reinterpret_cast(data.cos_cache),
- reinterpret_cast(data.sin_cache),
+ reinterpret_cast(data.cos_cache),
+ reinterpret_cast(data.sin_cache),
parameters.do_rotary ? parameters.rotary_dim : 0,
data.position_ids,
parameters.rotary_interleaved,
!past_bsnh, // is_cache_bnsh
parameters.k_quant_type,
- parameters.kv_cache_bit_width,
stream,
- device_prop.maxThreadsPerBlock));
+ device_prop.maxThreadsPerBlock)));
// Determine workspace size for XQA
void* xqa_workspace = data.xqa_buffer;
size_t xqa_workspace_size = data.xqa_buffer_bytes;
+ constexpr bool is_fp8 = std::is_same::value;
+ using onnxruntime::contrib::cuda::XqaQuantType;
// 5. Launch XQA
- Status status = onnxruntime::contrib::cuda::LaunchXQAKernel(
+ Status status = onnxruntime::contrib::cuda::LaunchXQAKernel(
device_prop,
stream,
q_input_for_xqa,
@@ -644,8 +636,8 @@ Status ExtremeDecoding(
past_bsnh,
data.past_seq_lens,
data.k_scale, // kv_cache_scale
- // Map KVQuantizationType (0=NONE, 1=TENSOR, 2=CHANNEL) to XqaQuantType (0=FP16/BF16, 1=INT8, 2=FP8)
- (parameters.k_quant_type == KVQuantizationType::NONE) ? onnxruntime::contrib::cuda::XqaQuantType::kNone : onnxruntime::contrib::cuda::XqaQuantType::kInt8,
+ // Map cache type to XqaQuantType: NONE->kNone, Float8E4M3FN->kFp8, int8->kInt8
+ (parameters.k_quant_type == KVQuantizationType::NONE) ? XqaQuantType::kNone : (is_fp8 ? XqaQuantType::kFp8 : XqaQuantType::kInt8),
xqa_workspace,
xqa_workspace_size);
@@ -668,6 +660,8 @@ Status FlashDecoding(
GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData& data,
float scale) {
+ static_assert(std::is_same::type>::value);
+ static_assert(std::is_same::type>::value);
assert(!parameters.is_first_prompt && parameters.past_present_share_buffer);
ORT_GQA_TRACE("FlashDecoding");
@@ -680,7 +674,7 @@ Status FlashDecoding(
const int head_size = parameters.head_size;
AttentionQkvFormat past_kv_format = parameters.past_kv_format;
bool is_causal = parameters.is_unidirectional;
- bool is_bf16 = std::is_same::value || std::is_same::value;
+ bool is_bf16 = std::is_same::value;
void* query = reinterpret_cast(const_cast(data.query));
void* key;
@@ -732,6 +726,9 @@ Status FlashAttention(
GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData& data,
float scale) {
+ static_assert(std::is_same::type>::value);
+ static_assert(std::is_same::type>::value);
+
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
@@ -742,7 +739,7 @@ Status FlashAttention(
AttentionQkvFormat past_kv_format = parameters.past_kv_format;
bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
bool is_causal = parameters.is_unidirectional;
- bool is_bf16 = std::is_same::value || std::is_same::value;
+ bool is_bf16 = std::is_same::value;
DUMP_TENSOR_INIT();
@@ -798,6 +795,9 @@ Status DequantizeFlashAttentionFallback(
GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData& data,
float scale) {
+ static_assert(std::is_same::type>::value);
+ static_assert(std::is_same::type>::value);
+
assert(!parameters.is_first_prompt); // Only support first prompt for this function.
assert(parameters.k_quant_type != KVQuantizationType::NONE || parameters.v_quant_type != KVQuantizationType::NONE);
@@ -805,62 +805,48 @@ Status DequantizeFlashAttentionFallback(
// We need to dequantize the entire KV cache (present_key/value) into a float/half buffer (data.qkv_buffer).
// Layout in qkv_buffer: [Q (rotated)] [K_dequantized] [V_dequantized]
- typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT;
- CudaT* q_rot = reinterpret_cast(data.qkv_buffer);
+
+ T* q_rot = reinterpret_cast(data.qkv_buffer);
size_t q_elements = static_cast(parameters.batch_size) * parameters.sequence_length * parameters.num_heads * parameters.head_size;
size_t k_elements = static_cast(parameters.batch_size) * parameters.seqlen_present_kv_cache * parameters.kv_num_heads * parameters.head_size;
- CudaT* k_dequant = q_rot + q_elements;
- CudaT* v_dequant = k_dequant + k_elements;
+ T* k_dequant = q_rot + q_elements;
+ T* v_dequant = k_dequant + k_elements;
// Step 1: Update Quantized Cache
// We can use LaunchUnpackRoPEQuantizeAppend to unpack new QKV, apply RoPE, and append to quantized cache.
// This will also put rotated Q into q_rot.
- ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppend(
- parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr,
- parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query),
- parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key),
- parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value),
- q_rot, data.present_key, data.present_value, data.k_scale, data.v_scale,
+ ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend(
+ parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr,
+ parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query),
+ parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key),
+ parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value),
+ q_rot, reinterpret_cast(data.present_key), reinterpret_cast(data.present_value),
+ data.k_scale, data.v_scale,
parameters.num_heads, parameters.kv_num_heads, parameters.head_size, parameters.sequence_length, parameters.batch_size,
parameters.seqlen_present_kv_cache, data.past_seq_lens,
- reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache),
+ reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache),
parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved,
(parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH),
- parameters.k_quant_type, parameters.kv_cache_bit_width,
- stream, device_prop.maxThreadsPerBlock));
+ parameters.k_quant_type,
+ stream, device_prop.maxThreadsPerBlock)));
// Step 2: Dequantize Entire Cache
// We now have the updated quantized cache in data.present_key/value. We need to dequantize it to k_dequant/v_dequant.
bool is_bsnh = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
- if (parameters.kv_cache_bit_width == 8) {
- ORT_RETURN_IF_ERROR((LaunchDequantizeKV(
- stream, k_dequant, reinterpret_cast(data.present_key), data.k_scale,
- nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache,
- parameters.head_size, 8, parameters.k_quant_type, is_bsnh)));
+ ORT_RETURN_IF_ERROR((LaunchDequantizeKV(
+ stream, k_dequant, reinterpret_cast(data.present_key), data.k_scale,
+ nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache,
+ parameters.head_size, parameters.kv_cache_bit_width, parameters.k_quant_type, is_bsnh)));
- ORT_RETURN_IF_ERROR((LaunchDequantizeKV(
- stream, v_dequant, reinterpret_cast(data.present_value), data.v_scale,
- nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache,
- parameters.head_size, 8, parameters.v_quant_type, is_bsnh)));
-#ifdef USE_INT4_KV_CACHE
- } else if (parameters.kv_cache_bit_width == 4) {
- // Int4 support if needed
- ORT_RETURN_IF_ERROR((LaunchDequantizeKV(
- stream, k_dequant, reinterpret_cast(data.present_key), data.k_scale,
- nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache,
- parameters.head_size, 4, parameters.k_quant_type, is_bsnh)));
-
- ORT_RETURN_IF_ERROR((LaunchDequantizeKV(
- stream, v_dequant, reinterpret_cast(data.present_value), data.v_scale,
- nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache,
- parameters.head_size, 4, parameters.v_quant_type, is_bsnh)));
-#endif
- }
+ ORT_RETURN_IF_ERROR((LaunchDequantizeKV(
+ stream, v_dequant, reinterpret_cast(data.present_value), data.v_scale,
+ nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache,
+ parameters.head_size, parameters.kv_cache_bit_width, parameters.v_quant_type, is_bsnh)));
// Step 3: Run Flash Attention on dequantized k/v
bool is_causal = parameters.is_unidirectional;
- bool is_bf16 = std::is_same::value || std::is_same::value;
+ bool is_bf16 = std::is_same::value;
// Use the total_seq_lens here since k_dequant/v_dequant has both past and new tokens.
void* seqlens_k_ptr = const_cast(reinterpret_cast(data.total_seq_lens));
@@ -880,7 +866,7 @@ Status DequantizeFlashAttentionFallback(
return Status::OK();
}
-// Use Flash Attention for float key and value, then quantize key/value to int8 to save to k/v cache.
+// Use Flash Attention for float key and value, then quantize key/value (int8/fp8/int4) to save to k/v cache.
template
Status FlashAttentionAndQuantizeKV(
const cudaDeviceProp& device_prop,
@@ -888,6 +874,8 @@ Status FlashAttentionAndQuantizeKV(
GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData& data,
float scale) {
+ static_assert(std::is_same::type>::value);
+ static_assert(std::is_same::type>::value);
assert(parameters.is_first_prompt); // Only support first prompt for this function.
assert(parameters.k_quant_type != KVQuantizationType::NONE || parameters.v_quant_type != KVQuantizationType::NONE);
@@ -900,37 +888,35 @@ Status FlashAttentionAndQuantizeKV(
ORT_GQA_TRACE("FlashAttentionAndQuantizeKV");
- bool past_bsnh = parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
+ ORT_ENFORCE(parameters.past_kv_format != AttentionQkvFormat::Q_K_V_BSNH, "GQA only supports BNSH format for KV cache.");
size_t q_elements = static_cast(batch_size) * sequence_length * num_heads * head_size;
size_t k_elements = static_cast(batch_size) * sequence_length * kv_num_heads * head_size;
- using CudaT = typename onnxruntime::cuda::OrtToCudaType::type;
- CudaT* q_final = reinterpret_cast(data.qkv_buffer);
+ T* q_final = reinterpret_cast(data.qkv_buffer);
// For FlashAttentionAndQuantizeKV, we need float K and V for attention.
// We'll write them to qkv_buffer.
- CudaT* k_final = q_final + q_elements;
- CudaT* v_final = k_final + k_elements;
-
- ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppend(
- parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr,
- parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query),
- parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key),
- parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value),
+ T* k_final = q_final + q_elements;
+ T* v_final = k_final + k_elements;
+
+ ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend(
+ parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr,
+ parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query),
+ parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key),
+ parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value),
q_final, k_final, v_final, nullptr, nullptr,
num_heads, kv_num_heads, head_size, sequence_length, batch_size,
sequence_length, data.past_seq_lens,
- reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache),
+ reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache),
parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved,
false, // BSNH for scratch
KVQuantizationType::NONE,
- 0, // bit_width is 0 since we are not quantizing here.
- stream, max_threads_per_block));
+ stream, max_threads_per_block)));
// 2. Run Float Flash Attention
bool is_causal = parameters.is_unidirectional;
- bool is_bf16 = std::is_same::value || std::is_same::value;
+ bool is_bf16 = std::is_same::value;
int local_window_size = parameters.local_window_size > 0 ? parameters.local_window_size - 1 : -1;
@@ -945,37 +931,18 @@ Status FlashAttentionAndQuantizeKV(
true, // kv_bsnh = true (BSNH)
local_window_size));
- // 3. Quantize K and V to present cache
if (parameters.k_quant_type != KVQuantizationType::NONE) {
- if (parameters.kv_cache_bit_width == 8) {
- ORT_RETURN_IF_ERROR((LaunchQuantizeKV(
- stream, reinterpret_cast(data.present_key), reinterpret_cast(k_final), data.k_scale,
- nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache,
- head_size, 8, parameters.k_quant_type, true, past_bsnh)));
-#ifdef USE_INT4_KV_CACHE
- } else if (parameters.kv_cache_bit_width == 4) {
- ORT_RETURN_IF_ERROR((LaunchQuantizeKV(
- stream, reinterpret_cast(data.present_key), reinterpret_cast(k_final), data.k_scale,
- nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache,
- head_size, 4, parameters.k_quant_type, true, past_bsnh)));
-#endif
- }
+ ORT_RETURN_IF_ERROR((LaunchQuantizeKV(
+ stream, reinterpret_cast(data.present_key), reinterpret_cast(k_final), data.k_scale,
+ nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache,
+ head_size, parameters.kv_cache_bit_width, parameters.k_quant_type, true)));
}
if (parameters.v_quant_type != KVQuantizationType::NONE) {
- if (parameters.kv_cache_bit_width == 8) {
- ORT_RETURN_IF_ERROR((LaunchQuantizeKV(
- stream, reinterpret_cast(data.present_value), reinterpret_cast(v_final), data.v_scale,
- nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache,
- head_size, 8, parameters.v_quant_type, true, past_bsnh)));
-#ifdef USE_INT4_KV_CACHE
- } else if (parameters.kv_cache_bit_width == 4) {
- ORT_RETURN_IF_ERROR((LaunchQuantizeKV(
- stream, reinterpret_cast(data.present_value), reinterpret_cast(v_final), data.v_scale,
- nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache,
- head_size, 4, parameters.v_quant_type, true, past_bsnh)));
-#endif
- }
+ ORT_RETURN_IF_ERROR((LaunchQuantizeKV(
+ stream, reinterpret_cast(data.present_value), reinterpret_cast(v_final), data.v_scale,
+ nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache,
+ head_size, parameters.kv_cache_bit_width, parameters.v_quant_type, true)));
}
return Status::OK();
@@ -990,6 +957,9 @@ Status EfficientAttention(
GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData& data,
float scale) {
+ static_assert(std::is_same::type>::value);
+ static_assert(std::is_same::type>::value);
+
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
@@ -1027,7 +997,7 @@ Status EfficientAttention(
MemoryEfficientAttentionParams p;
p.sm = device_prop.major * 10 + device_prop.minor;
- p.is_bf16 = std::is_same::value || std::is_same::value;
+ p.is_bf16 = std::is_same::value;
p.is_half = !p.is_bf16 && (sizeof(T) == 2);
p.batch_size = batch_size;
p.num_heads = num_heads;
@@ -1105,7 +1075,6 @@ Status QkvToContext(
template struct GroupQueryAttentionData;
template struct GroupQueryAttentionData<__nv_bfloat16, __nv_bfloat16>;
-template struct GroupQueryAttentionData;
template struct GroupQueryAttentionData;
template Status QkvToContext(
@@ -1122,13 +1091,6 @@ template Status QkvToContext<__nv_bfloat16, __nv_bfloat16>(
contrib::GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData<__nv_bfloat16, __nv_bfloat16>& data);
-template Status QkvToContext(
- const cudaDeviceProp& device_prop,
- cublasHandle_t& cublas,
- Stream* ort_stream,
- contrib::GroupQueryAttentionParameters& parameters,
- GroupQueryAttentionData& data);
-
template Status QkvToContext(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
@@ -1145,6 +1107,7 @@ template Status QkvToContext<__nv_bfloat16, int8_t>(
contrib::GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData<__nv_bfloat16, int8_t>& data);
+#ifdef USE_INT4_KV_CACHE
template struct GroupQueryAttentionData;
template Status QkvToContext(
@@ -1162,6 +1125,27 @@ template Status QkvToContext<__nv_bfloat16, uint8_t>(
Stream* ort_stream,
contrib::GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData<__nv_bfloat16, uint8_t>& data);
+#endif
+
+#ifdef USE_FP8_KV_CACHE
+template struct GroupQueryAttentionData;
+
+template Status QkvToContext(
+ const cudaDeviceProp& device_prop,
+ cublasHandle_t& cublas,
+ Stream* ort_stream,
+ contrib::GroupQueryAttentionParameters& parameters,
+ GroupQueryAttentionData& data);
+
+template struct GroupQueryAttentionData<__nv_bfloat16, __nv_fp8_e4m3>;
+
+template Status QkvToContext<__nv_bfloat16, __nv_fp8_e4m3>(
+ const cudaDeviceProp& device_prop,
+ cublasHandle_t& cublas,
+ Stream* ort_stream,
+ contrib::GroupQueryAttentionParameters& parameters,
+ GroupQueryAttentionData<__nv_bfloat16, __nv_fp8_e4m3>& data);
+#endif
template Status LaunchUnpackQKV(const half* packed_qkv, half* unpacked_q, half* unpacked_k, half* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block);
template Status LaunchUnpackQKV<__nv_bfloat16, LAYOUT_BNSH>(const __nv_bfloat16* packed_qkv, __nv_bfloat16* unpacked_q, __nv_bfloat16* unpacked_k, __nv_bfloat16* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block);
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h
index 8cd4b44b9832e..78b061837e402 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h
@@ -133,18 +133,6 @@ Status LaunchGetSequenceLengths(
cudaStream_t stream,
const int max_threads_per_block);
-template
-Status LaunchUnpackRoPEAppend(
- const T* packed_qkv, const T* query, const T* key, const T* value,
- T* unpacked_q, void* k_cache, void* v_cache,
- const float* k_scale, const float* v_scale,
- const int num_heads, const int kv_num_heads, const int head_size,
- const int sequence_length, const int batch_size, const int max_seqlen,
- const int* past_seq_lens, const T* cos_cache, const T* sin_cache,
- const int rotary_dim, const int64_t* position_ids, const bool interleaved,
- const bool is_cache_bnsh, const KVQuantizationType k_quant_type,
- const int bit_width, cudaStream_t stream, const int max_threads_per_block);
-
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh
index 3aa9d6d96789a..b69b0238686a6 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh
@@ -2,10 +2,11 @@
// Licensed under the MIT License.
#pragma once
-// Enable quantized KV cache support for INT8/INT4
+// Enable quantized KV cache support for INT8/INT4/FP8
#define KV_QUANT_SUPPORTED 1
#include
+#include
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
#include "contrib_ops/cpu/bert/attention_common.h"
@@ -25,6 +26,7 @@ constexpr int kInt4Max = 7;
constexpr int kInt8Min = -128;
constexpr int kInt8Max = 127;
constexpr int kInt4ZeroPacked = 0x88; // (0 + 8) | ((0 + 8) << 4) for INT4 zero padding
+constexpr float kFp8E4M3Max = 448.0f; // Max value for E4M3 format
constexpr int kThreadsPerBlock = 256;
template
@@ -47,7 +49,7 @@ struct TypeConverter<__nv_bfloat16> {
// ============================================================================
//
// This file implements symmetric quantization for KV cache in GroupQueryAttention.
-// Supports INT4 and INT8 with PER_TENSOR and PER_CHANNEL quantization modes.
+// Supports INT4, INT8, and FP8 (E4M3) with PER_TENSOR and PER_CHANNEL quantization modes.
//
// QUANTIZATION SCHEME:
// -------------------
@@ -86,10 +88,15 @@ struct TypeConverter<__nv_bfloat16> {
// -------------
// Cache: BNSH (batch, num_heads, sequence_length, head_size)
// INT4: (head_size + 1) / 2 bytes per head
-// INT8: head_size bytes per head
+// INT8/FP8: head_size bytes per head
+//
+// FP8 E4M3: Native CUDA FP8 format
+// - Range: [-448, 448]
+// - Storage: __nv_fp8_e4m3 (1 byte)
+// - Conversion: Native CUDA cast via __nv_cvt_float_to_fp8/fp8_to_float
// ============================================================================
-// Dequantization Kernel: Converts Quantized (Int8/Int4) KV cache back to Floating Point (T).
+// Dequantization Kernel: Converts Quantized (Int8/Int4/FP8) KV cache back to Floating Point (T).
// Iterates over every individual element with one thread per element.
template
__global__ void DequantizeKernel(T* dequantized_data,
@@ -143,7 +150,14 @@ __global__ void DequantizeKernel(T* dequantized_data,
(bit_width == 4 ? h / 2 : h);
}
- if (bit_width == 8) {
+ // FP8 check must come first since it also has bit_width=8
+#ifdef USE_FP8_KV_CACHE
+ if constexpr (std::is_same::value) {
+ __nv_fp8_e4m3 fp8_val = reinterpret_cast(quantized_data)[input_idx];
+ quantized_float = static_cast(fp8_val);
+ } else
+#endif
+ if (bit_width == 8) {
quantized_float = static_cast(
reinterpret_cast(quantized_data)[input_idx]);
#ifdef USE_INT4_KV_CACHE
@@ -181,7 +195,7 @@ Status LaunchDequantizeKV(cudaStream_t stream, T* dequantized_data,
return CUDA_CALL(cudaGetLastError());
}
-// Quantization Kernel: Converts Floating Point (T) cache to Quantized (Int8/Int4) values.
+// Quantization Kernel: Converts Floating Point (T) cache to Quantized (Int8/Int4/FP8) values.
// Note: This kernel is used to quantize a full input tensor, e.g. during graph initialization
// or fallback paths. The main prompt path uses the fused UnpackRoPEAppend kernel.
template
@@ -193,8 +207,7 @@ __global__ void QuantizeKernel(T_QUANT* quantized_data,
int input_sequence_length,
int cache_sequence_length, int num_heads, int head_size,
int bit_width, KVQuantizationType quant_type,
- bool is_input_bsnh,
- bool is_output_bsnh) {
+ bool is_input_bsnh) {
// elements_per_head_packed is the number of BYTES occupied by head_size elements.
int elements_per_head_packed = (bit_width == 4) ? (head_size + 1) / 2 : head_size;
@@ -220,17 +233,14 @@ __global__ void QuantizeKernel(T_QUANT* quantized_data,
if (s >= total_valid_len_b) {
if (bit_width == 8) {
int64_t out_idx = i;
- if (is_output_bsnh) {
- int64_t b_idx = b;
- int64_t n_idx = n;
- int64_t s_idx = s;
- int64_t h_idx = i % elements_per_head_packed;
- out_idx = b_idx * cache_sequence_length * num_heads * elements_per_head_packed +
- s_idx * num_heads * elements_per_head_packed +
- n_idx * elements_per_head_packed +
- h_idx;
- }
+
reinterpret_cast(quantized_data)[out_idx] = 0;
+#ifdef USE_FP8_KV_CACHE
+ } else if constexpr (std::is_same::value) { // FP8
+ int64_t out_idx = i;
+
+ reinterpret_cast<__nv_fp8_e4m3*>(quantized_data)[out_idx] = __nv_fp8_e4m3(0.0f);
+#endif
#ifdef USE_INT4_KV_CACHE
} else if (bit_width == 4) { // INT4
// With packed iteration, each thread handles one byte (2 values).
@@ -242,16 +252,7 @@ __global__ void QuantizeKernel(T_QUANT* quantized_data,
// Since `h_idx` comes from `i % elements_per_head_packed`, `out_idx` is guaranteed
// to be within the buffer bounds. Writing kInt4ZeroPacked is safe.
int64_t out_idx = i;
- if (is_output_bsnh) {
- int64_t b_idx = b;
- int64_t n_idx = n;
- int64_t s_idx = s;
- int64_t h_idx = i % elements_per_head_packed;
- out_idx = b_idx * cache_sequence_length * num_heads * elements_per_head_packed +
- s_idx * num_heads * elements_per_head_packed +
- n_idx * elements_per_head_packed +
- h_idx;
- }
+
// INT4 uses +8 bias, so zero values pack to 0x88
reinterpret_cast(quantized_data)[out_idx] = kInt4ZeroPacked;
#endif
@@ -260,18 +261,35 @@ __global__ void QuantizeKernel(T_QUANT* quantized_data,
}
int64_t output_idx = i;
- if (is_output_bsnh) {
- int64_t b_idx = b;
- int64_t n_idx = n;
- int64_t s_idx = s;
- int64_t h_idx = i % elements_per_head_packed;
- output_idx = b_idx * cache_sequence_length * num_heads * elements_per_head_packed +
- s_idx * num_heads * elements_per_head_packed +
- n_idx * elements_per_head_packed +
- h_idx;
- }
- if (bit_width == 8) {
+#ifdef USE_FP8_KV_CACHE
+ if constexpr (std::is_same::value) {
+ int h = h_packed;
+ float scale_val = 1.0f;
+ if (quant_type == KVQuantizationType::PER_TENSOR) {
+ scale_val = static_cast(scale[0]);
+ } else { // PER_CHANNEL
+ int scale_idx = n * head_size + h;
+ scale_val = static_cast(scale[scale_idx]);
+ }
+
+ float inv_scale = (scale_val == 0.0f) ? 0.0f : 1.0f / scale_val;
+ int64_t flattened_input_idx = is_input_bsnh ? (static_cast(b) * input_sequence_length * num_heads * head_size +
+ static_cast(s) * num_heads * head_size +
+ static_cast(n) * head_size +
+ h)
+ : ((int64_t)b * num_heads * input_sequence_length * head_size +
+ (int64_t)n * input_sequence_length * head_size +
+ (int64_t)s * head_size +
+ h);
+ float val_float = static_cast(dequantized_data[flattened_input_idx]) * inv_scale;
+
+ // Clamp to FP8 E4M3 range and convert
+ val_float = fmaxf(-kFp8E4M3Max, fminf(kFp8E4M3Max, val_float));
+ reinterpret_cast<__nv_fp8_e4m3*>(quantized_data)[output_idx] = __nv_fp8_e4m3(val_float);
+ } else
+#endif
+ if (bit_width == 8) {
int h = h_packed;
float scale_val = 1.0f;
if (quant_type == KVQuantizationType::PER_TENSOR) {
@@ -363,8 +381,7 @@ Status LaunchQuantizeKV(cudaStream_t stream, T_QUANT* quantized_data,
int batch_size, int num_heads,
int input_sequence_length, int cache_sequence_length, int head_size, int bit_width,
KVQuantizationType quant_type,
- bool is_input_bsnh,
- bool is_output_bsnh) {
+ bool is_input_bsnh) {
assert(total_seq_lens != nullptr);
if (cache_sequence_length == 0) return Status::OK();
@@ -375,7 +392,7 @@ Status LaunchQuantizeKV(cudaStream_t stream, T_QUANT* quantized_data,
QuantizeKernel<<>>(
quantized_data, dequantized_data, scale, past_seq_lens, total_seq_lens, total_packed_elements,
- input_sequence_length, cache_sequence_length, num_heads, head_size, bit_width, quant_type, is_input_bsnh, is_output_bsnh);
+ input_sequence_length, cache_sequence_length, num_heads, head_size, bit_width, quant_type, is_input_bsnh);
return CUDA_CALL(cudaGetLastError());
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh
index d5c95be316a1f..20f0144c335ee 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh
@@ -3,11 +3,16 @@
#pragma once
#include
+#include
+#ifdef USE_FP8_KV_CACHE
+#include
+#endif
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
#include "contrib_ops/cpu/bert/attention_common.h"
#include "contrib_ops/cuda/bert/rotary_common.cuh"
#include "core/providers/cuda/cuda_common.h"
+#include "core/providers/cuda/cuda_type_conversion.h"
#include "core/providers/cuda/shared_inc/cuda_call.h"
using namespace onnxruntime::cuda;
@@ -28,18 +33,19 @@ namespace cuda {
// 4. Writes the rotated Q back to global memory (unpacked_q) for the subsequent attention kernel.
//
// Template Parameters:
-// - T: The floating point type (half or BFloat16).
-// - BIT_WIDTH: The bit width for KV cache quantization (16=none, 8=Int8, 4=Int4).
+// - T: The floating point type for query (half or __nv_bfloat16).
+// - U: The cache element type (T for no quant, int8_t for INT8, uint8_t for INT4, __nv_fp8_e4m3 for FP8).
+// - BIT_WIDTH: The bit width for KV cache quantization (16=none, 8=Int8/FP8, 4=Int4).
// - MAX_HEAD_SIZE: Maximum supported head size, used for shared memory allocation.
-template
+template
__global__ void UnpackRoPEAppend(
const T* packed_qkv,
const T* query,
const T* key,
const T* value,
T* unpacked_q,
- void* k_cache,
- void* v_cache,
+ U* k_cache,
+ U* v_cache,
const float* k_scale,
const float* v_scale,
const int num_heads,
@@ -200,24 +206,39 @@ __global__ void UnpackRoPEAppend(
// No quantization: direct store
reinterpret_cast(cache_ptr)[cache_idx / elements_per_thread] = *reinterpret_cast(vals);
} else if constexpr (BIT_WIDTH == 8) {
- // Int8 Quantization: 1 element per byte
+ // 8-bit quantization: either INT8 or FP8 E4M3 based on cache type U
const float* scale_buffer = (head_type == KEY) ? k_scale : v_scale;
uint64_t packed = 0;
- for (int i = 0; i < elements_per_thread; ++i) {
- float sc = per_channel ? scale_buffer[n * head_size + h + i] : scale_buffer[0];
- float inv_s = (sc == 0.0f) ? 0.0f : 1.0f / sc;
- int8_t q = static_cast(max(-128.0f, min(127.0f, rintf(static_cast(vals[i]) * inv_s))));
- packed |= (static_cast(static_cast(q)) << (i * 8));
+#ifdef USE_FP8_KV_CACHE
+ if constexpr (std::is_same::value) {
+ // FP8 E4M3 Quantization: scale and convert to FP8 format
+ constexpr float kFp8E4M3Max = 448.0f;
+ for (int i = 0; i < elements_per_thread; ++i) {
+ float sc = per_channel ? scale_buffer[n * head_size + h + i] : scale_buffer[0];
+ float scaled_val = min(kFp8E4M3Max, max(-kFp8E4M3Max, static_cast(vals[i]) * (sc == 0.0f ? 0.0f : 1.0f / sc)));
+ __nv_fp8_e4m3 fp8_val = __nv_fp8_e4m3(scaled_val);
+ packed |= (static_cast(*reinterpret_cast(&fp8_val)) << (i * 8));
+ }
+ } else
+#endif
+ {
+ // INT8 Quantization: round and clamp to [-128, 127]
+ for (int i = 0; i < elements_per_thread; ++i) {
+ float sc = per_channel ? scale_buffer[n * head_size + h + i] : scale_buffer[0];
+ int8_t q = static_cast(max(-128.0f, min(127.0f, rintf(static_cast(vals[i]) * (sc == 0.0f ? 0.0f : 1.0f / sc)))));
+ packed |= (static_cast(static_cast(q)) << (i * 8));
+ }
}
// Store 8 elements (8 bytes) at once
- reinterpret_cast(cache_ptr)[cache_idx / 8] = packed;
+ unsigned char* cache_byte_ptr = reinterpret_cast((head_type == KEY) ? k_cache : v_cache);
+ reinterpret_cast(cache_byte_ptr + cache_idx)[0] = packed;
} else if constexpr (BIT_WIDTH == 4) {
// Int4 Quantization: 2 elements per byte
constexpr float kInt4Min = -8.0f;
constexpr float kInt4Max = 7.0f;
const float* scale_buffer = (head_type == KEY) ? k_scale : v_scale;
uint32_t packed = 0;
- for (int i = 0; i < 4; ++i) {
+ for (int i = 0; i < elements_per_thread / 2; ++i) {
// Elements are paired as (0,1), (2,3), etc. into single bytes.
float s0 = per_channel ? scale_buffer[n * head_size + h + i * 2] : scale_buffer[0];
float s1 = per_channel ? scale_buffer[n * head_size + h + i * 2 + 1] : scale_buffer[0];
@@ -237,28 +258,28 @@ __global__ void UnpackRoPEAppend(
// Internal dispatcher that selects the appropriate template specialization based on head_size.
// MAX_HEAD_SIZE is used to optimize shared memory usage and kernel performance.
-template
+template
Status DispatchUnpackRoPEAppendHeadSize(
const dim3& grid, const dim3& block, cudaStream_t stream,
const T* packed_qkv, const T* query, const T* key, const T* value,
- T* unpacked_q, void* k_cache, void* v_cache,
+ T* unpacked_q, U* k_cache, U* v_cache,
const float* k_scale, const float* v_scale,
const int num_heads, const int kv_num_heads, const int head_size, const int d,
const int max_seqlen, const int* past_seq_lens,
const T* cos_cache, const T* sin_cache, const int rotary_dim,
const int64_t* position_ids, const bool interleaved, const bool is_cache_bnsh, const bool per_channel) {
if (head_size <= 64) {
- UnpackRoPEAppend<<>>(
+ UnpackRoPEAppend<<>>(
packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale,
num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens,
cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel);
} else if (head_size <= 128) {
- UnpackRoPEAppend<<>>(
+ UnpackRoPEAppend<<>>(
packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale,
num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens,
cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel);
} else if (head_size <= 256) {
- UnpackRoPEAppend<<>>(
+ UnpackRoPEAppend<<>>(
packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale,
num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens,
cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel);
@@ -269,18 +290,24 @@ Status DispatchUnpackRoPEAppendHeadSize(
}
// Public entry point to launch the Unpack+RoPE+Append kernel.
-// Handles parameter validation, grid/block sizing, and bit-width dispatching.
-template
+// Handles parameter validation, grid/block sizing, and type-based dispatching.
+// Template parameters:
+// - T: Query/Key/Value floating point type (half or __nv_bfloat16)
+// - U: Cache element type (T for no quant, int8_t for INT8, uint8_t for INT4, __nv_fp8_e4m3 for FP8)
+template
Status LaunchUnpackRoPEAppend(
const T* packed_qkv, const T* query, const T* key, const T* value,
- T* unpacked_q, void* k_cache, void* v_cache,
+ T* unpacked_q, U* k_cache, U* v_cache,
const float* k_scale, const float* v_scale,
const int num_heads, const int kv_num_heads, const int head_size,
const int sequence_length, const int batch_size, const int max_seqlen,
const int* past_seq_lens, const T* cos_cache, const T* sin_cache,
const int rotary_dim, const int64_t* position_ids, const bool interleaved,
const bool is_cache_bnsh, const KVQuantizationType k_quant_type,
- const int bit_width, cudaStream_t stream, const int max_threads_per_block) {
+ cudaStream_t stream, const int max_threads_per_block) {
+ static_assert(std::is_same::type>::value);
+ static_assert(std::is_same::type>::value);
+
constexpr int elements_per_vector = sizeof(float4) / sizeof(T);
if (head_size % elements_per_vector != 0) {
@@ -315,26 +342,37 @@ Status LaunchUnpackRoPEAppend(
bool per_channel = (k_quant_type == KVQuantizationType::PER_CHANNEL);
- if (bit_width == 0) {
- return DispatchUnpackRoPEAppendHeadSize(
+ // Dispatch based on cache type U:
+ // - std::is_same: No quantization (BIT_WIDTH=16)
+ // - std::is_same or FP8: 8-bit quantization (BIT_WIDTH=8)
+ // - std::is_same: 4-bit quantization (BIT_WIDTH=4)
+ if constexpr (std::is_same::value) {
+ // No quantization: cache type same as input type
+ return DispatchUnpackRoPEAppendHeadSize(
grid, block, stream, packed_qkv, query, key, value, unpacked_q, k_cache, v_cache,
k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens,
cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel);
- } else if (bit_width == 8) {
- return DispatchUnpackRoPEAppendHeadSize(
+ } else if constexpr (std::is_same::value
+#ifdef USE_FP8_KV_CACHE
+ || std::is_same::value
+#endif
+ ) {
+ // INT8 or FP8 quantization (both 8-bit, distinguished inside kernel by type check)
+ return DispatchUnpackRoPEAppendHeadSize(
grid, block, stream, packed_qkv, query, key, value, unpacked_q, k_cache, v_cache,
k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens,
cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel);
#ifdef USE_INT4_KV_CACHE
- } else if (bit_width == 4) {
- return DispatchUnpackRoPEAppendHeadSize(
+ } else if constexpr (std::is_same::value) {
+ // INT4 quantization (packed 2 elements per byte)
+ return DispatchUnpackRoPEAppendHeadSize