Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled do
cmake_dependent_option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
option(onnxruntime_USE_FPA_INTB_GEMM "Build FpA IntB gemm cuda kernels" OFF)
option(onnxruntime_USE_INT4_KV_CACHE "Build cuda kernels for int4 kv cache" OFF)
option(onnxruntime_USE_FP8_KV_CACHE "Build cuda kernels for fp8 kv cache" ON)
option(onnxruntime_QUICK_BUILD "Speed up build by skipping some kernels for faster development" OFF)

option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF)
Expand Down Expand Up @@ -783,6 +784,11 @@ if (onnxruntime_USE_CUDA)
message( STATUS "Enable int4 kv cache for CUDA EP")
list(APPEND ORT_PROVIDER_FLAGS -DUSE_INT4_KV_CACHE=1)
endif()

if (onnxruntime_USE_FP8_KV_CACHE)
message( STATUS "Enable fp8 kv cache for CUDA EP")
list(APPEND ORT_PROVIDER_FLAGS -DUSE_FP8_KV_CACHE=1)
endif()
endif()

if (onnxruntime_USE_CUDA_INTERFACE AND (NOT onnxruntime_USE_CUDA))
Expand Down Expand Up @@ -1442,6 +1448,15 @@ if (Git_FOUND)
if (onnxruntime_USE_INT4_KV_CACHE)
string(APPEND ORT_BUILD_INFO "int4-kv-cache=1, ")
endif()
if (onnxruntime_USE_FP8_KV_CACHE)
string(APPEND ORT_BUILD_INFO "fp8-kv-cache=1, ")
endif()
if (onnxruntime_DUMP_TENSOR)
string(APPEND ORT_BUILD_INFO "dump-tensor=1, ")
endif()
if (onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS)
string(APPEND ORT_BUILD_INFO "dump-node=1, ")
endif()
endif()
string(APPEND ORT_BUILD_INFO "build type=${CMAKE_BUILD_TYPE}")
configure_file(onnxruntime_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/onnxruntime_config.h)
Expand Down
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T_CACHE**<br> *in* past_value:**T_CACHE**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *in* k_scale:**T_KV_SCALE**<br> *in* v_scale:**T_KV_SCALE**<br> *out* output:**T**<br> *out* present_key:**T_CACHE**<br> *out* present_value:**T_CACHE**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)<br/> **T_CACHE** = tensor(bfloat16), tensor(float16), tensor(int8)<br/> **T_KV_SCALE** = tensor(float)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T_CACHE**<br> *in* past_value:**T_CACHE**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *in* k_scale:**T_KV_SCALE**<br> *in* v_scale:**T_KV_SCALE**<br> *out* output:**T**<br> *out* present_key:**T_CACHE**<br> *out* present_value:**T_CACHE**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)<br/> **T_CACHE** = tensor(bfloat16), tensor(float16), tensor(float8e4m3fn), tensor(int8)<br/> **T_KV_SCALE** = tensor(float)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
26 changes: 22 additions & 4 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ REGISTER_KERNEL_TYPED(MLFloat16, MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16, BFloat16)
REGISTER_KERNEL_TYPED(MLFloat16, int8_t)
REGISTER_KERNEL_TYPED(BFloat16, int8_t)
#ifdef USE_FP8_KV_CACHE
REGISTER_KERNEL_TYPED(MLFloat16, Float8E4M3FN)
REGISTER_KERNEL_TYPED(BFloat16, Float8E4M3FN)
#endif
#ifdef USE_INT4_KV_CACHE
REGISTER_KERNEL_TYPED(MLFloat16, uint8_t)
REGISTER_KERNEL_TYPED(BFloat16, uint8_t)
Expand Down Expand Up @@ -292,6 +296,8 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
parameters.past_present_share_buffer = (data.past_key == data.present_key);

bool is_inputs_quantized = (k_quant_type_ != KVQuantizationType::NONE) || (v_quant_type_ != KVQuantizationType::NONE);
constexpr bool is_int8 = std::is_same<U, int8_t>::value;
constexpr bool is_fp8 = std::is_same<U, Float8E4M3FN>::value;

// Allocate XQA scratch if needed (only for Flash Decoding path)
IAllocatorUniquePtr<void> xqa_scratch_buffer;
Expand All @@ -315,18 +321,30 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
parameters.local_window_size == -1) {
int group_size = parameters.num_heads / parameters.kv_num_heads;

bool is_int8_quantized_supported = (k_quant_type_ == KVQuantizationType::PER_TENSOR &&
bool is_int8_quantized_supported = is_int8 &&
(k_quant_type_ == KVQuantizationType::PER_TENSOR &&
v_quant_type_ == KVQuantizationType::PER_TENSOR &&
data.k_scale == data.v_scale && // XQA requires k_scale and v_scale to be the same. Here requires k_scale and v_scale are same tensor.
parameters.kv_cache_bit_width == 8 &&
(parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) &&
(group_size == 4 || group_size == 8 || group_size == 16 || group_size == 32));

#ifdef USE_FP8_KV_CACHE
bool is_fp8_quantized_supported = is_fp8 &&
(k_quant_type_ == KVQuantizationType::PER_TENSOR &&
v_quant_type_ == KVQuantizationType::PER_TENSOR &&
data.k_scale == data.v_scale &&
(parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) &&
(group_size == 4 || group_size == 8 || group_size == 16 || group_size == 32) &&
(device_prop.major >= 9 || (device_prop.major == 8 && device_prop.minor == 9))); // FP8 requires SM89+ (Ada Lovelace)
#else
constexpr bool is_fp8_quantized_supported = false;
#endif

bool is_non_quantized_supported = !is_inputs_quantized &&
(parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) &&
(64 % group_size == 0);

data.use_xqa = (is_non_quantized_supported || is_int8_quantized_supported);
data.use_xqa = (is_non_quantized_supported || is_int8_quantized_supported || is_fp8_quantized_supported);

if (data.use_xqa) {
size_t xqa_internal_bytes = onnxruntime::contrib::cuda::GetXQAScratchSize(
Expand All @@ -336,7 +354,7 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
parameters.kv_num_heads,
parameters.head_size,
parameters.seqlen_present_kv_cache,
parameters.k_quant_type != KVQuantizationType::NONE ? XqaQuantType::kInt8 : XqaQuantType::kNone,
parameters.k_quant_type != KVQuantizationType::NONE ? (is_fp8 ? XqaQuantType::kFp8 : XqaQuantType::kInt8) : XqaQuantType::kNone,
std::is_same<T, BFloat16>::value);
assert(xqa_internal_bytes > 0);
// Calculate additional scratch needed for manual RoPE/Append in ExtremeDecoding
Expand Down
Loading
Loading