Skip to content

Commit 8836990

Browse files
authored
[TRTLLM-3602][feat] support nvfp4 model and fp8 kv cache for MLA chunked prefill (Blackwell) (NVIDIA#5475)
Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com>
1 parent 8dfa31c commit 8836990

File tree

12 files changed

+326
-150
lines changed

12 files changed

+326
-150
lines changed

cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu

Lines changed: 101 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
#include "mlaChunkedPrefill.cuh"
1818
#include "tensorrt_llm/common/assert.h"
19+
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
1920
#include "tensorrt_llm/common/mathUtils.h"
21+
#include <cuda_fp8.h>
2022
#include <cutlass/array.h>
2123
#include <cutlass/half.h>
2224

@@ -89,6 +91,70 @@ struct setChunkedKVKernelTraits
8991
static constexpr int kBlockSize = kThreadPerHead * kCpTokenPerBlock;
9092
};
9193

94+
template <typename SrcType, int NUM>
95+
inline __device__ void quantCopy(
96+
__nv_fp8_e4m3* dst_global_ptr, SrcType const* src_fragment_ptr, float const scale_val = 1.f)
97+
{
98+
using DstVecType = typename std::conditional<sizeof(SrcType) == 2, float2, float>::type;
99+
using SrcType2 = typename std::conditional<sizeof(SrcType) == 2,
100+
typename tensorrt_llm::common::TypeConverter<SrcType>::Type, float2>::type;
101+
static constexpr int COPY_SIZE = sizeof(DstVecType);
102+
static constexpr int TOTAL_COPY_SIZE = NUM * sizeof(__nv_fp8_e4m3);
103+
static constexpr int LOOP_NUM = TOTAL_COPY_SIZE / COPY_SIZE;
104+
static_assert(TOTAL_COPY_SIZE % COPY_SIZE == 0);
105+
static constexpr int CVT_NUM = COPY_SIZE / sizeof(__nv_fp8_e4m3) / 2;
106+
static_assert(COPY_SIZE % (sizeof(__nv_fp8_e4m3) * 2) == 0);
107+
DstVecType fragment;
108+
int offset = 0;
109+
#pragma unroll
110+
for (int i = 0; i < LOOP_NUM; ++i)
111+
{
112+
#pragma unroll
113+
for (int j = 0; j < CVT_NUM; ++j)
114+
{
115+
float2 val2 = tensorrt_llm::common::cuda_cast<float2>(
116+
reinterpret_cast<SrcType2 const*>(src_fragment_ptr)[j + offset]);
117+
val2.x *= scale_val;
118+
val2.y *= scale_val;
119+
reinterpret_cast<__nv_fp8x2_e4m3*>(&fragment)[j] = __nv_fp8x2_e4m3(val2);
120+
}
121+
reinterpret_cast<DstVecType*>(dst_global_ptr)[i] = fragment;
122+
offset += CVT_NUM;
123+
}
124+
}
125+
126+
template <typename DstType, int NUM>
127+
inline __device__ void dequantCopy(
128+
DstType* dst_global_ptr, __nv_fp8_e4m3 const* src_fragment_ptr, float const scale_val = 1.f)
129+
{
130+
using DstVecType = uint4;
131+
using DstType2
132+
= std::conditional_t<sizeof(DstType) == 2, typename tensorrt_llm::common::TypeConverter<DstType>::Type, float2>;
133+
static constexpr int COPY_SIZE = sizeof(DstVecType);
134+
static constexpr int TOTAL_COPY_SIZE = NUM * sizeof(DstType);
135+
static constexpr int LOOP_NUM = TOTAL_COPY_SIZE / COPY_SIZE;
136+
static_assert(TOTAL_COPY_SIZE % COPY_SIZE == 0);
137+
static constexpr int CVT_NUM = COPY_SIZE / sizeof(DstType) / 2;
138+
static_assert(COPY_SIZE % (sizeof(DstType) * 2) == 0);
139+
DstVecType fragment;
140+
int offset = 0;
141+
#pragma unroll
142+
for (int i = 0; i < LOOP_NUM; ++i)
143+
{
144+
#pragma unroll
145+
for (int j = 0; j < CVT_NUM; ++j)
146+
{
147+
float2 val2 = tensorrt_llm::common::cuda_cast<float2>(
148+
reinterpret_cast<__nv_fp8x2_e4m3 const*>(src_fragment_ptr)[j + offset]);
149+
val2.x *= scale_val;
150+
val2.y *= scale_val;
151+
reinterpret_cast<DstType2*>(&fragment)[j] = tensorrt_llm::common::cuda_cast<DstType2>(val2);
152+
}
153+
reinterpret_cast<DstVecType*>(dst_global_ptr)[i] = fragment;
154+
offset += CVT_NUM;
155+
}
156+
}
157+
92158
// merged_attn [q_total_len, H=128, D=128] (T)
93159
// merged_softmax_sum [q_total_len, H, 2] (float, max/sum)
94160
template <typename T>
@@ -179,12 +245,15 @@ __global__ void mergeAttnWithSoftmaxKernel(T* merged_attn, float2* merged_softma
179245

180246
// kv_output {total_chunk_token=b*chunk_size, h=1, d_lora}
181247
// k_pe_output {total_chunk_token, h=1, d_rope}
182-
template <typename T>
248+
template <typename T, typename TCache>
183249
__global__ void loadChunkedKVCacheForMLAKernel(T* output_kv_ptr, T* output_k_pe_ptr,
184250
tensorrt_llm::kernels::KVBlockArray const kv_cache, int64_t const* cu_ctx_chunked_len, int chunked_size,
185-
int chunked_idx)
251+
int chunked_idx, float const* kv_scale_quant_orig_ptr)
186252
{
187-
using KT = loadChunkedKVKernelTraits<T>;
253+
static_assert(std::is_same_v<T, TCache> || std::is_same_v<TCache, __nv_fp8_e4m3>,
254+
"TCache must be either the same type as T or __nv_fp8_e4m3");
255+
using KT = loadChunkedKVKernelTraits<TCache>;
256+
float const kv_scale_quant_orig = kv_scale_quant_orig_ptr ? kv_scale_quant_orig_ptr[0] : 1.0f;
188257
int const batch_idx = static_cast<int>(blockIdx.y);
189258
[[maybe_unused]] int const head_idx = static_cast<int>(blockIdx.z); // default 0
190259

@@ -215,14 +284,31 @@ __global__ void loadChunkedKVCacheForMLAKernel(T* output_kv_ptr, T* output_k_pe_
215284
// kv_output {total_chunk_token, h=1, d}
216285
int const global_st_idx
217286
= global_st_offset * KT::kLoraSize + local_token_idx * KT::kLoraSize + head_dim_idx;
218-
*reinterpret_cast<typename KT::VecT*>(output_kv_ptr + global_st_idx) = ld_data;
287+
if constexpr (std::is_same_v<TCache, T>)
288+
{
289+
*reinterpret_cast<typename KT::VecT*>(output_kv_ptr + global_st_idx) = ld_data;
290+
}
291+
else if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
292+
{
293+
dequantCopy<T, KT::kElemPerLoad>(output_kv_ptr + global_st_idx,
294+
reinterpret_cast<__nv_fp8_e4m3 const*>(&ld_data), kv_scale_quant_orig);
295+
}
219296
}
220297
else
221298
{
222299
// k_pe_output {total_chunk_token, h=1, d_rope}
223300
int const global_st_idx = global_st_offset * KT::kRopeSize + local_token_idx * KT::kRopeSize
224301
+ (head_dim_idx - KT::kLoraSize);
225-
*reinterpret_cast<typename KT::VecT*>(output_k_pe_ptr + global_st_idx) = ld_data;
302+
303+
if constexpr (std::is_same_v<TCache, T>)
304+
{
305+
*reinterpret_cast<typename KT::VecT*>(output_k_pe_ptr + global_st_idx) = ld_data;
306+
}
307+
else if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
308+
{
309+
dequantCopy<T, KT::kElemPerLoad>(output_k_pe_ptr + global_st_idx,
310+
reinterpret_cast<__nv_fp8_e4m3 const*>(&ld_data), kv_scale_quant_orig);
311+
}
226312
}
227313
}
228314
}
@@ -329,19 +415,19 @@ void invokeMergeAttnWithSoftmax(T* merged_attn, float* merged_softmax_stats, T c
329415
}
330416

331417
// load single chunk kv from kv_cache for each request
332-
template <typename T>
418+
template <typename T, typename TCache>
333419
void invokeMLALoadChunkedKV(T* output_kv_ptr, T* output_k_pe_ptr, KVBlockArray const& kv_cache, int const num_contexts,
334420
int64_t const* cu_ctx_chunked_len, int lora_size, int rope_size, int chunked_size, int chunked_idx,
335-
cudaStream_t stream)
421+
float const* kv_scale_quant_orig_ptr, cudaStream_t stream)
336422
{
337-
using KT = loadChunkedKVKernelTraits<T>;
423+
using KT = loadChunkedKVKernelTraits<TCache>;
338424
TLLM_CHECK_WITH_INFO(lora_size + rope_size == KT::kHeadSize, "head dim should be equal to %d", KT::kHeadSize);
339425
TLLM_CHECK_WITH_INFO(lora_size == KT::kLoraSize, "lora dim should be equal to %d", KT::kLoraSize);
340426
TLLM_CHECK_WITH_INFO(rope_size == KT::kRopeSize, "rope dim should be equal to %d", KT::kRopeSize);
341427
// {chunked_unit_size / token_per_block, batch_size, head_num}
342428
dim3 grid(static_cast<int>(tensorrt_llm::common::divUp(chunked_size, KT::kTokenPerBlock)), num_contexts, 1);
343-
loadChunkedKVCacheForMLAKernel<T><<<grid, KT::kBlockSize, 0, stream>>>(
344-
output_kv_ptr, output_k_pe_ptr, kv_cache, cu_ctx_chunked_len, chunked_size, chunked_idx);
429+
loadChunkedKVCacheForMLAKernel<T, TCache><<<grid, KT::kBlockSize, 0, stream>>>(output_kv_ptr, output_k_pe_ptr,
430+
kv_cache, cu_ctx_chunked_len, chunked_size, chunked_idx, kv_scale_quant_orig_ptr);
345431
}
346432

347433
// output_kv {B, 2, ceil(chunked_size / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, d}, padding with
@@ -369,9 +455,12 @@ void invokeMLASetChunkedKV(T* output_kv, T const* kv, T const* k_pe, int const b
369455
float const* pre_softmax_stats, T const* curr_attn, float const* curr_softmax_stats, int const batch_size, \
370456
int64_t const* cu_q_seq_len, int max_q_seq_len, int64_t const* merge_op, int const num_heads, \
371457
int const head_size, cudaStream_t stream); \
372-
template void invokeMLALoadChunkedKV<T>(T * output_kv_ptr, T * output_k_pe_ptr, KVBlockArray const& kv_cache, \
458+
template void invokeMLALoadChunkedKV<T, T>(T * output_kv_ptr, T * output_k_pe_ptr, KVBlockArray const& kv_cache, \
373459
int const num_contexts, int64_t const* cu_ctx_chunked_len, int lora_size, int rope_size, int chunked_size, \
374-
int chunked_idx, cudaStream_t stream); \
460+
int chunked_idx, float const* kv_scale_quant_orig_ptr, cudaStream_t stream); \
461+
template void invokeMLALoadChunkedKV<T, __nv_fp8_e4m3>(T * output_kv_ptr, T * output_k_pe_ptr, \
462+
KVBlockArray const& kv_cache, int const num_contexts, int64_t const* cu_ctx_chunked_len, int lora_size, \
463+
int rope_size, int chunked_size, int chunked_idx, float const* kv_scale_quant_orig_ptr, cudaStream_t stream); \
375464
template void invokeMLASetChunkedKV<T>(T * output_kv, T const* kv, T const* k_pe, int const batch_size, \
376465
int const max_seq_len, int const num_heads, int uncompressed_head_size, int rope_size, \
377466
int64_t const* cu_seq_lens, int const kv_cache_tokens_per_block, cudaStream_t stream);

cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ void invokeMergeAttnWithSoftmax(T* merged_attn, float* merged_softmax_stats, T c
3333
cudaStream_t stream);
3434

3535
// load single chunk kv from kv_cache for each request
36-
template <typename T>
36+
template <typename T, typename TCache>
3737
void invokeMLALoadChunkedKV(T* output_kv_ptr, T* output_k_pe_ptr, KVBlockArray const& kv_cache, int const num_contexts,
3838
int64_t const* cu_ctx_chunked_len, int lora_size, int rope_size, int chunked_size, int chunked_idx,
39-
cudaStream_t stream);
39+
float const* kv_scale_quant_orig_ptr, cudaStream_t stream);
4040

4141
// output_kv {B, 2, ceil(chunked_size / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, d}, padding with
4242
// zero

cpp/tensorrt_llm/kernels/mlaKernels.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -851,16 +851,16 @@ __global__ void applyMLARopeAppendPagedKVAssignQKernel(KVBlockArray kv_cache, T*
851851
if constexpr (std::is_same_v<TCache, T>)
852852
{
853853
reinterpret_cast<VecT*>(kDst)[inBlockIdx] = data;
854-
// copy to latent_cache (for chunked prefill, it will not load kv cache for uncached k_pe)
855-
auto const src_k_global_offset
856-
= static_cast<size_t>(global_token_idx) * (K_DIM + ROPE_DIM) + K_DIM;
857-
*reinterpret_cast<VecT*>(&latent_cache_ptr[src_k_global_offset + head_dim_idx]) = data;
858854
}
859855
else if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
860856
{
861857
quantCopy<T, ELTS_PER_VEC>(reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
862858
reinterpret_cast<T const*>(&data), quant_scale_kv_val);
863859
}
860+
// copy to latent_cache (for chunked prefill, it will not load kv cache for uncached k_pe)
861+
// we only need to copy original value.
862+
auto const src_k_global_offset = static_cast<size_t>(global_token_idx) * (K_DIM + ROPE_DIM) + K_DIM;
863+
*reinterpret_cast<VecT*>(&latent_cache_ptr[src_k_global_offset + head_dim_idx]) = data;
864864
}
865865
else
866866
{

cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,18 @@ void loadPagedKVCacheForMLAHelper(torch::Tensor& compressed_kv, torch::Tensor& k
4848
cu_ctx_cached_kv_lens_ptr, max_input_seq_len, lora_size, rope_size, kv_scale_quant_orig_ptr, stream);
4949
}
5050

51-
template <typename T>
51+
template <typename T, typename TCache>
5252
void loadChunkedKVCacheForMLAHelper(torch::Tensor& output_kv, torch::Tensor& output_k_pe, KVBlockArray& kv_cache,
5353
int const num_contexts, torch::Tensor const& cu_ctx_chunked_len, int lora_size, int rope_size,
54-
int const chunked_size, int const chunked_idx)
54+
int const chunked_size, int const chunked_idx, float const* kv_scale_quant_orig_ptr)
5555
{
5656
auto stream = at::cuda::getCurrentCUDAStream(output_kv.get_device());
5757

5858
T* output_kv_ptr = static_cast<T*>(output_kv.data_ptr());
5959
T* output_k_pe_ptr = static_cast<T*>(output_k_pe.data_ptr());
60-
tensorrt_llm::kernels::invokeMLALoadChunkedKV<T>(output_kv_ptr, output_k_pe_ptr, kv_cache, num_contexts,
61-
cu_ctx_chunked_len.data_ptr<int64_t>(), lora_size, rope_size, chunked_size, chunked_idx, stream);
60+
tensorrt_llm::kernels::invokeMLALoadChunkedKV<T, TCache>(output_kv_ptr, output_k_pe_ptr, kv_cache, num_contexts,
61+
cu_ctx_chunked_len.data_ptr<int64_t>(), lora_size, rope_size, chunked_size, chunked_idx,
62+
kv_scale_quant_orig_ptr, stream);
6263
}
6364

6465
template <typename T>
@@ -327,18 +328,44 @@ std::vector<torch::Tensor> loadChunkedKVCacheForMLA(torch::ScalarType out_dtype,
327328

328329
if (out_dtype == torch::kFloat16)
329330
{
330-
loadChunkedKVCacheForMLAHelper<half>(outputs[0], outputs[1], kv_cache_buffer, num_contexts,
331-
cu_ctx_chunked_kv_lens, lora_size, rope_size, chunked_size, chunked_index);
331+
if (kv_cache_quant_mode.hasFp8KvCache())
332+
{
333+
loadChunkedKVCacheForMLAHelper<half, __nv_fp8_e4m3>(outputs[0], outputs[1], kv_cache_buffer, num_contexts,
334+
cu_ctx_chunked_kv_lens, lora_size, rope_size, chunked_size, chunked_index, kv_scale_quant_orig_ptr);
335+
}
336+
else
337+
{
338+
loadChunkedKVCacheForMLAHelper<half, half>(outputs[0], outputs[1], kv_cache_buffer, num_contexts,
339+
cu_ctx_chunked_kv_lens, lora_size, rope_size, chunked_size, chunked_index, kv_scale_quant_orig_ptr);
340+
}
332341
}
333342
else if (out_dtype == torch::kFloat32)
334343
{
335-
loadChunkedKVCacheForMLAHelper<float>(outputs[0], outputs[1], kv_cache_buffer, num_contexts,
336-
cu_ctx_chunked_kv_lens, lora_size, rope_size, chunked_size, chunked_index);
344+
if (kv_cache_quant_mode.hasFp8KvCache())
345+
{
346+
loadChunkedKVCacheForMLAHelper<float, __nv_fp8_e4m3>(outputs[0], outputs[1], kv_cache_buffer, num_contexts,
347+
cu_ctx_chunked_kv_lens, lora_size, rope_size, chunked_size, chunked_index, kv_scale_quant_orig_ptr);
348+
}
349+
else
350+
{
351+
loadChunkedKVCacheForMLAHelper<float, float>(outputs[0], outputs[1], kv_cache_buffer, num_contexts,
352+
cu_ctx_chunked_kv_lens, lora_size, rope_size, chunked_size, chunked_index, kv_scale_quant_orig_ptr);
353+
}
337354
}
338355
else if (out_dtype == torch::kBFloat16)
339356
{
340-
loadChunkedKVCacheForMLAHelper<__nv_bfloat16>(outputs[0], outputs[1], kv_cache_buffer, num_contexts,
341-
cu_ctx_chunked_kv_lens, lora_size, rope_size, chunked_size, chunked_index);
357+
if (kv_cache_quant_mode.hasFp8KvCache())
358+
{
359+
loadChunkedKVCacheForMLAHelper<__nv_bfloat16, __nv_fp8_e4m3>(outputs[0], outputs[1], kv_cache_buffer,
360+
num_contexts, cu_ctx_chunked_kv_lens, lora_size, rope_size, chunked_size, chunked_index,
361+
kv_scale_quant_orig_ptr);
362+
}
363+
else
364+
{
365+
loadChunkedKVCacheForMLAHelper<__nv_bfloat16, __nv_bfloat16>(outputs[0], outputs[1], kv_cache_buffer,
366+
num_contexts, cu_ctx_chunked_kv_lens, lora_size, rope_size, chunked_size, chunked_index,
367+
kv_scale_quant_orig_ptr);
368+
}
342369
}
343370

344371
return outputs;

0 commit comments

Comments
 (0)