|
16 | 16 |
|
17 | 17 | #include "mlaChunkedPrefill.cuh" |
18 | 18 | #include "tensorrt_llm/common/assert.h" |
| 19 | +#include "tensorrt_llm/common/cudaTypeUtils.cuh" |
19 | 20 | #include "tensorrt_llm/common/mathUtils.h" |
| 21 | +#include <cuda_fp8.h> |
20 | 22 | #include <cutlass/array.h> |
21 | 23 | #include <cutlass/half.h> |
22 | 24 |
|
@@ -89,6 +91,70 @@ struct setChunkedKVKernelTraits |
89 | 91 | static constexpr int kBlockSize = kThreadPerHead * kCpTokenPerBlock; |
90 | 92 | }; |
91 | 93 |
|
| 94 | +template <typename SrcType, int NUM> |
| 95 | +inline __device__ void quantCopy( |
| 96 | + __nv_fp8_e4m3* dst_global_ptr, SrcType const* src_fragment_ptr, float const scale_val = 1.f) |
| 97 | +{ |
| 98 | + using DstVecType = typename std::conditional<sizeof(SrcType) == 2, float2, float>::type; |
| 99 | + using SrcType2 = typename std::conditional<sizeof(SrcType) == 2, |
| 100 | + typename tensorrt_llm::common::TypeConverter<SrcType>::Type, float2>::type; |
| 101 | + static constexpr int COPY_SIZE = sizeof(DstVecType); |
| 102 | + static constexpr int TOTAL_COPY_SIZE = NUM * sizeof(__nv_fp8_e4m3); |
| 103 | + static constexpr int LOOP_NUM = TOTAL_COPY_SIZE / COPY_SIZE; |
| 104 | + static_assert(TOTAL_COPY_SIZE % COPY_SIZE == 0); |
| 105 | + static constexpr int CVT_NUM = COPY_SIZE / sizeof(__nv_fp8_e4m3) / 2; |
| 106 | + static_assert(COPY_SIZE % (sizeof(__nv_fp8_e4m3) * 2) == 0); |
| 107 | + DstVecType fragment; |
| 108 | + int offset = 0; |
| 109 | +#pragma unroll |
| 110 | + for (int i = 0; i < LOOP_NUM; ++i) |
| 111 | + { |
| 112 | +#pragma unroll |
| 113 | + for (int j = 0; j < CVT_NUM; ++j) |
| 114 | + { |
| 115 | + float2 val2 = tensorrt_llm::common::cuda_cast<float2>( |
| 116 | + reinterpret_cast<SrcType2 const*>(src_fragment_ptr)[j + offset]); |
| 117 | + val2.x *= scale_val; |
| 118 | + val2.y *= scale_val; |
| 119 | + reinterpret_cast<__nv_fp8x2_e4m3*>(&fragment)[j] = __nv_fp8x2_e4m3(val2); |
| 120 | + } |
| 121 | + reinterpret_cast<DstVecType*>(dst_global_ptr)[i] = fragment; |
| 122 | + offset += CVT_NUM; |
| 123 | + } |
| 124 | +} |
| 125 | + |
| 126 | +template <typename DstType, int NUM> |
| 127 | +inline __device__ void dequantCopy( |
| 128 | + DstType* dst_global_ptr, __nv_fp8_e4m3 const* src_fragment_ptr, float const scale_val = 1.f) |
| 129 | +{ |
| 130 | + using DstVecType = uint4; |
| 131 | + using DstType2 |
| 132 | + = std::conditional_t<sizeof(DstType) == 2, typename tensorrt_llm::common::TypeConverter<DstType>::Type, float2>; |
| 133 | + static constexpr int COPY_SIZE = sizeof(DstVecType); |
| 134 | + static constexpr int TOTAL_COPY_SIZE = NUM * sizeof(DstType); |
| 135 | + static constexpr int LOOP_NUM = TOTAL_COPY_SIZE / COPY_SIZE; |
| 136 | + static_assert(TOTAL_COPY_SIZE % COPY_SIZE == 0); |
| 137 | + static constexpr int CVT_NUM = COPY_SIZE / sizeof(DstType) / 2; |
| 138 | + static_assert(COPY_SIZE % (sizeof(DstType) * 2) == 0); |
| 139 | + DstVecType fragment; |
| 140 | + int offset = 0; |
| 141 | +#pragma unroll |
| 142 | + for (int i = 0; i < LOOP_NUM; ++i) |
| 143 | + { |
| 144 | +#pragma unroll |
| 145 | + for (int j = 0; j < CVT_NUM; ++j) |
| 146 | + { |
| 147 | + float2 val2 = tensorrt_llm::common::cuda_cast<float2>( |
| 148 | + reinterpret_cast<__nv_fp8x2_e4m3 const*>(src_fragment_ptr)[j + offset]); |
| 149 | + val2.x *= scale_val; |
| 150 | + val2.y *= scale_val; |
| 151 | + reinterpret_cast<DstType2*>(&fragment)[j] = tensorrt_llm::common::cuda_cast<DstType2>(val2); |
| 152 | + } |
| 153 | + reinterpret_cast<DstVecType*>(dst_global_ptr)[i] = fragment; |
| 154 | + offset += CVT_NUM; |
| 155 | + } |
| 156 | +} |
| 157 | + |
92 | 158 | // merged_attn [q_total_len, H=128, D=128] (T) |
93 | 159 | // merged_softmax_sum [q_total_len, H, 2] (float, max/sum) |
94 | 160 | template <typename T> |
@@ -179,12 +245,15 @@ __global__ void mergeAttnWithSoftmaxKernel(T* merged_attn, float2* merged_softma |
179 | 245 |
|
180 | 246 | // kv_output {total_chunk_token=b*chunk_size, h=1, d_lora} |
181 | 247 | // k_pe_output {total_chunk_token, h=1, d_rope} |
182 | | -template <typename T> |
| 248 | +template <typename T, typename TCache> |
183 | 249 | __global__ void loadChunkedKVCacheForMLAKernel(T* output_kv_ptr, T* output_k_pe_ptr, |
184 | 250 | tensorrt_llm::kernels::KVBlockArray const kv_cache, int64_t const* cu_ctx_chunked_len, int chunked_size, |
185 | | - int chunked_idx) |
| 251 | + int chunked_idx, float const* kv_scale_quant_orig_ptr) |
186 | 252 | { |
187 | | - using KT = loadChunkedKVKernelTraits<T>; |
| 253 | + static_assert(std::is_same_v<T, TCache> || std::is_same_v<TCache, __nv_fp8_e4m3>, |
| 254 | + "TCache must be either the same type as T or __nv_fp8_e4m3"); |
| 255 | + using KT = loadChunkedKVKernelTraits<TCache>; |
| 256 | + float const kv_scale_quant_orig = kv_scale_quant_orig_ptr ? kv_scale_quant_orig_ptr[0] : 1.0f; |
188 | 257 | int const batch_idx = static_cast<int>(blockIdx.y); |
189 | 258 | [[maybe_unused]] int const head_idx = static_cast<int>(blockIdx.z); // default 0 |
190 | 259 |
|
@@ -215,14 +284,31 @@ __global__ void loadChunkedKVCacheForMLAKernel(T* output_kv_ptr, T* output_k_pe_ |
215 | 284 | // kv_output {total_chunk_token, h=1, d} |
216 | 285 | int const global_st_idx |
217 | 286 | = global_st_offset * KT::kLoraSize + local_token_idx * KT::kLoraSize + head_dim_idx; |
218 | | - *reinterpret_cast<typename KT::VecT*>(output_kv_ptr + global_st_idx) = ld_data; |
| 287 | + if constexpr (std::is_same_v<TCache, T>) |
| 288 | + { |
| 289 | + *reinterpret_cast<typename KT::VecT*>(output_kv_ptr + global_st_idx) = ld_data; |
| 290 | + } |
| 291 | + else if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>) |
| 292 | + { |
| 293 | + dequantCopy<T, KT::kElemPerLoad>(output_kv_ptr + global_st_idx, |
| 294 | + reinterpret_cast<__nv_fp8_e4m3 const*>(&ld_data), kv_scale_quant_orig); |
| 295 | + } |
219 | 296 | } |
220 | 297 | else |
221 | 298 | { |
222 | 299 | // k_pe_output {total_chunk_token, h=1, d_rope} |
223 | 300 | int const global_st_idx = global_st_offset * KT::kRopeSize + local_token_idx * KT::kRopeSize |
224 | 301 | + (head_dim_idx - KT::kLoraSize); |
225 | | - *reinterpret_cast<typename KT::VecT*>(output_k_pe_ptr + global_st_idx) = ld_data; |
| 302 | + |
| 303 | + if constexpr (std::is_same_v<TCache, T>) |
| 304 | + { |
| 305 | + *reinterpret_cast<typename KT::VecT*>(output_k_pe_ptr + global_st_idx) = ld_data; |
| 306 | + } |
| 307 | + else if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>) |
| 308 | + { |
| 309 | + dequantCopy<T, KT::kElemPerLoad>(output_k_pe_ptr + global_st_idx, |
| 310 | + reinterpret_cast<__nv_fp8_e4m3 const*>(&ld_data), kv_scale_quant_orig); |
| 311 | + } |
226 | 312 | } |
227 | 313 | } |
228 | 314 | } |
@@ -329,19 +415,19 @@ void invokeMergeAttnWithSoftmax(T* merged_attn, float* merged_softmax_stats, T c |
329 | 415 | } |
330 | 416 |
|
331 | 417 | // load single chunk kv from kv_cache for each request |
332 | | -template <typename T> |
| 418 | +template <typename T, typename TCache> |
333 | 419 | void invokeMLALoadChunkedKV(T* output_kv_ptr, T* output_k_pe_ptr, KVBlockArray const& kv_cache, int const num_contexts, |
334 | 420 | int64_t const* cu_ctx_chunked_len, int lora_size, int rope_size, int chunked_size, int chunked_idx, |
335 | | - cudaStream_t stream) |
| 421 | + float const* kv_scale_quant_orig_ptr, cudaStream_t stream) |
336 | 422 | { |
337 | | - using KT = loadChunkedKVKernelTraits<T>; |
| 423 | + using KT = loadChunkedKVKernelTraits<TCache>; |
338 | 424 | TLLM_CHECK_WITH_INFO(lora_size + rope_size == KT::kHeadSize, "head dim should be equal to %d", KT::kHeadSize); |
339 | 425 | TLLM_CHECK_WITH_INFO(lora_size == KT::kLoraSize, "lora dim should be equal to %d", KT::kLoraSize); |
340 | 426 | TLLM_CHECK_WITH_INFO(rope_size == KT::kRopeSize, "rope dim should be equal to %d", KT::kRopeSize); |
341 | 427 | // {chunked_unit_size / token_per_block, batch_size, head_num} |
342 | 428 | dim3 grid(static_cast<int>(tensorrt_llm::common::divUp(chunked_size, KT::kTokenPerBlock)), num_contexts, 1); |
343 | | - loadChunkedKVCacheForMLAKernel<T><<<grid, KT::kBlockSize, 0, stream>>>( |
344 | | - output_kv_ptr, output_k_pe_ptr, kv_cache, cu_ctx_chunked_len, chunked_size, chunked_idx); |
| 429 | + loadChunkedKVCacheForMLAKernel<T, TCache><<<grid, KT::kBlockSize, 0, stream>>>(output_kv_ptr, output_k_pe_ptr, |
| 430 | + kv_cache, cu_ctx_chunked_len, chunked_size, chunked_idx, kv_scale_quant_orig_ptr); |
345 | 431 | } |
346 | 432 |
|
347 | 433 | // output_kv {B, 2, ceil(chunked_size / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, d}, padding with |
@@ -369,9 +455,12 @@ void invokeMLASetChunkedKV(T* output_kv, T const* kv, T const* k_pe, int const b |
369 | 455 | float const* pre_softmax_stats, T const* curr_attn, float const* curr_softmax_stats, int const batch_size, \ |
370 | 456 | int64_t const* cu_q_seq_len, int max_q_seq_len, int64_t const* merge_op, int const num_heads, \ |
371 | 457 | int const head_size, cudaStream_t stream); \ |
372 | | - template void invokeMLALoadChunkedKV<T>(T * output_kv_ptr, T * output_k_pe_ptr, KVBlockArray const& kv_cache, \ |
| 458 | + template void invokeMLALoadChunkedKV<T, T>(T * output_kv_ptr, T * output_k_pe_ptr, KVBlockArray const& kv_cache, \ |
373 | 459 | int const num_contexts, int64_t const* cu_ctx_chunked_len, int lora_size, int rope_size, int chunked_size, \ |
374 | | - int chunked_idx, cudaStream_t stream); \ |
| 460 | + int chunked_idx, float const* kv_scale_quant_orig_ptr, cudaStream_t stream); \ |
| 461 | + template void invokeMLALoadChunkedKV<T, __nv_fp8_e4m3>(T * output_kv_ptr, T * output_k_pe_ptr, \ |
| 462 | + KVBlockArray const& kv_cache, int const num_contexts, int64_t const* cu_ctx_chunked_len, int lora_size, \ |
| 463 | + int rope_size, int chunked_size, int chunked_idx, float const* kv_scale_quant_orig_ptr, cudaStream_t stream); \ |
375 | 464 | template void invokeMLASetChunkedKV<T>(T * output_kv, T const* kv, T const* k_pe, int const batch_size, \ |
376 | 465 | int const max_seq_len, int const num_heads, int uncompressed_head_size, int rope_size, \ |
377 | 466 | int64_t const* cu_seq_lens, int const kv_cache_tokens_per_block, cudaStream_t stream); |
|
0 commit comments