diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 11244e46b78a0..6cfee9e495451 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -104,6 +104,7 @@ option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled do cmake_dependent_option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF) option(onnxruntime_USE_FPA_INTB_GEMM "Build FpA IntB gemm cuda kernels" OFF) option(onnxruntime_USE_INT4_KV_CACHE "Build cuda kernels for int4 kv cache" OFF) +option(onnxruntime_USE_FP8_KV_CACHE "Build cuda kernels for fp8 kv cache" ON) option(onnxruntime_QUICK_BUILD "Speed up build by skipping some kernels for faster development" OFF) option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF) @@ -783,6 +784,11 @@ if (onnxruntime_USE_CUDA) message( STATUS "Enable int4 kv cache for CUDA EP") list(APPEND ORT_PROVIDER_FLAGS -DUSE_INT4_KV_CACHE=1) endif() + + if (onnxruntime_USE_FP8_KV_CACHE) + message( STATUS "Enable fp8 kv cache for CUDA EP") + list(APPEND ORT_PROVIDER_FLAGS -DUSE_FP8_KV_CACHE=1) + endif() endif() if (onnxruntime_USE_CUDA_INTERFACE AND (NOT onnxruntime_USE_CUDA)) @@ -1442,6 +1448,15 @@ if (Git_FOUND) if (onnxruntime_USE_INT4_KV_CACHE) string(APPEND ORT_BUILD_INFO "int4-kv-cache=1, ") endif() + if (onnxruntime_USE_FP8_KV_CACHE) + string(APPEND ORT_BUILD_INFO "fp8-kv-cache=1, ") + endif() + if (onnxruntime_DUMP_TENSOR) + string(APPEND ORT_BUILD_INFO "dump-tensor=1, ") + endif() + if (onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS) + string(APPEND ORT_BUILD_INFO "dump-node=1, ") + endif() endif() string(APPEND ORT_BUILD_INFO "build type=${CMAKE_BUILD_TYPE}") configure_file(onnxruntime_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/onnxruntime_config.h) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 0230f2866fcb4..41811201cbf0e 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -1003,7 +1003,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)
**T_CACHE** = tensor(bfloat16), tensor(float16), tensor(int8)
**T_KV_SCALE** = tensor(float)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)
**T_CACHE** = tensor(bfloat16), tensor(float16), tensor(float8e4m3fn), tensor(int8)
**T_KV_SCALE** = tensor(float)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 39154ca395fc1..a965e00f6a391 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -63,6 +63,10 @@ REGISTER_KERNEL_TYPED(MLFloat16, MLFloat16) REGISTER_KERNEL_TYPED(BFloat16, BFloat16) REGISTER_KERNEL_TYPED(MLFloat16, int8_t) REGISTER_KERNEL_TYPED(BFloat16, int8_t) +#ifdef USE_FP8_KV_CACHE +REGISTER_KERNEL_TYPED(MLFloat16, Float8E4M3FN) +REGISTER_KERNEL_TYPED(BFloat16, Float8E4M3FN) +#endif #ifdef USE_INT4_KV_CACHE REGISTER_KERNEL_TYPED(MLFloat16, uint8_t) REGISTER_KERNEL_TYPED(BFloat16, uint8_t) @@ -292,6 +296,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons parameters.past_present_share_buffer = (data.past_key == data.present_key); bool is_inputs_quantized = (k_quant_type_ != KVQuantizationType::NONE) || (v_quant_type_ != KVQuantizationType::NONE); + constexpr bool is_int8 = std::is_same::value; + constexpr bool is_fp8 = std::is_same::value; // Allocate XQA scratch if needed (only for Flash Decoding path) IAllocatorUniquePtr xqa_scratch_buffer; @@ -315,18 +321,30 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons parameters.local_window_size == -1) { int group_size = parameters.num_heads / parameters.kv_num_heads; - bool is_int8_quantized_supported = (k_quant_type_ == KVQuantizationType::PER_TENSOR && + bool is_int8_quantized_supported = is_int8 && + (k_quant_type_ == KVQuantizationType::PER_TENSOR && v_quant_type_ == KVQuantizationType::PER_TENSOR && data.k_scale == data.v_scale && // XQA requires k_scale and v_scale to be the same. Here requires k_scale and v_scale are same tensor. - parameters.kv_cache_bit_width == 8 && (parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) && (group_size == 4 || group_size == 8 || group_size == 16 || group_size == 32)); +#ifdef USE_FP8_KV_CACHE + bool is_fp8_quantized_supported = is_fp8 && + (k_quant_type_ == KVQuantizationType::PER_TENSOR && + v_quant_type_ == KVQuantizationType::PER_TENSOR && + data.k_scale == data.v_scale && + (parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) && + (group_size == 4 || group_size == 8 || group_size == 16 || group_size == 32) && + (device_prop.major >= 9 || (device_prop.major == 8 && device_prop.minor == 9))); // FP8 requires SM89+ (Ada Lovelace) +#else + constexpr bool is_fp8_quantized_supported = false; +#endif + bool is_non_quantized_supported = !is_inputs_quantized && (parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) && (64 % group_size == 0); - data.use_xqa = (is_non_quantized_supported || is_int8_quantized_supported); + data.use_xqa = (is_non_quantized_supported || is_int8_quantized_supported || is_fp8_quantized_supported); if (data.use_xqa) { size_t xqa_internal_bytes = onnxruntime::contrib::cuda::GetXQAScratchSize( @@ -336,7 +354,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons parameters.kv_num_heads, parameters.head_size, parameters.seqlen_present_kv_cache, - parameters.k_quant_type != KVQuantizationType::NONE ? XqaQuantType::kInt8 : XqaQuantType::kNone, + parameters.k_quant_type != KVQuantizationType::NONE ? (is_fp8 ? XqaQuantType::kFp8 : XqaQuantType::kInt8) : XqaQuantType::kNone, std::is_same::value); assert(xqa_internal_bytes > 0); // Calculate additional scratch needed for manual RoPE/Append in ExtremeDecoding diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 59e2be5e8cd4b..961c80748d228 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -50,7 +50,6 @@ limitations under the License. #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_type_conversion.h" - #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" @@ -84,68 +83,58 @@ Status PrepareQKV( const GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, const T*& q) { + static_assert(std::is_same::type>::value); + static_assert(std::is_same::type>::value); + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int num_heads = parameters.num_heads; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; - typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT; - typedef typename onnxruntime::cuda::OrtToCudaType::type CudaU; - CudaT* q_out = reinterpret_cast(data.qkv_buffer); + T* q_out = reinterpret_cast(data.qkv_buffer); if (!parameters.is_packed_qkv && !parameters.do_rotary) { q_out = nullptr; } - CudaT* k = reinterpret_cast(data.present_key); - CudaT* v = reinterpret_cast(data.present_value); + U* k = reinterpret_cast(data.present_key); + U* v = reinterpret_cast(data.present_value); int max_cache_length = parameters.seqlen_present_kv_cache; - bool is_cache_bnsh = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); if (!parameters.past_present_share_buffer) { - size_t kv_buffer_size = (size_t)batch_size * kv_num_heads * max_cache_length * head_size * sizeof(CudaU); + size_t kv_buffer_size = (size_t)batch_size * kv_num_heads * max_cache_length * head_size * sizeof(U); CUDA_CALL_THROW(cudaMemsetAsync(data.present_key, 0, kv_buffer_size, stream)); CUDA_CALL_THROW(cudaMemsetAsync(data.present_value, 0, kv_buffer_size, stream)); } + bool is_cache_bnsh = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + assert(is_cache_bnsh); // Only support BNSH format for now + // Copy past KV to present KV if needed if (!parameters.past_present_share_buffer && data.past_key != nullptr && parameters.seqlen_past_kv_cache > 0) { - if (is_cache_bnsh) { - size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * head_size * sizeof(CudaU); - size_t dst_pitch = (size_t)parameters.seqlen_present_kv_cache * head_size * sizeof(CudaU); - size_t width = src_pitch; - size_t height = (size_t)batch_size * kv_num_heads; - - CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_key, dst_pitch, data.past_key, src_pitch, width, height, - cudaMemcpyDeviceToDevice, stream)); - CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_value, dst_pitch, data.past_value, src_pitch, width, height, - cudaMemcpyDeviceToDevice, stream)); - } else { - size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * kv_num_heads * head_size * sizeof(CudaU); - size_t dst_pitch = (size_t)parameters.seqlen_present_kv_cache * kv_num_heads * head_size * sizeof(CudaU); - size_t width = src_pitch; - size_t height = (size_t)batch_size; - - CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_key, dst_pitch, data.past_key, src_pitch, width, height, - cudaMemcpyDeviceToDevice, stream)); - CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_value, dst_pitch, data.past_value, src_pitch, width, height, - cudaMemcpyDeviceToDevice, stream)); - } + size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * head_size * sizeof(U); + size_t dst_pitch = (size_t)max_cache_length * head_size * sizeof(U); + size_t width = src_pitch; + size_t height = (size_t)batch_size * kv_num_heads; + CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_key, dst_pitch, data.past_key, src_pitch, width, height, + cudaMemcpyDeviceToDevice, stream)); + CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_value, dst_pitch, data.past_value, src_pitch, width, height, + cudaMemcpyDeviceToDevice, stream)); } - ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppend( - parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), + ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend( + parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), q_out, k, v, data.k_scale, data.v_scale, num_heads, kv_num_heads, head_size, sequence_length, batch_size, max_cache_length, data.past_seq_lens, - reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), + reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved, - is_cache_bnsh, parameters.k_quant_type, parameters.kv_cache_bit_width, - stream, max_threads_per_block)); + is_cache_bnsh, parameters.k_quant_type, + stream, max_threads_per_block))); if (q_out != nullptr) { q = reinterpret_cast(q_out); @@ -572,6 +561,9 @@ Status ExtremeDecoding( GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, float scale) { + static_assert(std::is_same::type>::value); + static_assert(std::is_same::type>::value); + ORT_GQA_TRACE("ExtremeDecoding"); const int batch_size = parameters.batch_size; @@ -584,25 +576,24 @@ Status ExtremeDecoding( // bool is_causal = parameters.is_unidirectional; // bool is_bf16 = std::is_same::value; - typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT; bool past_bsnh = (past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); // Ultimate Fused Preprocessing: Unpack, RoPE Q, RoPE K, Quantize K/V, Append K/V // This replaces all manual steps (Rotate Q, Rotate K, Quantize, StridedCopy) - CudaT* q_rot_ptr = reinterpret_cast(data.qkv_buffer); - const CudaT* q_input_for_xqa = q_rot_ptr; + T* q_rot_ptr = reinterpret_cast(data.qkv_buffer); + const T* q_input_for_xqa = q_rot_ptr; if (q_rot_ptr == nullptr) { - q_input_for_xqa = reinterpret_cast(data.query); + q_input_for_xqa = reinterpret_cast(data.query); } - ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppend( - parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), + ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend( + parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), q_rot_ptr, // unpacked_q (can be null if !do_rotary) - data.present_key, - data.present_value, + reinterpret_cast(data.present_key), + reinterpret_cast(data.present_value), data.k_scale, data.v_scale, num_heads, @@ -612,23 +603,24 @@ Status ExtremeDecoding( batch_size, parameters.seqlen_present_kv_cache, // max_seqlen (capacity) data.past_seq_lens, - reinterpret_cast(data.cos_cache), - reinterpret_cast(data.sin_cache), + reinterpret_cast(data.cos_cache), + reinterpret_cast(data.sin_cache), parameters.do_rotary ? parameters.rotary_dim : 0, data.position_ids, parameters.rotary_interleaved, !past_bsnh, // is_cache_bnsh parameters.k_quant_type, - parameters.kv_cache_bit_width, stream, - device_prop.maxThreadsPerBlock)); + device_prop.maxThreadsPerBlock))); // Determine workspace size for XQA void* xqa_workspace = data.xqa_buffer; size_t xqa_workspace_size = data.xqa_buffer_bytes; + constexpr bool is_fp8 = std::is_same::value; + using onnxruntime::contrib::cuda::XqaQuantType; // 5. Launch XQA - Status status = onnxruntime::contrib::cuda::LaunchXQAKernel( + Status status = onnxruntime::contrib::cuda::LaunchXQAKernel( device_prop, stream, q_input_for_xqa, @@ -644,8 +636,8 @@ Status ExtremeDecoding( past_bsnh, data.past_seq_lens, data.k_scale, // kv_cache_scale - // Map KVQuantizationType (0=NONE, 1=TENSOR, 2=CHANNEL) to XqaQuantType (0=FP16/BF16, 1=INT8, 2=FP8) - (parameters.k_quant_type == KVQuantizationType::NONE) ? onnxruntime::contrib::cuda::XqaQuantType::kNone : onnxruntime::contrib::cuda::XqaQuantType::kInt8, + // Map cache type to XqaQuantType: NONE->kNone, Float8E4M3FN->kFp8, int8->kInt8 + (parameters.k_quant_type == KVQuantizationType::NONE) ? XqaQuantType::kNone : (is_fp8 ? XqaQuantType::kFp8 : XqaQuantType::kInt8), xqa_workspace, xqa_workspace_size); @@ -668,6 +660,8 @@ Status FlashDecoding( GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, float scale) { + static_assert(std::is_same::type>::value); + static_assert(std::is_same::type>::value); assert(!parameters.is_first_prompt && parameters.past_present_share_buffer); ORT_GQA_TRACE("FlashDecoding"); @@ -680,7 +674,7 @@ Status FlashDecoding( const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; bool is_causal = parameters.is_unidirectional; - bool is_bf16 = std::is_same::value || std::is_same::value; + bool is_bf16 = std::is_same::value; void* query = reinterpret_cast(const_cast(data.query)); void* key; @@ -732,6 +726,9 @@ Status FlashAttention( GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, float scale) { + static_assert(std::is_same::type>::value); + static_assert(std::is_same::type>::value); + const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; @@ -742,7 +739,7 @@ Status FlashAttention( AttentionQkvFormat past_kv_format = parameters.past_kv_format; bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; bool is_causal = parameters.is_unidirectional; - bool is_bf16 = std::is_same::value || std::is_same::value; + bool is_bf16 = std::is_same::value; DUMP_TENSOR_INIT(); @@ -798,6 +795,9 @@ Status DequantizeFlashAttentionFallback( GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, float scale) { + static_assert(std::is_same::type>::value); + static_assert(std::is_same::type>::value); + assert(!parameters.is_first_prompt); // Only support first prompt for this function. assert(parameters.k_quant_type != KVQuantizationType::NONE || parameters.v_quant_type != KVQuantizationType::NONE); @@ -805,62 +805,48 @@ Status DequantizeFlashAttentionFallback( // We need to dequantize the entire KV cache (present_key/value) into a float/half buffer (data.qkv_buffer). // Layout in qkv_buffer: [Q (rotated)] [K_dequantized] [V_dequantized] - typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT; - CudaT* q_rot = reinterpret_cast(data.qkv_buffer); + + T* q_rot = reinterpret_cast(data.qkv_buffer); size_t q_elements = static_cast(parameters.batch_size) * parameters.sequence_length * parameters.num_heads * parameters.head_size; size_t k_elements = static_cast(parameters.batch_size) * parameters.seqlen_present_kv_cache * parameters.kv_num_heads * parameters.head_size; - CudaT* k_dequant = q_rot + q_elements; - CudaT* v_dequant = k_dequant + k_elements; + T* k_dequant = q_rot + q_elements; + T* v_dequant = k_dequant + k_elements; // Step 1: Update Quantized Cache // We can use LaunchUnpackRoPEQuantizeAppend to unpack new QKV, apply RoPE, and append to quantized cache. // This will also put rotated Q into q_rot. - ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppend( - parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), - q_rot, data.present_key, data.present_value, data.k_scale, data.v_scale, + ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend( + parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), + q_rot, reinterpret_cast(data.present_key), reinterpret_cast(data.present_value), + data.k_scale, data.v_scale, parameters.num_heads, parameters.kv_num_heads, parameters.head_size, parameters.sequence_length, parameters.batch_size, parameters.seqlen_present_kv_cache, data.past_seq_lens, - reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), + reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved, (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH), - parameters.k_quant_type, parameters.kv_cache_bit_width, - stream, device_prop.maxThreadsPerBlock)); + parameters.k_quant_type, + stream, device_prop.maxThreadsPerBlock))); // Step 2: Dequantize Entire Cache // We now have the updated quantized cache in data.present_key/value. We need to dequantize it to k_dequant/v_dequant. bool is_bsnh = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - if (parameters.kv_cache_bit_width == 8) { - ORT_RETURN_IF_ERROR((LaunchDequantizeKV( - stream, k_dequant, reinterpret_cast(data.present_key), data.k_scale, - nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, - parameters.head_size, 8, parameters.k_quant_type, is_bsnh))); + ORT_RETURN_IF_ERROR((LaunchDequantizeKV( + stream, k_dequant, reinterpret_cast(data.present_key), data.k_scale, + nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, + parameters.head_size, parameters.kv_cache_bit_width, parameters.k_quant_type, is_bsnh))); - ORT_RETURN_IF_ERROR((LaunchDequantizeKV( - stream, v_dequant, reinterpret_cast(data.present_value), data.v_scale, - nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, - parameters.head_size, 8, parameters.v_quant_type, is_bsnh))); -#ifdef USE_INT4_KV_CACHE - } else if (parameters.kv_cache_bit_width == 4) { - // Int4 support if needed - ORT_RETURN_IF_ERROR((LaunchDequantizeKV( - stream, k_dequant, reinterpret_cast(data.present_key), data.k_scale, - nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, - parameters.head_size, 4, parameters.k_quant_type, is_bsnh))); - - ORT_RETURN_IF_ERROR((LaunchDequantizeKV( - stream, v_dequant, reinterpret_cast(data.present_value), data.v_scale, - nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, - parameters.head_size, 4, parameters.v_quant_type, is_bsnh))); -#endif - } + ORT_RETURN_IF_ERROR((LaunchDequantizeKV( + stream, v_dequant, reinterpret_cast(data.present_value), data.v_scale, + nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, + parameters.head_size, parameters.kv_cache_bit_width, parameters.v_quant_type, is_bsnh))); // Step 3: Run Flash Attention on dequantized k/v bool is_causal = parameters.is_unidirectional; - bool is_bf16 = std::is_same::value || std::is_same::value; + bool is_bf16 = std::is_same::value; // Use the total_seq_lens here since k_dequant/v_dequant has both past and new tokens. void* seqlens_k_ptr = const_cast(reinterpret_cast(data.total_seq_lens)); @@ -880,7 +866,7 @@ Status DequantizeFlashAttentionFallback( return Status::OK(); } -// Use Flash Attention for float key and value, then quantize key/value to int8 to save to k/v cache. +// Use Flash Attention for float key and value, then quantize key/value (int8/fp8/int4) to save to k/v cache. template Status FlashAttentionAndQuantizeKV( const cudaDeviceProp& device_prop, @@ -888,6 +874,8 @@ Status FlashAttentionAndQuantizeKV( GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, float scale) { + static_assert(std::is_same::type>::value); + static_assert(std::is_same::type>::value); assert(parameters.is_first_prompt); // Only support first prompt for this function. assert(parameters.k_quant_type != KVQuantizationType::NONE || parameters.v_quant_type != KVQuantizationType::NONE); @@ -900,37 +888,35 @@ Status FlashAttentionAndQuantizeKV( ORT_GQA_TRACE("FlashAttentionAndQuantizeKV"); - bool past_bsnh = parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_ENFORCE(parameters.past_kv_format != AttentionQkvFormat::Q_K_V_BSNH, "GQA only supports BNSH format for KV cache."); size_t q_elements = static_cast(batch_size) * sequence_length * num_heads * head_size; size_t k_elements = static_cast(batch_size) * sequence_length * kv_num_heads * head_size; - using CudaT = typename onnxruntime::cuda::OrtToCudaType::type; - CudaT* q_final = reinterpret_cast(data.qkv_buffer); + T* q_final = reinterpret_cast(data.qkv_buffer); // For FlashAttentionAndQuantizeKV, we need float K and V for attention. // We'll write them to qkv_buffer. - CudaT* k_final = q_final + q_elements; - CudaT* v_final = k_final + k_elements; - - ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppend( - parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), + T* k_final = q_final + q_elements; + T* v_final = k_final + k_elements; + + ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend( + parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), q_final, k_final, v_final, nullptr, nullptr, num_heads, kv_num_heads, head_size, sequence_length, batch_size, sequence_length, data.past_seq_lens, - reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), + reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved, false, // BSNH for scratch KVQuantizationType::NONE, - 0, // bit_width is 0 since we are not quantizing here. - stream, max_threads_per_block)); + stream, max_threads_per_block))); // 2. Run Float Flash Attention bool is_causal = parameters.is_unidirectional; - bool is_bf16 = std::is_same::value || std::is_same::value; + bool is_bf16 = std::is_same::value; int local_window_size = parameters.local_window_size > 0 ? parameters.local_window_size - 1 : -1; @@ -945,37 +931,18 @@ Status FlashAttentionAndQuantizeKV( true, // kv_bsnh = true (BSNH) local_window_size)); - // 3. Quantize K and V to present cache if (parameters.k_quant_type != KVQuantizationType::NONE) { - if (parameters.kv_cache_bit_width == 8) { - ORT_RETURN_IF_ERROR((LaunchQuantizeKV( - stream, reinterpret_cast(data.present_key), reinterpret_cast(k_final), data.k_scale, - nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, - head_size, 8, parameters.k_quant_type, true, past_bsnh))); -#ifdef USE_INT4_KV_CACHE - } else if (parameters.kv_cache_bit_width == 4) { - ORT_RETURN_IF_ERROR((LaunchQuantizeKV( - stream, reinterpret_cast(data.present_key), reinterpret_cast(k_final), data.k_scale, - nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, - head_size, 4, parameters.k_quant_type, true, past_bsnh))); -#endif - } + ORT_RETURN_IF_ERROR((LaunchQuantizeKV( + stream, reinterpret_cast(data.present_key), reinterpret_cast(k_final), data.k_scale, + nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, + head_size, parameters.kv_cache_bit_width, parameters.k_quant_type, true))); } if (parameters.v_quant_type != KVQuantizationType::NONE) { - if (parameters.kv_cache_bit_width == 8) { - ORT_RETURN_IF_ERROR((LaunchQuantizeKV( - stream, reinterpret_cast(data.present_value), reinterpret_cast(v_final), data.v_scale, - nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, - head_size, 8, parameters.v_quant_type, true, past_bsnh))); -#ifdef USE_INT4_KV_CACHE - } else if (parameters.kv_cache_bit_width == 4) { - ORT_RETURN_IF_ERROR((LaunchQuantizeKV( - stream, reinterpret_cast(data.present_value), reinterpret_cast(v_final), data.v_scale, - nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, - head_size, 4, parameters.v_quant_type, true, past_bsnh))); -#endif - } + ORT_RETURN_IF_ERROR((LaunchQuantizeKV( + stream, reinterpret_cast(data.present_value), reinterpret_cast(v_final), data.v_scale, + nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, + head_size, parameters.kv_cache_bit_width, parameters.v_quant_type, true))); } return Status::OK(); @@ -990,6 +957,9 @@ Status EfficientAttention( GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, float scale) { + static_assert(std::is_same::type>::value); + static_assert(std::is_same::type>::value); + const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; @@ -1027,7 +997,7 @@ Status EfficientAttention( MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; - p.is_bf16 = std::is_same::value || std::is_same::value; + p.is_bf16 = std::is_same::value; p.is_half = !p.is_bf16 && (sizeof(T) == 2); p.batch_size = batch_size; p.num_heads = num_heads; @@ -1105,7 +1075,6 @@ Status QkvToContext( template struct GroupQueryAttentionData; template struct GroupQueryAttentionData<__nv_bfloat16, __nv_bfloat16>; -template struct GroupQueryAttentionData; template struct GroupQueryAttentionData; template Status QkvToContext( @@ -1122,13 +1091,6 @@ template Status QkvToContext<__nv_bfloat16, __nv_bfloat16>( contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData<__nv_bfloat16, __nv_bfloat16>& data); -template Status QkvToContext( - const cudaDeviceProp& device_prop, - cublasHandle_t& cublas, - Stream* ort_stream, - contrib::GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data); - template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, @@ -1145,6 +1107,7 @@ template Status QkvToContext<__nv_bfloat16, int8_t>( contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData<__nv_bfloat16, int8_t>& data); +#ifdef USE_INT4_KV_CACHE template struct GroupQueryAttentionData; template Status QkvToContext( @@ -1162,6 +1125,27 @@ template Status QkvToContext<__nv_bfloat16, uint8_t>( Stream* ort_stream, contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData<__nv_bfloat16, uint8_t>& data); +#endif + +#ifdef USE_FP8_KV_CACHE +template struct GroupQueryAttentionData; + +template Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data); + +template struct GroupQueryAttentionData<__nv_bfloat16, __nv_fp8_e4m3>; + +template Status QkvToContext<__nv_bfloat16, __nv_fp8_e4m3>( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData<__nv_bfloat16, __nv_fp8_e4m3>& data); +#endif template Status LaunchUnpackQKV(const half* packed_qkv, half* unpacked_q, half* unpacked_k, half* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block); template Status LaunchUnpackQKV<__nv_bfloat16, LAYOUT_BNSH>(const __nv_bfloat16* packed_qkv, __nv_bfloat16* unpacked_q, __nv_bfloat16* unpacked_k, __nv_bfloat16* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index 8cd4b44b9832e..78b061837e402 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -133,18 +133,6 @@ Status LaunchGetSequenceLengths( cudaStream_t stream, const int max_threads_per_block); -template -Status LaunchUnpackRoPEAppend( - const T* packed_qkv, const T* query, const T* key, const T* value, - T* unpacked_q, void* k_cache, void* v_cache, - const float* k_scale, const float* v_scale, - const int num_heads, const int kv_num_heads, const int head_size, - const int sequence_length, const int batch_size, const int max_seqlen, - const int* past_seq_lens, const T* cos_cache, const T* sin_cache, - const int rotary_dim, const int64_t* position_ids, const bool interleaved, - const bool is_cache_bnsh, const KVQuantizationType k_quant_type, - const int bit_width, cudaStream_t stream, const int max_threads_per_block); - } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh index 3aa9d6d96789a..b69b0238686a6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh @@ -2,10 +2,11 @@ // Licensed under the MIT License. #pragma once -// Enable quantized KV cache support for INT8/INT4 +// Enable quantized KV cache support for INT8/INT4/FP8 #define KV_QUANT_SUPPORTED 1 #include +#include #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cpu/bert/attention_common.h" @@ -25,6 +26,7 @@ constexpr int kInt4Max = 7; constexpr int kInt8Min = -128; constexpr int kInt8Max = 127; constexpr int kInt4ZeroPacked = 0x88; // (0 + 8) | ((0 + 8) << 4) for INT4 zero padding +constexpr float kFp8E4M3Max = 448.0f; // Max value for E4M3 format constexpr int kThreadsPerBlock = 256; template @@ -47,7 +49,7 @@ struct TypeConverter<__nv_bfloat16> { // ============================================================================ // // This file implements symmetric quantization for KV cache in GroupQueryAttention. -// Supports INT4 and INT8 with PER_TENSOR and PER_CHANNEL quantization modes. +// Supports INT4, INT8, and FP8 (E4M3) with PER_TENSOR and PER_CHANNEL quantization modes. // // QUANTIZATION SCHEME: // ------------------- @@ -86,10 +88,15 @@ struct TypeConverter<__nv_bfloat16> { // ------------- // Cache: BNSH (batch, num_heads, sequence_length, head_size) // INT4: (head_size + 1) / 2 bytes per head -// INT8: head_size bytes per head +// INT8/FP8: head_size bytes per head +// +// FP8 E4M3: Native CUDA FP8 format +// - Range: [-448, 448] +// - Storage: __nv_fp8_e4m3 (1 byte) +// - Conversion: Native CUDA cast via __nv_cvt_float_to_fp8/fp8_to_float // ============================================================================ -// Dequantization Kernel: Converts Quantized (Int8/Int4) KV cache back to Floating Point (T). +// Dequantization Kernel: Converts Quantized (Int8/Int4/FP8) KV cache back to Floating Point (T). // Iterates over every individual element with one thread per element. template __global__ void DequantizeKernel(T* dequantized_data, @@ -143,7 +150,14 @@ __global__ void DequantizeKernel(T* dequantized_data, (bit_width == 4 ? h / 2 : h); } - if (bit_width == 8) { + // FP8 check must come first since it also has bit_width=8 +#ifdef USE_FP8_KV_CACHE + if constexpr (std::is_same::value) { + __nv_fp8_e4m3 fp8_val = reinterpret_cast(quantized_data)[input_idx]; + quantized_float = static_cast(fp8_val); + } else +#endif + if (bit_width == 8) { quantized_float = static_cast( reinterpret_cast(quantized_data)[input_idx]); #ifdef USE_INT4_KV_CACHE @@ -181,7 +195,7 @@ Status LaunchDequantizeKV(cudaStream_t stream, T* dequantized_data, return CUDA_CALL(cudaGetLastError()); } -// Quantization Kernel: Converts Floating Point (T) cache to Quantized (Int8/Int4) values. +// Quantization Kernel: Converts Floating Point (T) cache to Quantized (Int8/Int4/FP8) values. // Note: This kernel is used to quantize a full input tensor, e.g. during graph initialization // or fallback paths. The main prompt path uses the fused UnpackRoPEAppend kernel. template @@ -193,8 +207,7 @@ __global__ void QuantizeKernel(T_QUANT* quantized_data, int input_sequence_length, int cache_sequence_length, int num_heads, int head_size, int bit_width, KVQuantizationType quant_type, - bool is_input_bsnh, - bool is_output_bsnh) { + bool is_input_bsnh) { // elements_per_head_packed is the number of BYTES occupied by head_size elements. int elements_per_head_packed = (bit_width == 4) ? (head_size + 1) / 2 : head_size; @@ -220,17 +233,14 @@ __global__ void QuantizeKernel(T_QUANT* quantized_data, if (s >= total_valid_len_b) { if (bit_width == 8) { int64_t out_idx = i; - if (is_output_bsnh) { - int64_t b_idx = b; - int64_t n_idx = n; - int64_t s_idx = s; - int64_t h_idx = i % elements_per_head_packed; - out_idx = b_idx * cache_sequence_length * num_heads * elements_per_head_packed + - s_idx * num_heads * elements_per_head_packed + - n_idx * elements_per_head_packed + - h_idx; - } + reinterpret_cast(quantized_data)[out_idx] = 0; +#ifdef USE_FP8_KV_CACHE + } else if constexpr (std::is_same::value) { // FP8 + int64_t out_idx = i; + + reinterpret_cast<__nv_fp8_e4m3*>(quantized_data)[out_idx] = __nv_fp8_e4m3(0.0f); +#endif #ifdef USE_INT4_KV_CACHE } else if (bit_width == 4) { // INT4 // With packed iteration, each thread handles one byte (2 values). @@ -242,16 +252,7 @@ __global__ void QuantizeKernel(T_QUANT* quantized_data, // Since `h_idx` comes from `i % elements_per_head_packed`, `out_idx` is guaranteed // to be within the buffer bounds. Writing kInt4ZeroPacked is safe. int64_t out_idx = i; - if (is_output_bsnh) { - int64_t b_idx = b; - int64_t n_idx = n; - int64_t s_idx = s; - int64_t h_idx = i % elements_per_head_packed; - out_idx = b_idx * cache_sequence_length * num_heads * elements_per_head_packed + - s_idx * num_heads * elements_per_head_packed + - n_idx * elements_per_head_packed + - h_idx; - } + // INT4 uses +8 bias, so zero values pack to 0x88 reinterpret_cast(quantized_data)[out_idx] = kInt4ZeroPacked; #endif @@ -260,18 +261,35 @@ __global__ void QuantizeKernel(T_QUANT* quantized_data, } int64_t output_idx = i; - if (is_output_bsnh) { - int64_t b_idx = b; - int64_t n_idx = n; - int64_t s_idx = s; - int64_t h_idx = i % elements_per_head_packed; - output_idx = b_idx * cache_sequence_length * num_heads * elements_per_head_packed + - s_idx * num_heads * elements_per_head_packed + - n_idx * elements_per_head_packed + - h_idx; - } - if (bit_width == 8) { +#ifdef USE_FP8_KV_CACHE + if constexpr (std::is_same::value) { + int h = h_packed; + float scale_val = 1.0f; + if (quant_type == KVQuantizationType::PER_TENSOR) { + scale_val = static_cast(scale[0]); + } else { // PER_CHANNEL + int scale_idx = n * head_size + h; + scale_val = static_cast(scale[scale_idx]); + } + + float inv_scale = (scale_val == 0.0f) ? 0.0f : 1.0f / scale_val; + int64_t flattened_input_idx = is_input_bsnh ? (static_cast(b) * input_sequence_length * num_heads * head_size + + static_cast(s) * num_heads * head_size + + static_cast(n) * head_size + + h) + : ((int64_t)b * num_heads * input_sequence_length * head_size + + (int64_t)n * input_sequence_length * head_size + + (int64_t)s * head_size + + h); + float val_float = static_cast(dequantized_data[flattened_input_idx]) * inv_scale; + + // Clamp to FP8 E4M3 range and convert + val_float = fmaxf(-kFp8E4M3Max, fminf(kFp8E4M3Max, val_float)); + reinterpret_cast<__nv_fp8_e4m3*>(quantized_data)[output_idx] = __nv_fp8_e4m3(val_float); + } else +#endif + if (bit_width == 8) { int h = h_packed; float scale_val = 1.0f; if (quant_type == KVQuantizationType::PER_TENSOR) { @@ -363,8 +381,7 @@ Status LaunchQuantizeKV(cudaStream_t stream, T_QUANT* quantized_data, int batch_size, int num_heads, int input_sequence_length, int cache_sequence_length, int head_size, int bit_width, KVQuantizationType quant_type, - bool is_input_bsnh, - bool is_output_bsnh) { + bool is_input_bsnh) { assert(total_seq_lens != nullptr); if (cache_sequence_length == 0) return Status::OK(); @@ -375,7 +392,7 @@ Status LaunchQuantizeKV(cudaStream_t stream, T_QUANT* quantized_data, QuantizeKernel<<>>( quantized_data, dequantized_data, scale, past_seq_lens, total_seq_lens, total_packed_elements, - input_sequence_length, cache_sequence_length, num_heads, head_size, bit_width, quant_type, is_input_bsnh, is_output_bsnh); + input_sequence_length, cache_sequence_length, num_heads, head_size, bit_width, quant_type, is_input_bsnh); return CUDA_CALL(cudaGetLastError()); } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh index d5c95be316a1f..20f0144c335ee 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh @@ -3,11 +3,16 @@ #pragma once #include +#include +#ifdef USE_FP8_KV_CACHE +#include +#endif #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cpu/bert/attention_common.h" #include "contrib_ops/cuda/bert/rotary_common.cuh" #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cuda_type_conversion.h" #include "core/providers/cuda/shared_inc/cuda_call.h" using namespace onnxruntime::cuda; @@ -28,18 +33,19 @@ namespace cuda { // 4. Writes the rotated Q back to global memory (unpacked_q) for the subsequent attention kernel. // // Template Parameters: -// - T: The floating point type (half or BFloat16). -// - BIT_WIDTH: The bit width for KV cache quantization (16=none, 8=Int8, 4=Int4). +// - T: The floating point type for query (half or __nv_bfloat16). +// - U: The cache element type (T for no quant, int8_t for INT8, uint8_t for INT4, __nv_fp8_e4m3 for FP8). +// - BIT_WIDTH: The bit width for KV cache quantization (16=none, 8=Int8/FP8, 4=Int4). // - MAX_HEAD_SIZE: Maximum supported head size, used for shared memory allocation. -template +template __global__ void UnpackRoPEAppend( const T* packed_qkv, const T* query, const T* key, const T* value, T* unpacked_q, - void* k_cache, - void* v_cache, + U* k_cache, + U* v_cache, const float* k_scale, const float* v_scale, const int num_heads, @@ -200,24 +206,39 @@ __global__ void UnpackRoPEAppend( // No quantization: direct store reinterpret_cast(cache_ptr)[cache_idx / elements_per_thread] = *reinterpret_cast(vals); } else if constexpr (BIT_WIDTH == 8) { - // Int8 Quantization: 1 element per byte + // 8-bit quantization: either INT8 or FP8 E4M3 based on cache type U const float* scale_buffer = (head_type == KEY) ? k_scale : v_scale; uint64_t packed = 0; - for (int i = 0; i < elements_per_thread; ++i) { - float sc = per_channel ? scale_buffer[n * head_size + h + i] : scale_buffer[0]; - float inv_s = (sc == 0.0f) ? 0.0f : 1.0f / sc; - int8_t q = static_cast(max(-128.0f, min(127.0f, rintf(static_cast(vals[i]) * inv_s)))); - packed |= (static_cast(static_cast(q)) << (i * 8)); +#ifdef USE_FP8_KV_CACHE + if constexpr (std::is_same::value) { + // FP8 E4M3 Quantization: scale and convert to FP8 format + constexpr float kFp8E4M3Max = 448.0f; + for (int i = 0; i < elements_per_thread; ++i) { + float sc = per_channel ? scale_buffer[n * head_size + h + i] : scale_buffer[0]; + float scaled_val = min(kFp8E4M3Max, max(-kFp8E4M3Max, static_cast(vals[i]) * (sc == 0.0f ? 0.0f : 1.0f / sc))); + __nv_fp8_e4m3 fp8_val = __nv_fp8_e4m3(scaled_val); + packed |= (static_cast(*reinterpret_cast(&fp8_val)) << (i * 8)); + } + } else +#endif + { + // INT8 Quantization: round and clamp to [-128, 127] + for (int i = 0; i < elements_per_thread; ++i) { + float sc = per_channel ? scale_buffer[n * head_size + h + i] : scale_buffer[0]; + int8_t q = static_cast(max(-128.0f, min(127.0f, rintf(static_cast(vals[i]) * (sc == 0.0f ? 0.0f : 1.0f / sc))))); + packed |= (static_cast(static_cast(q)) << (i * 8)); + } } // Store 8 elements (8 bytes) at once - reinterpret_cast(cache_ptr)[cache_idx / 8] = packed; + unsigned char* cache_byte_ptr = reinterpret_cast((head_type == KEY) ? k_cache : v_cache); + reinterpret_cast(cache_byte_ptr + cache_idx)[0] = packed; } else if constexpr (BIT_WIDTH == 4) { // Int4 Quantization: 2 elements per byte constexpr float kInt4Min = -8.0f; constexpr float kInt4Max = 7.0f; const float* scale_buffer = (head_type == KEY) ? k_scale : v_scale; uint32_t packed = 0; - for (int i = 0; i < 4; ++i) { + for (int i = 0; i < elements_per_thread / 2; ++i) { // Elements are paired as (0,1), (2,3), etc. into single bytes. float s0 = per_channel ? scale_buffer[n * head_size + h + i * 2] : scale_buffer[0]; float s1 = per_channel ? scale_buffer[n * head_size + h + i * 2 + 1] : scale_buffer[0]; @@ -237,28 +258,28 @@ __global__ void UnpackRoPEAppend( // Internal dispatcher that selects the appropriate template specialization based on head_size. // MAX_HEAD_SIZE is used to optimize shared memory usage and kernel performance. -template +template Status DispatchUnpackRoPEAppendHeadSize( const dim3& grid, const dim3& block, cudaStream_t stream, const T* packed_qkv, const T* query, const T* key, const T* value, - T* unpacked_q, void* k_cache, void* v_cache, + T* unpacked_q, U* k_cache, U* v_cache, const float* k_scale, const float* v_scale, const int num_heads, const int kv_num_heads, const int head_size, const int d, const int max_seqlen, const int* past_seq_lens, const T* cos_cache, const T* sin_cache, const int rotary_dim, const int64_t* position_ids, const bool interleaved, const bool is_cache_bnsh, const bool per_channel) { if (head_size <= 64) { - UnpackRoPEAppend<<>>( + UnpackRoPEAppend<<>>( packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); } else if (head_size <= 128) { - UnpackRoPEAppend<<>>( + UnpackRoPEAppend<<>>( packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); } else if (head_size <= 256) { - UnpackRoPEAppend<<>>( + UnpackRoPEAppend<<>>( packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); @@ -269,18 +290,24 @@ Status DispatchUnpackRoPEAppendHeadSize( } // Public entry point to launch the Unpack+RoPE+Append kernel. -// Handles parameter validation, grid/block sizing, and bit-width dispatching. -template +// Handles parameter validation, grid/block sizing, and type-based dispatching. +// Template parameters: +// - T: Query/Key/Value floating point type (half or __nv_bfloat16) +// - U: Cache element type (T for no quant, int8_t for INT8, uint8_t for INT4, __nv_fp8_e4m3 for FP8) +template Status LaunchUnpackRoPEAppend( const T* packed_qkv, const T* query, const T* key, const T* value, - T* unpacked_q, void* k_cache, void* v_cache, + T* unpacked_q, U* k_cache, U* v_cache, const float* k_scale, const float* v_scale, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, const int max_seqlen, const int* past_seq_lens, const T* cos_cache, const T* sin_cache, const int rotary_dim, const int64_t* position_ids, const bool interleaved, const bool is_cache_bnsh, const KVQuantizationType k_quant_type, - const int bit_width, cudaStream_t stream, const int max_threads_per_block) { + cudaStream_t stream, const int max_threads_per_block) { + static_assert(std::is_same::type>::value); + static_assert(std::is_same::type>::value); + constexpr int elements_per_vector = sizeof(float4) / sizeof(T); if (head_size % elements_per_vector != 0) { @@ -315,26 +342,37 @@ Status LaunchUnpackRoPEAppend( bool per_channel = (k_quant_type == KVQuantizationType::PER_CHANNEL); - if (bit_width == 0) { - return DispatchUnpackRoPEAppendHeadSize( + // Dispatch based on cache type U: + // - std::is_same: No quantization (BIT_WIDTH=16) + // - std::is_same or FP8: 8-bit quantization (BIT_WIDTH=8) + // - std::is_same: 4-bit quantization (BIT_WIDTH=4) + if constexpr (std::is_same::value) { + // No quantization: cache type same as input type + return DispatchUnpackRoPEAppendHeadSize( grid, block, stream, packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); - } else if (bit_width == 8) { - return DispatchUnpackRoPEAppendHeadSize( + } else if constexpr (std::is_same::value +#ifdef USE_FP8_KV_CACHE + || std::is_same::value +#endif + ) { + // INT8 or FP8 quantization (both 8-bit, distinguished inside kernel by type check) + return DispatchUnpackRoPEAppendHeadSize( grid, block, stream, packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); #ifdef USE_INT4_KV_CACHE - } else if (bit_width == 4) { - return DispatchUnpackRoPEAppendHeadSize( + } else if constexpr (std::is_same::value) { + // INT4 quantization (packed 2 elements per byte) + return DispatchUnpackRoPEAppendHeadSize( grid, block, stream, packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); #endif + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported cache type U for GQA quantization."); } - - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported bit_width (", bit_width, ") for GQA quantization."); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/mha.h b/onnxruntime/contrib_ops/cuda/bert/xqa/mha.h index 5aa78aa242306..d803cb6fba531 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/mha.h +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/mha.h @@ -50,7 +50,7 @@ constexpr uint32_t inputSeqLen = 1; // speculative decoding if > 1 constexpr bool useKVCache = USE_KV_CACHE; using SeqLenDataType = uint32_t; -#endif +#endif // MHA_H_COMMON // Dependent definitions #ifndef MHA_H_DEPENDENT diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_128.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_128.cu new file mode 100644 index 0000000000000..612f2fd14f09a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_128.cu @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 128 +#define HEAD_DIM_NAMESPACE H128 + +#ifdef USE_FP8_KV_CACHE +#include "xqa_loader_bf16_fp8_impl.cuh" +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_256.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_256.cu new file mode 100644 index 0000000000000..9329679593e7c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_256.cu @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 256 +#define HEAD_DIM_NAMESPACE H256 + +#ifdef USE_FP8_KV_CACHE +#include "xqa_loader_bf16_fp8_impl.cuh" +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_64.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_64.cu new file mode 100644 index 0000000000000..d3144b5bb7e2b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_64.cu @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 64 +#define HEAD_DIM_NAMESPACE H64 + +#ifdef USE_FP8_KV_CACHE +#include "xqa_loader_bf16_fp8_impl.cuh" +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_impl.cuh new file mode 100644 index 0000000000000..481fcb63c1f8c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_impl.cuh @@ -0,0 +1,120 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "xqa_loader.h" +#include + +// HEAD_ELEMS must be defined by the including file +#ifndef HEAD_ELEMS +#error "HEAD_ELEMS must be defined before including xqa_loader_bf16_fp8_impl.cuh" +#endif + +// HEAD_DIM_NAMESPACE must be defined by the including file +#ifndef HEAD_DIM_NAMESPACE +#error "HEAD_DIM_NAMESPACE must be defined before including xqa_loader_bf16_fp8_impl.cuh" +#endif + +// Define global constants for FP8 E4M3 KV Cache with BF16 Query +#define CACHE_ELEM_ENUM 2 // FP8 E4M3 +#define USE_PAGED_KV_CACHE 0 +#define TOKENS_PER_PAGE 0 +#define INPUT_FP16 0 // Q is BF16 +#define ALLOW_MULTI_BLOCK_MODE 1 + +#pragma nv_diag_suppress 177 +#pragma nv_diag_suppress 20012 + +// Include common headers once +#include "cuda_hint.cuh" +#include "mha.h" +// Include all helpers globally to ensure visibility +#include "ldgsts.cuh" +#include "mhaUtils.cuh" +#include "mha_components.cuh" +#include "mma.cuh" +#include "utils.cuh" +#include "hostUtils.h" + +// Undefine HEAD_GRP_SIZE and M_TILESIZE to allow re-definition in impl gen +#undef HEAD_GRP_SIZE +#undef M_TILESIZE + +namespace onnxruntime { +namespace contrib { +namespace cuda { +namespace HEAD_DIM_NAMESPACE { + +// ============================================================================ +// FP8 E4M3 KV Cache Instantiations for BF16 Query +// ============================================================================ + +#define NAMESPACE_NAME grp4_bf16_fp8 +#define GRP_SIZE 4 +#define M_TILESIZE 8 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp8_bf16_fp8 +#define GRP_SIZE 8 +#define M_TILESIZE 8 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp16_bf16_fp8 +#define GRP_SIZE 16 +#define M_TILESIZE 16 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp32_bf16_fp8 +#define GRP_SIZE 32 +#define M_TILESIZE 32 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +Status LaunchXQAFp8KernelBF16( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + void* workspace, + size_t workspace_size) { + int group_size = num_heads / kv_num_heads; + switch (group_size) { + case 4: + return grp4_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 8: + return grp8_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 16: + return grp16_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 32: + return grp32_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA FP8 only supports group_size 4, 8, 16, 32. Input has ", group_size); + } +} + +} // namespace HEAD_DIM_NAMESPACE +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh index 644dec2c67bbd..c2d9c057c6e50 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh @@ -116,6 +116,28 @@ Status LaunchXQAInt8KernelBF16( void* workspace, size_t workspace_size); +#ifdef USE_FP8_KV_CACHE +// Extern declarations for FP8 kernels with BF16 query (implemented in xqa_loader_bf16_fp8_impl.cuh) +Status LaunchXQAFp8KernelBF16( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + void* workspace, + size_t workspace_size); +#endif + // ============================================================================ // Specialization for BFloat16 // ============================================================================ @@ -171,6 +193,16 @@ Status LaunchXQAKernelImpl<__nv_bfloat16>( workspace_size); } +#ifdef USE_FP8_KV_CACHE + // Dispatch to FP8 path if requested + if (kv_quant_type == XqaQuantType::kFp8) { + return LaunchXQAFp8KernelBF16(device_prop, stream, query, key_cache, value_cache, output, + batch_size, num_heads, kv_num_heads, head_size, max_seq_len, + scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, + workspace_size); + } +#endif + int group_size = num_heads / kv_num_heads; switch (group_size) { case 1: diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_128.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_128.cu new file mode 100644 index 0000000000000..f9697fdd2f614 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_128.cu @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 128 +#define HEAD_DIM_NAMESPACE H128 + +#ifdef USE_FP8_KV_CACHE +#include "xqa_loader_fp16_fp8_impl.cuh" +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_256.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_256.cu new file mode 100644 index 0000000000000..3f5d9ac3f5507 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_256.cu @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 256 +#define HEAD_DIM_NAMESPACE H256 + +#ifdef USE_FP8_KV_CACHE +#include "xqa_loader_fp16_fp8_impl.cuh" +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_64.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_64.cu new file mode 100644 index 0000000000000..ce894ebc384a6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_64.cu @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 64 +#define HEAD_DIM_NAMESPACE H64 + +#ifdef USE_FP8_KV_CACHE +#include "xqa_loader_fp16_fp8_impl.cuh" +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_impl.cuh new file mode 100644 index 0000000000000..5e18d21defb79 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_impl.cuh @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "xqa_loader.h" +#include + +// HEAD_ELEMS must be defined by the including file +#ifndef HEAD_ELEMS +#error "HEAD_ELEMS must be defined before including xqa_loader_fp16_fp8_impl.cuh" +#endif + +// HEAD_DIM_NAMESPACE must be defined by the including file +#ifndef HEAD_DIM_NAMESPACE +#error "HEAD_DIM_NAMESPACE must be defined before including xqa_loader_fp16_fp8_impl.cuh" +#endif + +// Define global constants for FP8 E4M3 KV Cache +#define CACHE_ELEM_ENUM 2 // FP8 E4M3 +#define USE_PAGED_KV_CACHE 0 +#define TOKENS_PER_PAGE 0 +#define INPUT_FP16 1 // Q is FP16 +#define ALLOW_MULTI_BLOCK_MODE 1 + +#pragma nv_diag_suppress 177 +#pragma nv_diag_suppress 20012 + +// Include common headers once +#include "cuda_hint.cuh" +#include "mha.h" +// Include all helpers globally to ensure visibility +#include "ldgsts.cuh" +#include "mhaUtils.cuh" +#include "mha_components.cuh" +#include "mma.cuh" +#include "utils.cuh" +#include "hostUtils.h" + +// Undefine HEAD_GRP_SIZE and M_TILESIZE to allow re-definition in impl gen +#undef HEAD_GRP_SIZE +#undef M_TILESIZE + +namespace onnxruntime { +namespace contrib { +namespace cuda { +namespace HEAD_DIM_NAMESPACE { + +// ============================================================================ +// FP8 E4M3 KV Cache Instantiations for FP16 Query +// ============================================================================ +#define NAMESPACE_NAME grp4_fp8 +#define GRP_SIZE 4 +#define M_TILESIZE 8 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp8_fp8 +#define GRP_SIZE 8 +#define M_TILESIZE 8 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp16_fp8 +#define GRP_SIZE 16 +#define M_TILESIZE 16 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp32_fp8 +#define GRP_SIZE 32 +#define M_TILESIZE 32 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +Status LaunchXQAFp8Kernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + void* workspace, + size_t workspace_size) { + int group_size = num_heads / kv_num_heads; + switch (group_size) { + case 4: + return grp4_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 8: + return grp8_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 16: + return grp16_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 32: + return grp32_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA FP8 only supports group_size 4, 8, 16, 32. Input has ", group_size); + } +} + +} // namespace HEAD_DIM_NAMESPACE +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh index 8ba0fe3b1ee0d..675beb3c92d0f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh @@ -116,6 +116,28 @@ Status LaunchXQAInt8Kernel( void* workspace, size_t workspace_size); +#ifdef USE_FP8_KV_CACHE +// Extern declarations for FP8 kernels (implemented in xqa_loader_fp16_fp8_impl.cuh via instantiation) +Status LaunchXQAFp8Kernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + void* workspace, + size_t workspace_size); +#endif + // ============================================================================ // Dispatcher Implementation // ============================================================================ @@ -152,6 +174,18 @@ Status LaunchXQAKernelImpl( } } +#ifdef USE_FP8_KV_CACHE + // Dispatch to FP8 path if requested + if (kv_quant_type == XqaQuantType::kFp8) { + if constexpr (std::is_same::value) { + return LaunchXQAFp8Kernel(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + } else { + // BF16 case is handled in xqa_loader_bf16.cu via specialization + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA FP8 path mismatch."); + } + } +#endif + int group_size = num_heads / kv_num_heads; switch (group_size) { case 1: diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index ab692e0549d6c..e73ad25d96f38 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -115,6 +115,10 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16_int8_t, GroupQueryAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_uint8_t, GroupQueryAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16_uint8_t, GroupQueryAttention); #endif +#ifdef USE_FP8_KV_CACHE +class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_Float8E4M3FN, GroupQueryAttention); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16_Float8E4M3FN, GroupQueryAttention); +#endif class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, PagedAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, PagedAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DecoderAttention); @@ -361,6 +365,10 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { #ifdef USE_INT4_KV_CACHE BuildKernelCreateInfo, BuildKernelCreateInfo, +#endif +#ifdef USE_FP8_KV_CACHE + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index fee5d9556e75b..092c05f9e081a 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -342,8 +342,14 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte : past_dims[2].dim_value(); present_shape.add_dim()->set_dim_value(present_sequence_length); } else { - // Cannot compute exact present_sequence_length, copy from past_key (may be dynamic) - *present_shape.add_dim() = past_dims[2]; + // Cannot compute exact present_sequence_length. + if (ctx.getNumInputs() > 6 && past_dims[2].has_dim_value() && past_dims[2].dim_value() == 0) { + // If total_sequence_length is provided and past_key has 0 length, present_key will grow. + // Leave the dimension as dynamic to avoid "Error merging shape info" warning. + present_shape.add_dim(); + } else { + *present_shape.add_dim() = past_dims[2]; + } } *present_shape.add_dim() = past_dims[3]; // head_size diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc index 6c235f95aabcf..ef0dd065db523 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.cc +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -11,6 +11,7 @@ #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "core/providers/cuda/cuda_type_conversion.h" using namespace onnxruntime::cuda; @@ -116,8 +117,6 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { contribop_parameters.is_output_bnsh = false; } - typedef typename ToCudaType::MappedType CudaT; - // Check if this is Group Query Attention (GQA) const bool is_gqa = parameters.kv_num_heads != parameters.q_num_heads; @@ -196,6 +195,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { gqa_parameters.num_splits = 1; // Construct GroupQueryAttentionData + typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT; onnxruntime::contrib::cuda::GroupQueryAttentionData gqa_data; // Scratch buffers for flash/memory efficient attention @@ -481,151 +481,152 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { return onnxruntime::contrib::cuda::QkvToContext( device_prop, cublas, context->GetComputeStream(), gqa_parameters, gqa_data); - } + } else { // MHA path (kv_num_heads == q_num_heads) + typedef typename ToCudaType::MappedType CudaT; + contribop_parameters.batch_size = parameters.batch_size; + contribop_parameters.sequence_length = parameters.q_sequence_length; + contribop_parameters.kv_sequence_length = parameters.kv_sequence_length; + contribop_parameters.past_sequence_length = parameters.past_sequence_length; + contribop_parameters.total_sequence_length = parameters.total_sequence_length; + // max_sequence_length: For non-buffer-sharing case, this equals total_sequence_length (the present KV cache size) + contribop_parameters.max_sequence_length = parameters.total_sequence_length; + contribop_parameters.input_hidden_size = 0; // Not applicable - new Attention op takes pre-projected Q/K/V + contribop_parameters.hidden_size = parameters.q_num_heads * parameters.head_size; + contribop_parameters.head_size = parameters.head_size; + contribop_parameters.v_head_size = parameters.v_head_size; + contribop_parameters.v_hidden_size = parameters.kv_num_heads * parameters.v_head_size; + contribop_parameters.num_heads = parameters.q_num_heads; + contribop_parameters.rotary_dim = 0; + contribop_parameters.num_splits = 1; + contribop_parameters.beam_width = 1; + contribop_parameters.is_unidirectional = parameters.is_causal; + contribop_parameters.past_present_share_buffer = false; // New Attention op doesn't share buffer + contribop_parameters.is_packed_qkv = false; + contribop_parameters.do_rotary = false; + + // The new Attention op uses attn_mask as attention_bias (additive bias), not as key_padding_mask + // So mask_type should always be MASK_NONE since we don't have a separate padding mask input + contribop_parameters.mask_type = onnxruntime::contrib::AttentionMaskType::MASK_NONE; + + // Determine broadcast flags for attention_bias (if it exists) + // Note: The new Attention op uses attn_mask as attention_bias + // The attention_bias should be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length) + // attn_mask can be 2D, 3D, or 4D. Broadcasting aligns from the right (trailing dimensions). + if (attn_mask != nullptr) { + // TODO(titaiwang, xadupre): attn_mask bool is not supported yet + if (attn_mask->IsDataType()) { + ORT_THROW("Boolean attn_mask is not supported yet in Attention op (CUDA)."); + } - // MHA path (kv_num_heads == q_num_heads) - contribop_parameters.batch_size = parameters.batch_size; - contribop_parameters.sequence_length = parameters.q_sequence_length; - contribop_parameters.kv_sequence_length = parameters.kv_sequence_length; - contribop_parameters.past_sequence_length = parameters.past_sequence_length; - contribop_parameters.total_sequence_length = parameters.total_sequence_length; - // max_sequence_length: For non-buffer-sharing case, this equals total_sequence_length (the present KV cache size) - contribop_parameters.max_sequence_length = parameters.total_sequence_length; - contribop_parameters.input_hidden_size = 0; // Not applicable - new Attention op takes pre-projected Q/K/V - contribop_parameters.hidden_size = parameters.q_num_heads * parameters.head_size; - contribop_parameters.head_size = parameters.head_size; - contribop_parameters.v_head_size = parameters.v_head_size; - contribop_parameters.v_hidden_size = parameters.kv_num_heads * parameters.v_head_size; - contribop_parameters.num_heads = parameters.q_num_heads; - contribop_parameters.rotary_dim = 0; - contribop_parameters.num_splits = 1; - contribop_parameters.beam_width = 1; - contribop_parameters.is_unidirectional = parameters.is_causal; - contribop_parameters.past_present_share_buffer = false; // New Attention op doesn't share buffer - contribop_parameters.is_packed_qkv = false; - contribop_parameters.do_rotary = false; - - // The new Attention op uses attn_mask as attention_bias (additive bias), not as key_padding_mask - // So mask_type should always be MASK_NONE since we don't have a separate padding mask input - contribop_parameters.mask_type = onnxruntime::contrib::AttentionMaskType::MASK_NONE; - - // Determine broadcast flags for attention_bias (if it exists) - // Note: The new Attention op uses attn_mask as attention_bias - // The attention_bias should be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length) - // attn_mask can be 2D, 3D, or 4D. Broadcasting aligns from the right (trailing dimensions). - if (attn_mask != nullptr) { - // TODO(titaiwang, xadupre): attn_mask bool is not supported yet - if (attn_mask->IsDataType()) { - ORT_THROW("Boolean attn_mask is not supported yet in Attention op (CUDA)."); + size_t attn_mask_dims_size = attn_mask->Shape().NumDimensions(); + auto attn_mask_dims = attn_mask->Shape().GetDims(); + // For 2D mask (q_seq_len, total_seq_len): both batch and heads dimensions need broadcasting + // For 3D mask (X, q_seq_len, total_seq_len): batch needs broadcasting if X==1, heads always needs broadcasting + // For 4D mask (B, H, q_seq_len, total_seq_len): check if B==1 and H==1 + + if (attn_mask_dims_size == 2) { + // 2D mask: both dimensions need broadcasting + contribop_parameters.broadcast_attn_bias_dim_0 = true; + contribop_parameters.broadcast_attn_bias_dim_1 = true; + } else if (attn_mask_dims_size == 3) { + // 3D mask: dim 0 broadcasts if it's 1, dim 1 (heads) always broadcasts + contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1; + contribop_parameters.broadcast_attn_bias_dim_1 = true; + } else { + // 4D mask: check both dim 0 and dim 1 explicitly + contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1; + contribop_parameters.broadcast_attn_bias_dim_1 = attn_mask_dims[1] == 1; + } + } else { + contribop_parameters.broadcast_attn_bias_dim_0 = false; + contribop_parameters.broadcast_attn_bias_dim_1 = false; } - size_t attn_mask_dims_size = attn_mask->Shape().NumDimensions(); - auto attn_mask_dims = attn_mask->Shape().GetDims(); - // For 2D mask (q_seq_len, total_seq_len): both batch and heads dimensions need broadcasting - // For 3D mask (X, q_seq_len, total_seq_len): batch needs broadcasting if X==1, heads always needs broadcasting - // For 4D mask (B, H, q_seq_len, total_seq_len): check if B==1 and H==1 - - if (attn_mask_dims_size == 2) { - // 2D mask: both dimensions need broadcasting - contribop_parameters.broadcast_attn_bias_dim_0 = true; - contribop_parameters.broadcast_attn_bias_dim_1 = true; - } else if (attn_mask_dims_size == 3) { - // 3D mask: dim 0 broadcasts if it's 1, dim 1 (heads) always broadcasts - contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1; - contribop_parameters.broadcast_attn_bias_dim_1 = true; - } else { - // 4D mask: check both dim 0 and dim 1 explicitly - contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1; - contribop_parameters.broadcast_attn_bias_dim_1 = attn_mask_dims[1] == 1; + contribop_parameters.mask_filter_value = -10000.0f; + contribop_parameters.scale = parameters.scale; + contribop_parameters.use_tf32 = UseTF32(); + // TODO(titaiwang, xadupre): qk_matmul_output_mode only supports kNone and kQK for now + if (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone && + qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kQK) { + ORT_THROW("qk_matmul_output_mode other than -1 (None) and 0 (QK) is not supported yet in Attention op (CUDA)."); + } + // TODO(titaiwang, xadupre): softcap and softmax_precision are not used yet + if (parameters.softcap != 0.0f) { + ORT_THROW("softcap is not supported yet in Attention op (CUDA)."); + } + if (parameters.softmax_precision != 0) { + ORT_THROW("softmax_precision is not supported yet in Attention op (CUDA)."); } - } else { - contribop_parameters.broadcast_attn_bias_dim_0 = false; - contribop_parameters.broadcast_attn_bias_dim_1 = false; - } - contribop_parameters.mask_filter_value = -10000.0f; - contribop_parameters.scale = parameters.scale; - contribop_parameters.use_tf32 = UseTF32(); - // TODO(titaiwang, xadupre): qk_matmul_output_mode only supports kNone and kQK for now - if (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone && - qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kQK) { - ORT_THROW("qk_matmul_output_mode other than -1 (None) and 0 (QK) is not supported yet in Attention op (CUDA)."); - } - // TODO(titaiwang, xadupre): softcap and softmax_precision are not used yet - if (parameters.softcap != 0.0f) { - ORT_THROW("softcap is not supported yet in Attention op (CUDA)."); - } - if (parameters.softmax_precision != 0) { - ORT_THROW("softmax_precision is not supported yet in Attention op (CUDA)."); - } + // Construct AttentionData to pass to QkvToContext + onnxruntime::contrib::cuda::AttentionData data; - // Construct AttentionData to pass to QkvToContext - onnxruntime::contrib::cuda::AttentionData data; - - // Set input pointers - data.query = reinterpret_cast(Q->Data()); - data.key = reinterpret_cast(K->Data()); - data.value = reinterpret_cast(V->Data()); - data.mask_index = nullptr; // New Attention op doesn't have key_padding_mask - data.mask_index_dims = gsl::span(); - data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data()); - data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); - - // Set output pointers - data.output = reinterpret_cast(Y->MutableData()); - data.present_key = (present_key == nullptr) ? nullptr : reinterpret_cast(present_key->MutableData()); - data.present_value = (present_value == nullptr) ? nullptr : reinterpret_cast(present_value->MutableData()); - if (nullptr != output_qk) { - data.output_qk = reinterpret_cast(output_qk->MutableData()); - } + // Set input pointers + data.query = reinterpret_cast(Q->Data()); + data.key = reinterpret_cast(K->Data()); + data.value = reinterpret_cast(V->Data()); + data.mask_index = nullptr; // New Attention op doesn't have key_padding_mask + data.mask_index_dims = gsl::span(); + data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data()); + data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); + + // Set output pointers + data.output = reinterpret_cast(Y->MutableData()); + data.present_key = (present_key == nullptr) ? nullptr : reinterpret_cast(present_key->MutableData()); + data.present_value = (present_value == nullptr) ? nullptr : reinterpret_cast(present_value->MutableData()); + if (nullptr != output_qk) { + data.output_qk = reinterpret_cast(output_qk->MutableData()); + } - // Set additional fields - data.bias = nullptr; // New Attention op doesn't have bias - if (nullptr != attn_mask) { - data.attention_bias = reinterpret_cast(attn_mask->Data()); + // Set additional fields + data.bias = nullptr; // New Attention op doesn't have bias + if (nullptr != attn_mask) { + data.attention_bias = reinterpret_cast(attn_mask->Data()); + } + data.qkv_format = contribop_parameters.qkv_format; + + // For now, set flags to false and let QkvToContext use the unfused path + data.use_flash_attention = false; + data.use_memory_efficient_attention = false; + data.fused_runner = nullptr; + data.fused_cross_attention_kernel = nullptr; + data.kernel_type = onnxruntime::contrib::AttentionKernelType::AttentionKernel_Unfused; + + // Allocate workspace for Q, K, V processing and scratch buffer + const bool no_qkv_workspace = onnxruntime::contrib::cuda::NoQkvWorkspace(contribop_parameters, data); + size_t workspace_bytes = onnxruntime::contrib::cuda::GetAttentionWorkspaceSize( + sizeof(T), + contribop_parameters.batch_size, + contribop_parameters.num_heads, + contribop_parameters.head_size, + contribop_parameters.v_head_size, + contribop_parameters.sequence_length, + contribop_parameters.kv_sequence_length, + contribop_parameters.total_sequence_length, + nullptr, // fused_runner + false, // use_flash_attention + false, // use_lean_attention + false, // use_fused_cross_attention + false, // use_memory_efficient_attention + false, // use_cudnn_flash_attention + no_qkv_workspace); + auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); + + data.has_qkv_workspace = !no_qkv_workspace; + data.workspace = reinterpret_cast(work_space.get()); + data.workspace_bytes = workspace_bytes; + + // Call QkvToContext to perform the attention computation + auto& device_prop = GetDeviceProp(); + cublasHandle_t cublas = GetCublasHandle(context); + cudnnHandle_t cudnn = GetCudnnHandle(context); + + // QkvToContext takes two template parameters: T for computation type, QK for output_qk type + // For now, both are the same type (CudaT) + + return onnxruntime::contrib::cuda::QkvToContext( + device_prop, cublas, cudnn, context->GetComputeStream(), contribop_parameters, data); } - data.qkv_format = contribop_parameters.qkv_format; - - // For now, set flags to false and let QkvToContext use the unfused path - data.use_flash_attention = false; - data.use_memory_efficient_attention = false; - data.fused_runner = nullptr; - data.fused_cross_attention_kernel = nullptr; - data.kernel_type = onnxruntime::contrib::AttentionKernelType::AttentionKernel_Unfused; - - // Allocate workspace for Q, K, V processing and scratch buffer - const bool no_qkv_workspace = onnxruntime::contrib::cuda::NoQkvWorkspace(contribop_parameters, data); - size_t workspace_bytes = onnxruntime::contrib::cuda::GetAttentionWorkspaceSize( - sizeof(T), - contribop_parameters.batch_size, - contribop_parameters.num_heads, - contribop_parameters.head_size, - contribop_parameters.v_head_size, - contribop_parameters.sequence_length, - contribop_parameters.kv_sequence_length, - contribop_parameters.total_sequence_length, - nullptr, // fused_runner - false, // use_flash_attention - false, // use_lean_attention - false, // use_fused_cross_attention - false, // use_memory_efficient_attention - false, // use_cudnn_flash_attention - no_qkv_workspace); - auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); - - data.has_qkv_workspace = !no_qkv_workspace; - data.workspace = reinterpret_cast(work_space.get()); - data.workspace_bytes = workspace_bytes; - - // Call QkvToContext to perform the attention computation - auto& device_prop = GetDeviceProp(); - cublasHandle_t cublas = GetCublasHandle(context); - cudnnHandle_t cudnn = GetCudnnHandle(context); - - // QkvToContext takes two template parameters: T for computation type, QK for output_qk type - // For now, both are the same type (CudaT) - return onnxruntime::contrib::cuda::QkvToContext( - device_prop, cublas, cudnn, context->GetComputeStream(), contribop_parameters, data); } } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py index 3caaca9663c2e..ced5f4ebec3aa 100644 --- a/onnxruntime/python/tools/transformers/io_binding_helper.py +++ b/onnxruntime/python/tools/transformers/io_binding_helper.py @@ -35,13 +35,20 @@ def get_output_type(ort_session, name: str) -> str: @staticmethod def ort_type_to_numpy_type(ort_type: str): ort_type_to_numpy_type_map = { - "tensor(int64)": numpy.longlong, - "tensor(int32)": numpy.intc, + "tensor(int64)": numpy.int64, + "tensor(int32)": numpy.int32, "tensor(float)": numpy.float32, "tensor(float16)": numpy.float16, "tensor(bool)": bool, "tensor(uint8)": numpy.uint8, "tensor(int8)": numpy.int8, + "tensor(double)": numpy.float64, + "tensor(int16)": numpy.int16, + "tensor(uint16)": numpy.uint16, + "tensor(uint32)": numpy.uint32, + "tensor(uint64)": numpy.uint64, + "tensor(complex64)": numpy.complex64, + "tensor(complex128)": numpy.complex128, } if ort_type not in ort_type_to_numpy_type_map: raise ValueError(f"{ort_type} not found in map") @@ -59,6 +66,19 @@ def ort_type_to_torch_type(ort_type: str): "tensor(bool)": torch.bool, "tensor(uint8)": torch.uint8, "tensor(int8)": torch.int8, + "tensor(double)": torch.float64, + "tensor(int16)": torch.int16, + "tensor(uint16)": torch.uint16, + "tensor(uint32)": torch.uint32, + "tensor(uint64)": torch.uint64, + "tensor(complex64)": torch.complex64, + "tensor(complex128)": torch.complex128, + "tensor(float8e4m3fn)": torch.float8_e4m3fn, + "tensor(float8e4m3fnuz)": torch.float8_e4m3fnuz, + "tensor(float8e5m2)": torch.float8_e5m2, + "tensor(float8e5m2fnuz)": torch.float8_e5m2fnuz, + "tensor(int4)": torch.int4, + "tensor(uint4)": torch.uint4, } if ort_type not in ort_type_to_torch_type_map: raise ValueError(f"{ort_type} not found in map") @@ -87,6 +107,21 @@ def ort_type_to_onnx_type(ort_type: str): "tensor(bool)": TensorProto.BOOL, "tensor(uint8)": TensorProto.UINT8, "tensor(int8)": TensorProto.INT8, + "tensor(double)": TensorProto.DOUBLE, + "tensor(int16)": TensorProto.INT16, + "tensor(uint16)": TensorProto.UINT16, + "tensor(uint32)": TensorProto.UINT32, + "tensor(uint64)": TensorProto.UINT64, + "tensor(complex64)": TensorProto.COMPLEX64, + "tensor(complex128)": TensorProto.COMPLEX128, + "tensor(float8e4m3fn)": TensorProto.FLOAT8E4M3FN, + "tensor(float8e4m3fnuz)": TensorProto.FLOAT8E4M3FNUZ, + "tensor(float8e5m2)": TensorProto.FLOAT8E5M2, + "tensor(float8e5m2fnuz)": TensorProto.FLOAT8E5M2FNUZ, + "tensor(float4e2m1)": TensorProto.FLOAT4E2M1, + "tensor(int4)": TensorProto.INT4, + "tensor(uint4)": TensorProto.UINT4, + "tensor(string)": TensorProto.STRING, } if ort_type not in ort_type_to_onnx_type_map: raise ValueError(f"{ort_type} not found in map") @@ -96,15 +131,22 @@ def ort_type_to_onnx_type(ort_type: str): @staticmethod def numpy_type_to_torch_type(numpy_type: numpy.dtype): numpy_type_to_torch_type_map = { - numpy.longlong: torch.int64, - numpy.intc: torch.int32, + numpy.int64: torch.int64, numpy.int32: torch.int32, numpy.float32: torch.float32, numpy.float16: torch.float16, bool: torch.bool, numpy.uint8: torch.uint8, numpy.int8: torch.int8, + numpy.float64: torch.float64, + numpy.int16: torch.int16, + numpy.uint16: torch.uint16, + numpy.uint32: torch.uint32, + numpy.uint64: torch.uint64, + numpy.complex64: torch.complex64, + numpy.complex128: torch.complex128, } + if numpy_type not in numpy_type_to_torch_type_map: raise ValueError(f"{numpy_type} not found in map") @@ -113,13 +155,22 @@ def numpy_type_to_torch_type(numpy_type: numpy.dtype): @staticmethod def torch_type_to_numpy_type(torch_type: torch.dtype): torch_type_to_numpy_type_map = { - torch.int64: numpy.longlong, - torch.int32: numpy.intc, + torch.int64: numpy.int64, + torch.int32: numpy.int32, torch.float32: numpy.float32, torch.float16: numpy.float16, torch.bool: bool, torch.uint8: numpy.uint8, + torch.int8: numpy.int8, + torch.float64: numpy.float64, + torch.int16: numpy.int16, + torch.uint16: numpy.uint16, + torch.uint32: numpy.uint32, + torch.uint64: numpy.uint64, + torch.complex64: numpy.complex64, + torch.complex128: numpy.complex128, } + if torch_type not in torch_type_to_numpy_type_map: raise ValueError(f"{torch_type} not found in map") diff --git a/onnxruntime/test/python/transformers/benchmark_gqa.py b/onnxruntime/test/python/transformers/benchmark_gqa.py index c44d6b606d3a2..3a835d0852a9d 100644 --- a/onnxruntime/test/python/transformers/benchmark_gqa.py +++ b/onnxruntime/test/python/transformers/benchmark_gqa.py @@ -22,6 +22,7 @@ class TestConfig: test_int4: bool = False test_int8: bool = False + test_fp8: bool = False def get_plot_algos(sm: int, local_window_size: int | None, config: TestConfig | None): @@ -37,17 +38,21 @@ def get_plot_algos(sm: int, local_window_size: int | None, config: TestConfig | # Add quantized variants if requested if sm >= 80 and config: - quant_vals = ["ort_gqa_int4", "ort_gqa_int8"] - quant_names = ["ORT-GQA-INT4", "ORT-GQA-INT8"] - quant_styles = [("purple", "dotted"), ("orange", "dashdot")] + quant_vals = ["ort_gqa_int4", "ort_gqa_int8", "ort_gqa_fp8"] + quant_names = ["ORT-GQA-INT4", "ORT-GQA-INT8", "ORT-GQA-FP8"] + quant_styles = [("purple", "dotted"), ("orange", "dashdot"), ("brown", "dashed")] if config.test_int4: - line_vals.extend(quant_vals[:1]) - line_names.extend(quant_names[:1]) - styles.extend(quant_styles[:1]) + line_vals.append(quant_vals[0]) + line_names.append(quant_names[0]) + styles.append(quant_styles[0]) if config.test_int8: - line_vals.extend(quant_vals[1:]) - line_names.extend(quant_names[1:]) - styles.extend(quant_styles[1:]) + line_vals.append(quant_vals[1]) + line_names.append(quant_names[1]) + styles.append(quant_styles[1]) + if config.test_fp8: + line_vals.append(quant_vals[2]) + line_names.append(quant_names[2]) + styles.append(quant_styles[2]) return { "line_vals": line_vals, @@ -116,6 +121,9 @@ def benchmark( elif "_int8" in provider: k_quant_type = v_quant_type = "PER_TENSOR" kv_cache_type = "int8" + elif "_fp8" in provider: + k_quant_type = v_quant_type = "PER_TENSOR" + kv_cache_type = "fp8" config: GroupQueryAttentionConfig = GroupQueryAttentionConfig( batch_size=batch_size, @@ -205,6 +213,10 @@ def benchmark( k_quant_type = v_quant_type = "PER_TENSOR" kv_cache_type = "int8" share_kv_scale = True # XQA requires shared scale + elif "_fp8" in provider: + k_quant_type = v_quant_type = "PER_TENSOR" + kv_cache_type = "fp8" + share_kv_scale = True # XQA requires shared scale config: GroupQueryAttentionConfig = GroupQueryAttentionConfig( batch_size=batch_size, @@ -303,7 +315,7 @@ def run_performance_test( s = torch.cuda.Stream() with torch.cuda.stream(s), torch.no_grad(): - config = TestConfig(test_int4=False, test_int8=True) + config = TestConfig(test_int4=False, test_int8=True, test_fp8=True) run_performance_test(sm, fast=True, config=config, dtype="float16", is_prompt=True) run_performance_test(sm, fast=True, config=config, dtype="float16", is_prompt=False) # run_performance_test(sm, fast=True, config=config, dtype="bfloat16", is_prompt=True) diff --git a/onnxruntime/test/python/transformers/gqa_test_helper.py b/onnxruntime/test/python/transformers/gqa_test_helper.py index cd34f4f420ad5..7f0d50a7ac8ed 100644 --- a/onnxruntime/test/python/transformers/gqa_test_helper.py +++ b/onnxruntime/test/python/transformers/gqa_test_helper.py @@ -21,6 +21,7 @@ "int32": TensorProto.INT32, "int8": TensorProto.INT8, "int4": TensorProto.UINT8, + "fp8": TensorProto.FLOAT8E4M3FN, } TORCH_DTYPE_TO_ONNX_MAP = { @@ -29,6 +30,7 @@ torch.bfloat16: TensorProto.BFLOAT16, torch.int32: TensorProto.INT32, torch.int8: TensorProto.INT8, + torch.float8_e4m3fn: TensorProto.FLOAT8E4M3FN, } TORCH_DTYPE_MAP = { @@ -37,6 +39,7 @@ "bfloat16": torch.bfloat16, "int8": torch.int8, "int4": torch.uint8, + "fp8": torch.float8_e4m3fn, } NUMPY_DTYPE_MAP = { @@ -45,6 +48,7 @@ "bfloat16": numpy.uint16, "int8": numpy.int8, "int4": numpy.uint8, + "fp8": numpy.uint8, # FP8 E4M3 stored as uint8 } @@ -54,6 +58,8 @@ def get_q_range(q_type_str): return -128, 127 if q_type_str.endswith("int4"): return -8, 7 + if q_type_str == "fp8": + return -448.0, 448.0 # FP8 E4M3 range raise ValueError(f"Unsupported quantization type for range: {q_type_str}") @@ -108,8 +114,14 @@ def dequantize_tensor(quantized_tensor, scale, quant_type, q_type_str): if isinstance(scale, torch.Tensor): scale = scale.to(quantized_tensor.device) - unpacked_tensor = quantized_tensor q_type_str_s = str(q_type_str) + + # FP8 dequantization: cast to float32 and multiply by scale + if q_type_str_s == "fp8": + # FP8 tensors are already float8_e4m3fn, just cast and scale + return quantized_tensor.to(torch.float32) * scale + + unpacked_tensor = quantized_tensor if q_type_str_s.endswith("int4"): unpacked_tensor = unpack_int4(quantized_tensor) @@ -121,10 +133,20 @@ def quantize_tensor_with_scale(tensor_float, scale, quant_type, q_type_str): if quant_type == "NONE": return tensor_float + q_type_str_s = str(q_type_str) + + # FP8 quantization: scale and cast to float8_e4m3fn (no rounding needed) + if q_type_str_s == "fp8": + # FP8 E4M3 has max representable value of 448.0 + # Scale the tensor and clamp to FP8 range, then cast + scaled = tensor_float / scale + clamped = torch.clamp(scaled, -448.0, 448.0) + return clamped.to(torch.float8_e4m3fn) + + # INT8/INT4 quantization: scale, round, clamp to integer range qmin, qmax = get_q_range(q_type_str) quantized = torch.clamp(torch.round(tensor_float / scale), qmin, qmax) - q_type_str_s = str(q_type_str) if q_type_str_s.endswith("int4"): quantized = pack_int4(quantized.to(torch.int8)) else: @@ -318,10 +340,17 @@ def __init__( # Quantization parameters self.k_quant_type = k_quant_type self.v_quant_type = v_quant_type - self.kv_cache_type = kv_cache_type - # Determine bit width from cache type if applicable - self.kv_cache_bit_width = 4 if kv_cache_type == "int4" else (8 if kv_cache_type == "int8" else 0) self.share_kv_scale = share_kv_scale + # Determine bit width from cache type if applicable + if kv_cache_type == "int4": + self.kv_cache_bit_width = 4 + elif kv_cache_type == "int8": + self.kv_cache_bit_width = 8 + elif kv_cache_type == "fp8": + self.kv_cache_bit_width = 8 # FP8 is 8 bits + else: + self.kv_cache_bit_width = 0 + self.kv_cache_type = kv_cache_type def shape_dict(self): shapes = super().shape_dict() @@ -450,6 +479,8 @@ def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig): cache_type = TensorProto.UINT8 elif config.kv_cache_type == "int8": cache_type = TensorProto.INT8 + elif config.kv_cache_type == "fp8": + cache_type = TensorProto.FLOAT8E4M3FN # Compute actual cache shapes (packed for INT4) past_key_shape = list(shape_dict["past_key"]) diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index 5cbba989a4dbd..6def1be804743 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -57,7 +57,10 @@ has_int4_kv_cache = ", int4-kv-cache=" in get_build_info() -enable_debug_print = False +has_fp8_kv_cache = ", fp8-kv-cache=" in get_build_info() + +# Enable debug print if tensor or node dumping is enabled in build. +enable_debug_print = ("dump-tensor" in get_build_info()) or ("dump-node" in get_build_info()) enable_deterministic_check = True # ################################################################################################# @@ -674,7 +677,8 @@ def gqa_past_func( k_scale = k_scale.to(torch.float32) k_scale = k_scale.contiguous() bind_tensor(io_binding, "k_scale", k_scale, device, k_scale_ort_type) - if v_scale is not None: + + if v_scale is not None and not config.share_kv_scale: v_scale_ort_type = TensorProto.FLOAT if v_scale.dtype != torch.float32: v_scale = v_scale.to(torch.float32) @@ -931,20 +935,30 @@ def parity_check_gqa_prompt( elif causal: window_size = (-1, 0) - # --- PyTorch Reference Path --- - if config.kv_cache_bit_width == 4 or config.kv_cache_type == "int8": + if config.kv_cache_bit_width == 4 or config.kv_cache_type == "int8" or config.kv_cache_type == "fp8": + # k/v are already quantized (int8/fp8) in inputs k_ref_dequant = dequantize_tensor(k, k_scale, config.k_quant_type, config.kv_cache_type) v_ref_dequant = dequantize_tensor(v, v_scale, config.v_quant_type, config.kv_cache_type) else: k_ref_dequant = dequantize_tensor( - quantize_tensor_with_scale(k, k_scale, config.k_quant_type, config.kv_cache_type), - k_scale, + quantize_tensor_with_scale( + k, + k_scale.to(torch.float32) if k_scale is not None else None, + config.k_quant_type, + config.kv_cache_type, + ), + k_scale.to(torch.float32) if k_scale is not None else None, config.k_quant_type, config.kv_cache_type, ) v_ref_dequant = dequantize_tensor( - quantize_tensor_with_scale(v, v_scale, config.v_quant_type, config.kv_cache_type), - v_scale, + quantize_tensor_with_scale( + v, + v_scale.to(torch.float32) if v_scale is not None else None, + config.v_quant_type, + config.kv_cache_type, + ), + v_scale.to(torch.float32) if v_scale is not None else None, config.v_quant_type, config.kv_cache_type, ) @@ -1097,6 +1111,9 @@ def parity_check_gqa_prompt( elif config.kv_cache_type == "int8": # For int8, present_k is int8 data present_k_torch = torch.from_numpy(present_k.astype(numpy.int8)).to(device) + elif config.kv_cache_type == "fp8": + # For fp8, present_k is float8_e4m3fn data, returned as uint8/int8 by ORT python + present_k_torch = torch.from_numpy(present_k).view(torch.float8_e4m3fn).to(device) else: present_k_torch = torch.from_numpy(present_k).to(device) @@ -1134,6 +1151,8 @@ def parity_check_gqa_prompt( present_v_torch = torch.from_numpy(present_v).to(device) elif config.kv_cache_type == "int8": present_v_torch = torch.from_numpy(present_v.astype(numpy.int8)).to(device) + elif config.kv_cache_type == "fp8": + present_v_torch = torch.from_numpy(present_v).view(torch.float8_e4m3fn).to(device) else: present_v_torch = torch.from_numpy(present_v).to(device) @@ -1345,8 +1364,8 @@ def parity_check_gqa_past( # Quantize k and v for ORT when using quantized KV cache k_ort = k v_ort = v - if config.kv_cache_type in ["int8", "int4"]: - # NOTE: Quantize returns tensor with kv_cache_type (int8) + if config.kv_cache_type in ["int8", "int4", "fp8"]: + # NOTE: Quantize returns tensor with kv_cache_type (int8, int4, or fp8) k_ort = quantize_tensor_with_scale(k, k_scale, config.k_quant_type, config.kv_cache_type) v_ort = quantize_tensor_with_scale(v, v_scale, config.v_quant_type, config.kv_cache_type) @@ -1386,26 +1405,37 @@ def parity_check_gqa_past( if numpy.count_nonzero(out_ref_np) > 0 and numpy.count_nonzero(out_np) == 0: raise RuntimeError("Output is all zeros") + print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol, atol=atol) + # --- Comparison --- - if config.k_quant_type == "NONE" and config.v_quant_type == "NONE": + compare_kv = (config.k_quant_type == "NONE" and config.v_quant_type == "NONE") or (config.kv_cache_type == "fp8") + if compare_kv: # Compare KV cache # Transpose reference back to BNSH to match ORT output k_cache_ref_np = k_cache_ref.transpose(1, 2).to(torch.float32).detach().cpu().numpy() v_cache_ref_np = v_cache_ref.transpose(1, 2).to(torch.float32).detach().cpu().numpy() - present_k_np = present_k.to(torch.float32).detach().cpu().numpy() - present_v_np = present_v.to(torch.float32).detach().cpu().numpy() - if not config.share_buffer: - total_len = config.past_kv_sequence_length + config.q_sequence_length - k_cache_ref_np = k_cache_ref_np[:, :, :total_len, :] - v_cache_ref_np = v_cache_ref_np[:, :, :total_len, :] + if isinstance(present_k, torch.Tensor): + present_k_torch = present_k.to(device) + present_v_torch = present_v.to(device) + else: + present_k_torch = torch.from_numpy(present_k).to(device) + present_v_torch = torch.from_numpy(present_v).to(device) + + if config.kv_cache_type == "fp8": + # FP8 cache needs dequantization for comparison with float reference + present_k_dequant = dequantize_tensor(present_k_torch, k_scale, config.k_quant_type, config.kv_cache_type) + present_v_dequant = dequantize_tensor(present_v_torch, v_scale, config.v_quant_type, config.kv_cache_type) + present_k_np = present_k_dequant.to(torch.float32).detach().cpu().numpy() + present_v_np = present_v_dequant.to(torch.float32).detach().cpu().numpy() + else: + present_k_np = present_k_torch.to(torch.float32).detach().cpu().numpy() + present_v_np = present_v_torch.to(torch.float32).detach().cpu().numpy() numpy.testing.assert_allclose(present_k_np, k_cache_ref_np, rtol=rtol, atol=atol) numpy.testing.assert_allclose(present_v_np, v_cache_ref_np, rtol=rtol, atol=atol) - print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") - numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol, atol=atol) - # Compare quantized cache with proper masking per batch if config.k_quant_type != "NONE": if isinstance(present_k, torch.Tensor): @@ -1415,6 +1445,8 @@ def parity_check_gqa_past( present_k_torch = torch.from_numpy(present_k).to(device) elif config.kv_cache_type == "int8": present_k_torch = torch.from_numpy(present_k.astype(numpy.int8)).to(device) + elif config.kv_cache_type == "fp8": + present_k_torch = torch.from_numpy(present_k).view(torch.float8_e4m3fn).to(device) else: present_k_torch = torch.from_numpy(present_k).to(device) @@ -1455,6 +1487,8 @@ def parity_check_gqa_past( present_v_torch = torch.from_numpy(present_v).to(device) elif config.kv_cache_type == "int8": present_v_torch = torch.from_numpy(present_v.astype(numpy.int8)).to(device) + elif config.kv_cache_type == "fp8": + present_v_torch = torch.from_numpy(present_v).view(torch.float8_e4m3fn).to(device) else: present_v_torch = torch.from_numpy(present_v).to(device) @@ -1851,8 +1885,14 @@ def gqa_cuda_quantized_test_cases(is_past: bool): else gqa_cuda_prompt_test_cases(allow_local=True) ) + kv_types = ["int8"] + if has_int4_kv_cache: + kv_types.append("int4") + if has_fp8_kv_cache: + kv_types.append("fp8") + for name, config in base_cases: - for kv_type in ["int8", "int4"] if has_int4_kv_cache else ["int8"]: + for kv_type in kv_types: for quant_mode in ["PER_TENSOR", "PER_CHANNEL"]: share_scales_options = [False] if quant_mode == "PER_TENSOR" and kv_type == "int8": @@ -1871,6 +1911,8 @@ def gqa_cuda_quantized_test_cases(is_past: bool): q_config.kv_cache_bit_width = 4 elif kv_type == "int8": q_config.kv_cache_bit_width = 8 + elif kv_type == "fp8": + q_config.kv_cache_bit_width = 8 q_name = f"{name}_quant_{kv_type}_{quant_mode}" if share_scales: @@ -1902,8 +1944,26 @@ def has_flash_attention(bf16=False): return True -rtol = {"fp16": 5e-3, "bf16": 5e-2, "int8_fp16": 5e-2, "int4_fp16": 5e-2, "int8_bf16": 5e-2, "int4_bf16": 5e-2} -atol = {"fp16": 5e-3, "bf16": 1e-2, "int8_fp16": 1e-1, "int4_fp16": 1e-1, "int8_bf16": 2e-1, "int4_bf16": 2e-1} +rtol = { + "fp16": 5e-3, + "bf16": 5e-2, + "int8_fp16": 5e-2, + "int4_fp16": 5e-2, + "int8_bf16": 5e-2, + "int4_bf16": 5e-2, + "fp8_fp16": 5e-2, + "fp8_bf16": 5e-2, +} +atol = { + "fp16": 5e-3, + "bf16": 1e-2, + "int8_fp16": 1e-1, + "int4_fp16": 1e-1, + "int8_bf16": 2e-1, + "int4_bf16": 2e-1, + "fp8_fp16": 1e-1, + "fp8_bf16": 2e-1, +} def has_quantized_kv_cache(): @@ -2355,6 +2415,134 @@ def test_gqa_int8_large_seq_batch4(self): atol=5e-2, ) + @unittest.skipIf(not has_cuda_device(89) or not has_fp8_kv_cache, "FP8 KV cache is not available, skipping tests.") + def test_gqa_fp8_kv_cache(self): + """ + Test GQA with FP8 E4M3 quantized KV cache. + Requires SM89+ (Ada Lovelace or newer) and USE_FP8_KV_CACHE build flag. + """ + config = GQAConfig( + batch_size=2, + num_heads=32, + kv_num_heads=8, + head_size=128, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=127, + buffer_sequence_length=128, + rotary=True, + rotary_interleaved=False, + k_quant_type="PER_TENSOR", + v_quant_type="PER_TENSOR", + kv_cache_type="fp8", + share_buffer=True, + share_kv_scale=True, + ) + + torch_type = torch.float16 + ort_type = TensorProto.FLOAT16 + device = "cuda" + + try: + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device=device, + torch_type=torch_type, + ort_type=ort_type, + causal=True, + rtol=5e-2, + atol=5e-2, + ) + except Exception as e: + # FP8 may not be built, skip if kernel not registered + if "Float8E4M3FN" in str(e) or "fp8" in str(e).lower(): + self.skipTest(f"FP8 KV cache not available: {e}") + raise + + @unittest.skipIf(not has_cuda_device(89) or not has_fp8_kv_cache, "FP8 KV cache is not available, skipping tests.") + def test_gqa_fp8_prompt(self): + """ + Test GQA Prompt phase with FP8 E4M3 quantized KV cache. + """ + config = GQAConfig( + batch_size=2, + num_heads=32, + kv_num_heads=8, + head_size=128, + q_sequence_length=128, + kv_sequence_length=128, + past_kv_sequence_length=0, + buffer_sequence_length=128, + rotary=True, + rotary_interleaved=False, + k_quant_type="PER_TENSOR", + v_quant_type="PER_TENSOR", + kv_cache_type="fp8", + share_buffer=True, + share_kv_scale=True, + kv_cache_bit_width=8, + ) + + torch_type = torch.float16 + ort_type = TensorProto.FLOAT16 + device = "cuda" + + try: + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device=device, + torch_type=torch_type, + ort_type=ort_type, + causal=True, + rtol=5e-2, + atol=5e-2, + ) + except Exception as e: + if "Float8E4M3FN" in str(e) or "fp8" in str(e).lower(): + self.skipTest(f"FP8 KV cache not available: {e}") + raise + + @unittest.skipIf(not has_cuda_device(89) or not has_fp8_kv_cache, "FP8 KV cache is not available, skipping tests.") + def test_gqa_fp8_fallback_unsupported_head_size(self): + """ + Test GQA with FP8 KV cache on a head size not supported by XQA. + This forces fallback to the generic generic kernel (if available) or ensures graceful failure/correctness. + """ + config = GQAConfig( + batch_size=2, + num_heads=32, + kv_num_heads=8, + head_size=48, # Valid head size (multiple of 16) but not supported by XQA (supports 64, 128, 256) + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=64, + buffer_sequence_length=128, + rotary=True, + rotary_interleaved=False, + k_quant_type="PER_TENSOR", + v_quant_type="PER_TENSOR", + kv_cache_type="fp8", + share_buffer=True, + share_kv_scale=True, + ) + + torch_type = torch.float16 + ort_type = TensorProto.FLOAT16 + device = "cuda" + + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device=device, + torch_type=torch_type, + ort_type=ort_type, + causal=True, + rtol=5e-2, + atol=5e-2, + ) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_mha_flash_attn.py b/onnxruntime/test/python/transformers/test_mha_flash_attn.py index a015ce6979f91..150f8418d75ab 100644 --- a/onnxruntime/test/python/transformers/test_mha_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_mha_flash_attn.py @@ -371,7 +371,9 @@ def parity_check_mha( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() # Pytorch to compare - out_ref, _ = attention_ref(q, k, v, None, None, 0.0, None, causal=False) + out_ref, _ = attention_ref( + q, k, v, query_padding_mask=None, key_padding_mask=None, attention_bias=None, causal=False + ) out_ref = out_ref.detach().cpu().numpy() numpy.testing.assert_allclose(