forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmlaKernels.cu
More file actions
1120 lines (998 loc) · 53.9 KB
/
mlaKernels.cu
File metadata and controls
1120 lines (998 loc) · 53.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/common/mathUtils.h"
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h"
#include "tensorrt_llm/kernels/gptKernels.h"
#include "tensorrt_llm/kernels/mlaKernels.h"
#include <cstdint>
#include <cub/cub.cuh>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
using namespace tensorrt_llm::common;
namespace tensorrt_llm
{
namespace kernels
{
// A stateful callback functor that maintains the running sum between consecutive scans.
struct BlockPrefixCallbackOp
{
// Running prefix
int mRunningTotal;
// Constructor
__device__ BlockPrefixCallbackOp(int runningTotal)
: mRunningTotal(runningTotal)
{
}
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
__device__ int operator()(int blockAggregate)
{
int oldPrefix = mRunningTotal;
mRunningTotal += blockAggregate;
return oldPrefix;
}
};
template <typename T>
struct VecType
{
using Type = T;
using GPTJEltType = T;
};
template <>
struct VecType<float>
{
using Type = float4;
using GPTJEltType = float2;
};
template <>
struct VecType<half>
{
using Type = uint4;
using GPTJEltType = uint32_t;
};
template <>
struct VecType<__nv_bfloat16>
{
using Type = mmha::bf16_8_t;
using GPTJEltType = __nv_bfloat162;
};
struct __align__(16) fp8_16_t
{
__nv_fp8x4_e4m3 x;
__nv_fp8x4_e4m3 y;
__nv_fp8x4_e4m3 z;
__nv_fp8x4_e4m3 w;
};
template <>
struct VecType<__nv_fp8_e4m3>
{
using Type = fp8_16_t;
using GPTJEltType = __nv_fp8x2_e4m3;
};
template <typename T>
struct loadPagedKVKernelTraits
{
static constexpr int kLoraSize = 512;
static constexpr int kRopeSize = 64;
static constexpr int kHeadSize = kLoraSize + kRopeSize;
using VecT = typename VecType<T>::Type;
static constexpr int kBytesPerElem = sizeof(T);
static constexpr int kBytesPerLoad = 16;
static constexpr int kElemPerLoad = kBytesPerLoad / kBytesPerElem;
static_assert((kHeadSize * kBytesPerElem) % kBytesPerLoad == 0,
"kHeadSize * kBytesPerElem must be multiple of kBytesPerLoad (16Bytes)");
static constexpr int kVecPerHead = (kHeadSize * kBytesPerElem) / kBytesPerLoad;
static constexpr int kThreadPerHead = kVecPerHead; // for each head, we use kThreadPerHead threads to fetch all the
// kv cache data, each thread read kv cache only once.
static constexpr int kTokenPerBlock
= std::is_same_v<T, float> ? 4 : 8; // for each block, we fetch 4 tokens for fp32, 8 tokens for other types.
static constexpr int kBlockSize = kThreadPerHead * kTokenPerBlock;
static constexpr int kKVThreadPerHead = (kLoraSize * kBytesPerElem) / kBytesPerLoad;
};
template <typename SrcType, int NUM>
inline __device__ void quantCopy(
__nv_fp8_e4m3* dst_global_ptr, SrcType const* src_fragment_ptr, float const scale_val = 1.f)
{
using DstVecType = typename std::conditional<sizeof(SrcType) == 2, float2, float>::type;
using SrcType2 =
typename std::conditional<sizeof(SrcType) == 2, typename TypeConverter<SrcType>::Type, float2>::type;
static constexpr int COPY_SIZE = sizeof(DstVecType);
static constexpr int TOTAL_COPY_SIZE = NUM * sizeof(__nv_fp8_e4m3);
static constexpr int LOOP_NUM = TOTAL_COPY_SIZE / COPY_SIZE;
static_assert(TOTAL_COPY_SIZE % COPY_SIZE == 0);
static constexpr int CVT_NUM = COPY_SIZE / sizeof(__nv_fp8_e4m3) / 2;
static_assert(COPY_SIZE % (sizeof(__nv_fp8_e4m3) * 2) == 0);
DstVecType fragment;
int offset = 0;
#pragma unroll
for (int i = 0; i < LOOP_NUM; ++i)
{
#pragma unroll
for (int j = 0; j < CVT_NUM; ++j)
{
float2 val2 = cuda_cast<float2>(reinterpret_cast<SrcType2 const*>(src_fragment_ptr)[j + offset]);
val2.x *= scale_val;
val2.y *= scale_val;
reinterpret_cast<__nv_fp8x2_e4m3*>(&fragment)[j] = __nv_fp8x2_e4m3(val2);
}
reinterpret_cast<DstVecType*>(dst_global_ptr)[i] = fragment;
offset += CVT_NUM;
}
}
template <typename DstType, int NUM>
inline __device__ void dequantCopy(
DstType* dst_global_ptr, __nv_fp8_e4m3 const* src_fragment_ptr, float const scale_val = 1.f)
{
using DstVecType = typename VecType<DstType>::Type;
using DstType2 =
typename std::conditional<sizeof(DstType) == 2, typename TypeConverter<DstType>::Type, float2>::type;
static constexpr int COPY_SIZE = sizeof(DstVecType);
static constexpr int TOTAL_COPY_SIZE = NUM * sizeof(DstType);
static constexpr int LOOP_NUM = TOTAL_COPY_SIZE / COPY_SIZE;
static_assert(TOTAL_COPY_SIZE % COPY_SIZE == 0);
static constexpr int CVT_NUM = COPY_SIZE / sizeof(DstType) / 2;
static_assert(COPY_SIZE % (sizeof(DstType) * 2) == 0);
DstVecType fragment;
int offset = 0;
#pragma unroll
for (int i = 0; i < LOOP_NUM; ++i)
{
#pragma unroll
for (int j = 0; j < CVT_NUM; ++j)
{
float2 val2 = cuda_cast<float2>(reinterpret_cast<__nv_fp8x2_e4m3 const*>(src_fragment_ptr)[j + offset]);
val2.x *= scale_val;
val2.y *= scale_val;
reinterpret_cast<DstType2*>(&fragment)[j] = cuda_cast<DstType2>(val2);
}
reinterpret_cast<DstVecType*>(dst_global_ptr)[i] = fragment;
offset += CVT_NUM;
}
}
template <typename T, int BLOCK_SIZE, int K_DIM, int ROPE_DIM, typename KVCacheBuffer>
__global__ void applyMLARopeAndAssignQKVKernelOptContext(T* q_ptr, T* q_pe, T* k_ptr, T const* fuse_buf,
KVCacheBuffer kv_cache, int q_pe_ld, int q_pe_stride, float2 const* cos_sin_cache, size_t head_num, int head_size,
int c_k, int* cu_q_seqlens, int32_t const* kv_cache_lengths, uint32_t max_input_seq_len, KvCacheDataType cache_type,
float const* quant_scale_kv, int32_t const* helix_position_offsets, bool absorption_mode)
{
// Constants.
using VecT = typename VecType<T>::Type;
using GPTJEltT = typename VecType<T>::GPTJEltType;
constexpr auto HEAD_SIZE = ROPE_DIM;
constexpr auto K_HEAD_SIZE = K_DIM;
constexpr auto BYTES_PER_ELT = sizeof(T);
constexpr auto BYTES_PER_LOAD = 16;
constexpr auto ELTS_PER_VEC = BYTES_PER_LOAD / BYTES_PER_ELT;
static_assert((HEAD_SIZE * BYTES_PER_ELT) % BYTES_PER_LOAD == 0, "Head size needs to be multiple of 16 bytes.");
constexpr auto VECS_PER_HEAD = HEAD_SIZE * BYTES_PER_ELT / BYTES_PER_LOAD;
constexpr auto K_VECS_PER_HEAD = K_HEAD_SIZE * BYTES_PER_ELT / BYTES_PER_LOAD;
static_assert(BLOCK_SIZE % VECS_PER_HEAD == 0, "Kernel block should be able to handle entire heads.");
constexpr auto TOKENS_PER_BLOCK = BLOCK_SIZE / VECS_PER_HEAD;
constexpr auto K_TOKENS_PER_BLOCK = BLOCK_SIZE / K_VECS_PER_HEAD;
constexpr auto TOTAL_VECS_PER_HEAD = VECS_PER_HEAD + K_VECS_PER_HEAD;
// Block/Head idx.
size_t const batch_idx = blockIdx.y;
size_t const head_idx = blockIdx.z;
// The nope head_size for q.
// Use the latent_space head size in the absorption mode.
int nope_head_size_q = absorption_mode ? c_k : head_size;
if (head_idx < head_num)
{
size_t const head_dim_vec_idx = (threadIdx.x % VECS_PER_HEAD);
size_t const head_dim_idx = head_dim_vec_idx * ELTS_PER_VEC;
size_t const seq_len_loop_end
= size_t((max_input_seq_len + TOKENS_PER_BLOCK - 1) / TOKENS_PER_BLOCK) * TOKENS_PER_BLOCK;
float quant_scale_kv_val = quant_scale_kv ? quant_scale_kv[0] : 1.f;
// Mainloop.
for (int local_token_idx = (threadIdx.x / VECS_PER_HEAD) + blockIdx.x * TOKENS_PER_BLOCK;
local_token_idx < seq_len_loop_end; local_token_idx += TOKENS_PER_BLOCK * gridDim.x)
{
int const global_token_offset = cu_q_seqlens[batch_idx];
int const cache_seq_len = kv_cache_lengths[batch_idx];
int token_idx_in_kv_cache = local_token_idx;
bool const valid_token = token_idx_in_kv_cache < cache_seq_len;
// Limit the token_idx to cache seq length (we need all threads in this block to be involved).
token_idx_in_kv_cache = std::min(token_idx_in_kv_cache, cache_seq_len - 1);
local_token_idx = std::min(local_token_idx, cache_seq_len - 1);
int const global_token_idx = local_token_idx + global_token_offset;
auto const position_id
= helix_position_offsets ? helix_position_offsets[global_token_idx] : local_token_idx;
float2 const* rotary_coef_cache_buffer
= cos_sin_cache + static_cast<size_t>(ROPE_DIM) * position_id + (head_dim_idx / 2);
VecT q, k;
auto const src_k_global_offset = static_cast<size_t>(global_token_idx) * (c_k + ROPE_DIM) + c_k;
auto src_q_global_offset = static_cast<size_t>(global_token_idx) * head_num * (head_size + ROPE_DIM)
+ (head_size + ROPE_DIM) * head_idx + head_size;
// In the absorption mode, we load pe from q_pe instead of q_ptr.
T* q_pe_input = q_ptr;
if (absorption_mode)
{
q_pe_input = q_pe;
src_q_global_offset = static_cast<size_t>(global_token_idx) * q_pe_stride + q_pe_ld * head_idx;
}
q = *reinterpret_cast<VecT const*>(&q_pe_input[src_q_global_offset + head_dim_idx]);
k = *reinterpret_cast<VecT const*>(&fuse_buf[src_k_global_offset + head_dim_idx]);
// Pack two elements into one for gptj rotary embedding.
#pragma unroll
for (int elt_id = 0; elt_id < ELTS_PER_VEC / 2; elt_id++)
{
GPTJEltT& q_ = reinterpret_cast<GPTJEltT*>(&q)[elt_id];
GPTJEltT& k_ = reinterpret_cast<GPTJEltT*>(&k)[elt_id];
float2 rotary_coef_cache = rotary_coef_cache_buffer[elt_id];
mmha::apply_rotary_embedding_gptj(q_, k_, rotary_coef_cache);
}
// do sync
__syncwarp();
if (valid_token)
{
if (head_idx == 0)
{
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_idx_in_kv_cache));
auto inBlockIdx = kv_cache.getKVLocalIdx(
token_idx_in_kv_cache, 0, TOTAL_VECS_PER_HEAD, K_VECS_PER_HEAD + head_dim_vec_idx);
if (cache_type == KvCacheDataType::FP8)
{
quantCopy<T, ELTS_PER_VEC>(reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
reinterpret_cast<T const*>(&k), quant_scale_kv_val);
}
else
reinterpret_cast<VecT*>(kDst)[inBlockIdx] = k;
}
auto const dst_q_idx = static_cast<size_t>(global_token_idx) * head_num * (nope_head_size_q + ROPE_DIM)
+ head_idx * (nope_head_size_q + ROPE_DIM) + nope_head_size_q + head_dim_idx;
auto const dst_k_idx = static_cast<size_t>(global_token_idx) * head_num * (head_size + ROPE_DIM)
+ head_idx * (head_size + ROPE_DIM) + head_size + head_dim_idx;
reinterpret_cast<VecT*>(q_ptr)[dst_q_idx / ELTS_PER_VEC] = q;
// Only write to k_pe to k_buf in the non-absorption mode.
if (!absorption_mode)
{
reinterpret_cast<VecT*>(k_ptr)[dst_k_idx / ELTS_PER_VEC] = k;
}
}
}
}
else
{
int block_dim = gridDim.z - head_num;
int block_id = head_idx - head_num;
size_t const head_dim_vec_idx = (threadIdx.x % K_VECS_PER_HEAD);
size_t const head_dim_idx = head_dim_vec_idx * ELTS_PER_VEC;
size_t const seq_len_loop_end
= size_t((max_input_seq_len + K_TOKENS_PER_BLOCK - 1) / K_TOKENS_PER_BLOCK) * K_TOKENS_PER_BLOCK;
float quant_scale_kv_val = quant_scale_kv ? quant_scale_kv[0] : 1.f;
// Mainloop.
for (int local_token_idx = (threadIdx.x / K_VECS_PER_HEAD) + gridDim.x * K_TOKENS_PER_BLOCK * block_id
+ blockIdx.x * K_TOKENS_PER_BLOCK;
local_token_idx < seq_len_loop_end; local_token_idx += block_dim * K_TOKENS_PER_BLOCK * gridDim.x)
{
int const global_token_offset = cu_q_seqlens[batch_idx];
int const cache_seq_len = kv_cache_lengths[batch_idx];
int token_idx_in_kv_cache = local_token_idx;
bool const valid_token = token_idx_in_kv_cache < cache_seq_len;
// Limit the token_idx to cache seq length (we need all threads in this block to be involved).
token_idx_in_kv_cache = std::min(token_idx_in_kv_cache, cache_seq_len - 1);
local_token_idx = std::min(local_token_idx, cache_seq_len - 1);
int const global_token_idx = local_token_idx + global_token_offset;
if (valid_token)
{
auto const src_k_global_offset = static_cast<size_t>(global_token_idx) * (c_k + ROPE_DIM);
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_idx_in_kv_cache));
auto inBlockIdx
= kv_cache.getKVLocalIdx(token_idx_in_kv_cache, 0, TOTAL_VECS_PER_HEAD, head_dim_vec_idx);
if (cache_type == KvCacheDataType::FP8)
{
quantCopy<T, ELTS_PER_VEC>(reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
fuse_buf + src_k_global_offset + head_dim_idx, quant_scale_kv_val);
}
else
reinterpret_cast<VecT*>(kDst)[inBlockIdx]
= *reinterpret_cast<VecT const*>(&fuse_buf[src_k_global_offset + head_dim_idx]);
}
}
}
}
template <typename T, int BLOCK_SIZE, int K_DIM, int ROPE_DIM, typename KVCacheBuffer>
__global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe, T const* fuse_buf, void* quant_q,
KVCacheBuffer kv_cache, float2 const* cos_sin_cache, size_t head_num, int c_k, int total_s_len, int seq_len,
int* seqQOffset, uint32_t* fmha_tile_counter, int32_t const* kv_cache_lengths, int* seqKVOffsets, int q_pe_ld,
int q_pe_stride, KvCacheDataType cache_type, float* bmm1_scale, float* bmm2_scale, float const* quant_scale_o,
float const* quant_scale_q, float const* quant_scale_kv, float const* dequant_scale_q,
float const* dequant_scale_kv, float host_bmm1_scale, int32_t const* helix_position_offsets,
bool const* helix_is_inactive_rank)
{
// Constants.
using VecT = typename VecType<T>::Type;
using GPTJEltT = typename VecType<T>::GPTJEltType;
constexpr auto HEAD_SIZE = ROPE_DIM;
constexpr auto K_HEAD_SIZE = K_DIM;
constexpr auto BYTES_PER_ELT = sizeof(T);
constexpr auto BYTES_PER_LOAD = 16;
constexpr auto ELTS_PER_VEC = BYTES_PER_LOAD / BYTES_PER_ELT;
static_assert((HEAD_SIZE * BYTES_PER_ELT) % BYTES_PER_LOAD == 0, "Head size needs to be multiple of 16 bytes.");
constexpr auto VECS_PER_HEAD = HEAD_SIZE * BYTES_PER_ELT / BYTES_PER_LOAD;
constexpr auto K_VECS_PER_HEAD = K_HEAD_SIZE * BYTES_PER_ELT / BYTES_PER_LOAD;
static_assert(BLOCK_SIZE % VECS_PER_HEAD == 0, "Kernel block should be able to handle entire heads.");
constexpr auto TOKENS_PER_BLOCK = BLOCK_SIZE / VECS_PER_HEAD;
constexpr auto K_TOKENS_PER_BLOCK = BLOCK_SIZE / K_VECS_PER_HEAD;
constexpr auto TOTAL_VEC_PER_HEAD = VECS_PER_HEAD + K_VECS_PER_HEAD;
// Block/Head idx.
size_t const head_idx = blockIdx.y;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0)
{
fmha_tile_counter[0] = 0;
seqQOffset[0] = 0;
// Calculate bmm scale for FP8 MLA
if (cache_type == KvCacheDataType::FP8)
{
float dequant_scale_q_val = dequant_scale_q ? dequant_scale_q[0] : 1.f;
float dequant_scale_kv_val = dequant_scale_kv ? dequant_scale_kv[0] : 1.f;
float quant_scale_o_val = quant_scale_o ? quant_scale_o[0] : 1.f;
if (bmm1_scale)
{
// The scale prepared for log2 optimization.
constexpr float kLog2e = 1.4426950408889634074f;
// The scale after fmha bmm1.
float bmm1_scale_val = dequant_scale_q_val * dequant_scale_kv_val * host_bmm1_scale;
bmm1_scale[0] = bmm1_scale_val;
bmm1_scale[1] = bmm1_scale_val * kLog2e;
}
if (bmm2_scale)
{
// The scale after fmha bmm2.
bmm2_scale[0] = quant_scale_o_val * dequant_scale_kv_val;
}
}
}
if (head_idx <= head_num)
{
size_t const head_dim_vec_idx = (threadIdx.x % VECS_PER_HEAD);
size_t const head_dim_idx = head_dim_vec_idx * ELTS_PER_VEC;
int const seq_len_loop_end = size_t((total_s_len + TOKENS_PER_BLOCK - 1) / TOKENS_PER_BLOCK) * TOKENS_PER_BLOCK;
float const quant_scale_q_val = quant_scale_q ? quant_scale_q[0] : 1.0f;
float const quant_scale_kv_val = quant_scale_kv ? quant_scale_kv[0] : 1.0f;
// Mainloop.
for (int global_token_idx = (threadIdx.x / VECS_PER_HEAD) + blockIdx.x * TOKENS_PER_BLOCK;
global_token_idx < seq_len_loop_end; global_token_idx += TOKENS_PER_BLOCK * gridDim.x)
{
auto batch_idx = global_token_idx / seq_len;
auto local_token_idx = global_token_idx % seq_len;
bool const valid_token = global_token_idx < total_s_len;
VecT data;
if (valid_token)
{
auto const position_id
= (helix_position_offsets != nullptr ? helix_position_offsets[global_token_idx]
: kv_cache_lengths[batch_idx] - seq_len + local_token_idx);
float2 const* rotary_coef_cache_buffer
= cos_sin_cache + static_cast<size_t>(ROPE_DIM) * position_id + (head_dim_idx / 2);
if (head_idx == head_num)
{
auto const src_k_global_offset = static_cast<size_t>(global_token_idx) * (c_k + ROPE_DIM) + c_k;
data = *reinterpret_cast<VecT const*>(&fuse_buf[src_k_global_offset + head_dim_idx]);
}
else
{
auto const src_q_global_offset
= static_cast<size_t>(global_token_idx) * q_pe_stride + q_pe_ld * head_idx;
data = *reinterpret_cast<VecT const*>(&q_pe[src_q_global_offset + head_dim_idx]);
}
// Pack two elements into one for gptj rotary embedding.
#pragma unroll
for (int elt_id = 0; elt_id < ELTS_PER_VEC / 2; elt_id++)
{
GPTJEltT& data_ = reinterpret_cast<GPTJEltT*>(&data)[elt_id];
float2 rotary_coef_cache = rotary_coef_cache_buffer[elt_id];
data_ = mmha::rotary_embedding_transform(data_, rotary_coef_cache);
}
}
__syncwarp();
if (valid_token)
{
if (head_idx == head_num)
{
auto const token_kv_idx = kv_cache_lengths[batch_idx] - seq_len + local_token_idx;
{
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_kv_idx));
auto inBlockIdx = kv_cache.getKVLocalIdx(
token_kv_idx, 0, TOTAL_VEC_PER_HEAD, K_VECS_PER_HEAD + head_dim_vec_idx);
if (cache_type == KvCacheDataType::FP8)
{
quantCopy<T, ELTS_PER_VEC>(
reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
reinterpret_cast<T const*>(&data), quant_scale_kv_val);
}
else
reinterpret_cast<VecT*>(kDst)[inBlockIdx] = data;
}
}
else
{
auto const dst_q_idx = static_cast<size_t>(global_token_idx) * head_num * (c_k + ROPE_DIM)
+ head_idx * (c_k + ROPE_DIM) + c_k + head_dim_idx;
if (cache_type == KvCacheDataType::FP8)
{
quantCopy<T, ELTS_PER_VEC>(reinterpret_cast<__nv_fp8_e4m3*>(quant_q) + dst_q_idx,
reinterpret_cast<T const*>(&data), quant_scale_q_val);
}
else
reinterpret_cast<VecT*>(qkv_output)[dst_q_idx / ELTS_PER_VEC] = data;
}
}
}
}
else if (head_idx <= head_num + 8)
{
int block_dim = gridDim.y - head_num - 1;
int block_id = head_idx - head_num - 1;
size_t const head_dim_vec_idx = (threadIdx.x % K_VECS_PER_HEAD);
size_t const head_dim_idx = head_dim_vec_idx * ELTS_PER_VEC;
size_t const seq_len_loop_end
= size_t((total_s_len + K_TOKENS_PER_BLOCK - 1) / K_TOKENS_PER_BLOCK) * K_TOKENS_PER_BLOCK;
float quant_scale_kv_val = quant_scale_kv ? quant_scale_kv[0] : 1.0f;
// Mainloop.
for (int global_token_idx = (threadIdx.x / K_VECS_PER_HEAD) + gridDim.x * K_TOKENS_PER_BLOCK * block_id
+ blockIdx.x * K_TOKENS_PER_BLOCK;
global_token_idx < seq_len_loop_end; global_token_idx += block_dim * K_TOKENS_PER_BLOCK * gridDim.x)
{
auto batch_idx = global_token_idx / seq_len;
auto local_token_idx = global_token_idx % seq_len;
bool valid_token = global_token_idx < total_s_len;
if (valid_token && (helix_is_inactive_rank == nullptr || !helix_is_inactive_rank[batch_idx]))
{
if (head_dim_vec_idx == 0)
{
seqQOffset[batch_idx + 1] = head_num * seq_len * (batch_idx + 1);
}
auto const token_kv_idx = kv_cache_lengths[batch_idx] - seq_len + local_token_idx;
auto const src_kv_global_offset = static_cast<size_t>(global_token_idx) * (c_k + ROPE_DIM);
{
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_kv_idx));
auto inBlockIdx = kv_cache.getKVLocalIdx(token_kv_idx, 0, TOTAL_VEC_PER_HEAD, head_dim_vec_idx);
if (cache_type == KvCacheDataType::FP8)
{
quantCopy<T, ELTS_PER_VEC>(reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
fuse_buf + src_kv_global_offset + head_dim_idx, quant_scale_kv_val);
}
else
reinterpret_cast<VecT*>(kDst)[inBlockIdx]
= *reinterpret_cast<VecT const*>(&fuse_buf[src_kv_global_offset + head_dim_idx]);
}
}
}
}
else
{
if (cache_type == KvCacheDataType::FP8)
{
int block_dim = gridDim.y - head_num - 1 - 8;
int block_id = head_idx - head_num - 1 - 8;
size_t const head_dim_vec_idx = (threadIdx.x % K_VECS_PER_HEAD);
size_t const head_dim_idx = head_dim_vec_idx * ELTS_PER_VEC;
size_t const head_num_idx = (block_id % head_num) * (K_HEAD_SIZE + HEAD_SIZE);
size_t const seq_len_loop_end
= size_t((total_s_len + K_TOKENS_PER_BLOCK - 1) / K_TOKENS_PER_BLOCK) * K_TOKENS_PER_BLOCK;
float quant_scale_q_val = quant_scale_q ? quant_scale_q[0] : 1.0f;
// Mainloop.
for (int global_token_idx = (threadIdx.x / K_VECS_PER_HEAD)
+ (block_id / head_num) * gridDim.x * K_TOKENS_PER_BLOCK + blockIdx.x * K_TOKENS_PER_BLOCK;
global_token_idx < seq_len_loop_end;
global_token_idx += (block_dim / head_num) * gridDim.x * K_TOKENS_PER_BLOCK)
{
if (global_token_idx < total_s_len)
{
size_t const load_idx
= global_token_idx * head_num * (K_HEAD_SIZE + HEAD_SIZE) + head_num_idx + head_dim_idx;
quantCopy<T, ELTS_PER_VEC>(
reinterpret_cast<__nv_fp8_e4m3*>(quant_q) + load_idx, qkv_output + load_idx, quant_scale_q_val);
}
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
// The implementation of the parallel scan in the thread block (see CUB for details).
using BlockScan = cub::BlockScan<int, BLOCK_SIZE>;
// Allocate storage in shared memory to do the scan.
__shared__ typename BlockScan::TempStorage tempKVStorage;
BlockPrefixCallbackOp prefixKVOp(0);
if (blockIdx.x == 0 && blockIdx.y == 0)
{
int const batchSizeBound = total_s_len / seq_len;
for (int batchOffset = 0; batchOffset <= batchSizeBound; batchOffset += BLOCK_SIZE)
{
// The index of the batch.
int batchIdx = batchOffset + threadIdx.x;
int seqKVLength = 0;
if (batchIdx < batchSizeBound)
{
seqKVLength = kv_cache_lengths[batchIdx];
}
int seqKVOffset;
BlockScan(tempKVStorage).ExclusiveSum(seqKVLength, seqKVOffset, prefixKVOp);
if (batchIdx <= batchSizeBound)
{
seqKVOffsets[batchIdx] = seqKVOffset;
}
}
}
}
template <typename T, typename TCache>
__global__ void loadPagedKVCacheForMLAKernel(T* compressed_kv_ptr, T* k_pe_ptr,
tensorrt_llm::kernels::KVBlockArray const kv_cache, int64_t const* cu_ctx_cached_kv_lens, int max_input_seq_len,
float const* kv_scale_quant_orig_ptr)
{
static_assert(std::is_same_v<T, TCache> || std::is_same_v<TCache, __nv_fp8_e4m3>,
"TCache must be either the same type as T or __nv_fp8_e4m3");
using KT = typename tensorrt_llm::kernels::loadPagedKVKernelTraits<TCache>;
int const batch_idx = static_cast<int>(blockIdx.y);
float const kv_scale_quant_orig = kv_scale_quant_orig_ptr ? kv_scale_quant_orig_ptr[0] : 1.0f;
size_t const head_dim_vec_idx = (threadIdx.x % KT::kVecPerHead);
size_t const head_dim_idx = head_dim_vec_idx * KT::kElemPerLoad;
bool const is_valid_kv = head_dim_vec_idx < KT::kKVThreadPerHead;
size_t const seq_len_loop_end
= (max_input_seq_len + KT::kTokenPerBlock - 1) / KT::kTokenPerBlock * KT::kTokenPerBlock;
int64_t const global_token_offset = cu_ctx_cached_kv_lens[batch_idx];
int64_t const cache_kv_len = cu_ctx_cached_kv_lens[batch_idx + 1] - cu_ctx_cached_kv_lens[batch_idx];
for (int local_token_idx = (threadIdx.x / KT::kThreadPerHead) + blockIdx.x * KT::kTokenPerBlock;
local_token_idx < seq_len_loop_end; local_token_idx += KT::kTokenPerBlock * gridDim.x)
{
int token_idx_in_kv_cache = local_token_idx;
bool const valid_token = token_idx_in_kv_cache < cache_kv_len;
if (valid_token)
{
auto* kvSrc = reinterpret_cast<TCache*>(kv_cache.getKBlockPtr(batch_idx, token_idx_in_kv_cache));
// head_idx === 0
auto kvBlockIdx
= kv_cache.getKVLocalIdx(token_idx_in_kv_cache, 0, KT::kVecPerHead, static_cast<int>(head_dim_vec_idx));
auto src_data = reinterpret_cast<typename KT::VecT*>(kvSrc)[kvBlockIdx];
int const global_token_idx = local_token_idx + global_token_offset;
if (is_valid_kv)
{
// compressed_kv {total_token, lora_size}
int const dstIdx = global_token_idx * KT::kLoraSize + head_dim_idx;
// copy back to compressed_kv
if constexpr (std::is_same_v<TCache, T>)
{
*reinterpret_cast<typename KT::VecT*>(compressed_kv_ptr + dstIdx) = src_data;
}
else if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
{
dequantCopy<T, KT::kElemPerLoad>(compressed_kv_ptr + dstIdx,
reinterpret_cast<__nv_fp8_e4m3 const*>(&src_data), kv_scale_quant_orig);
}
}
else
{
// k_pe {total_token, rope_size}
int const dstIdx = global_token_idx * KT::kRopeSize + (head_dim_idx - KT::kLoraSize);
// copy back to k_pe
if constexpr (std::is_same_v<TCache, T>)
{
*reinterpret_cast<typename KT::VecT*>(k_pe_ptr + dstIdx) = src_data;
}
else if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
{
dequantCopy<T, KT::kElemPerLoad>(
k_pe_ptr + dstIdx, reinterpret_cast<__nv_fp8_e4m3 const*>(&src_data), kv_scale_quant_orig);
}
}
}
}
}
// q {total_uncached_tokens, h, d_nope + d_rope}
// latent_cache {total_uncached_tokens, d_k + d_rope}
template <typename T, typename TCache, int BLOCK_SIZE, int K_DIM, int ROPE_DIM>
__global__ void applyMLARopeAppendPagedKVAssignQKernel(KVBlockArray kv_cache, T* q_ptr, T* latent_cache_ptr,
int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens, int const max_input_uncached_seq_len,
float2 const* cos_sin_cache, size_t head_num, int nope_size, float const* kv_scale_orig_quant_ptr)
{
static_assert(std::is_same_v<T, TCache> || std::is_same_v<TCache, __nv_fp8_e4m3>,
"TCache must be either the same type as T or __nv_fp8_e4m3");
// Constants.
using VecT = typename VecType<T>::Type;
using GPTJEltT = typename VecType<T>::GPTJEltType;
constexpr auto HEAD_SIZE = ROPE_DIM;
constexpr auto K_HEAD_SIZE = K_DIM;
constexpr auto BYTES_PER_ELT = sizeof(T);
constexpr auto BYTES_PER_LOAD = 16;
constexpr auto ELTS_PER_VEC = BYTES_PER_LOAD / BYTES_PER_ELT;
static_assert((HEAD_SIZE * BYTES_PER_ELT) % BYTES_PER_LOAD == 0, "Head size needs to be multiple of 16 bytes.");
constexpr auto VECS_PER_HEAD = HEAD_SIZE * BYTES_PER_ELT / BYTES_PER_LOAD;
constexpr auto K_VECS_PER_HEAD = K_HEAD_SIZE * BYTES_PER_ELT / BYTES_PER_LOAD;
static_assert(BLOCK_SIZE % VECS_PER_HEAD == 0, "Kernel block should be able to handle entire heads.");
constexpr auto TOKENS_PER_BLOCK = BLOCK_SIZE / VECS_PER_HEAD;
constexpr auto K_TOKENS_PER_BLOCK = BLOCK_SIZE / K_VECS_PER_HEAD;
constexpr auto TOTAL_VECS_PER_HEAD = VECS_PER_HEAD + K_VECS_PER_HEAD;
// Block/Head idx.
size_t const batch_idx = blockIdx.y;
size_t const head_idx = blockIdx.z;
int64_t const global_token_offset = cu_seq_lens[batch_idx] - cu_ctx_cached_kv_lens[batch_idx];
int64_t const cached_kv_len = cu_ctx_cached_kv_lens[batch_idx + 1] - cu_ctx_cached_kv_lens[batch_idx];
int64_t const uncached_kv_len = cu_seq_lens[batch_idx + 1] - cu_seq_lens[batch_idx] - cached_kv_len;
if (head_idx <= head_num)
{
size_t const head_dim_vec_idx = (threadIdx.x % VECS_PER_HEAD);
size_t const head_dim_idx = head_dim_vec_idx * ELTS_PER_VEC;
size_t const seq_len_loop_end
= size_t((max_input_uncached_seq_len + TOKENS_PER_BLOCK - 1) / TOKENS_PER_BLOCK) * TOKENS_PER_BLOCK;
float quant_scale_kv_val = kv_scale_orig_quant_ptr ? kv_scale_orig_quant_ptr[0] : 1.f;
// Mainloop.
for (int local_token_idx = (threadIdx.x / VECS_PER_HEAD) + blockIdx.x * TOKENS_PER_BLOCK;
local_token_idx < seq_len_loop_end; local_token_idx += TOKENS_PER_BLOCK * gridDim.x)
{
int token_idx_in_kv_cache = local_token_idx + cached_kv_len;
bool valid_token = local_token_idx < uncached_kv_len;
int const global_token_idx = local_token_idx + global_token_offset;
VecT data;
if (valid_token)
{
auto const position_id = token_idx_in_kv_cache;
float2 const* rotary_coef_cache_buffer
= cos_sin_cache + static_cast<size_t>(ROPE_DIM) * position_id + (head_dim_idx / 2);
if (head_idx == head_num)
{
auto const src_k_global_offset = static_cast<size_t>(global_token_idx) * (K_DIM + ROPE_DIM) + K_DIM;
data = *reinterpret_cast<VecT const*>(&latent_cache_ptr[src_k_global_offset + head_dim_idx]);
}
else
{
auto const src_q_global_offset
= static_cast<size_t>(global_token_idx) * head_num * (nope_size + ROPE_DIM)
+ (nope_size + ROPE_DIM) * head_idx + nope_size;
data = *reinterpret_cast<VecT const*>(&q_ptr[src_q_global_offset + head_dim_idx]);
}
// Pack two elements into one for gptj rotary embedding.
#pragma unroll
for (int elt_id = 0; elt_id < ELTS_PER_VEC / 2; elt_id++)
{
GPTJEltT& data_ = reinterpret_cast<GPTJEltT*>(&data)[elt_id];
float2 rotary_coef_cache = rotary_coef_cache_buffer[elt_id];
data_ = mmha::rotary_embedding_transform(data_, rotary_coef_cache);
}
}
// do sync
__syncwarp();
if (valid_token)
{
if (head_idx == head_num)
{
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_idx_in_kv_cache));
auto inBlockIdx = kv_cache.getKVLocalIdx(
token_idx_in_kv_cache, 0, TOTAL_VECS_PER_HEAD, K_VECS_PER_HEAD + head_dim_vec_idx);
if constexpr (std::is_same_v<TCache, T>)
{
reinterpret_cast<VecT*>(kDst)[inBlockIdx] = data;
}
else if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
{
quantCopy<T, ELTS_PER_VEC>(reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
reinterpret_cast<T const*>(&data), quant_scale_kv_val);
}
// copy to latent_cache (for chunked prefill, it will not load kv cache for uncached k_pe)
// we only need to copy original value.
auto const src_k_global_offset = static_cast<size_t>(global_token_idx) * (K_DIM + ROPE_DIM) + K_DIM;
*reinterpret_cast<VecT*>(&latent_cache_ptr[src_k_global_offset + head_dim_idx]) = data;
}
else
{
auto const dst_q_idx = static_cast<size_t>(global_token_idx) * head_num * (nope_size + ROPE_DIM)
+ head_idx * (nope_size + ROPE_DIM) + nope_size + head_dim_idx;
reinterpret_cast<VecT*>(q_ptr)[dst_q_idx / ELTS_PER_VEC] = data;
}
}
}
}
else
{
int block_dim = gridDim.z - head_num - 1;
int block_id = head_idx - head_num - 1;
size_t const head_dim_vec_idx = (threadIdx.x % K_VECS_PER_HEAD);
size_t const head_dim_idx = head_dim_vec_idx * ELTS_PER_VEC;
size_t const seq_len_loop_end
= size_t((max_input_uncached_seq_len + K_TOKENS_PER_BLOCK - 1) / K_TOKENS_PER_BLOCK) * K_TOKENS_PER_BLOCK;
float quant_scale_kv_val = kv_scale_orig_quant_ptr ? kv_scale_orig_quant_ptr[0] : 1.f;
// Mainloop.
for (int local_token_idx = (threadIdx.x / K_VECS_PER_HEAD) + gridDim.x * K_TOKENS_PER_BLOCK * block_id
+ blockIdx.x * K_TOKENS_PER_BLOCK;
local_token_idx < seq_len_loop_end; local_token_idx += block_dim * K_TOKENS_PER_BLOCK * gridDim.x)
{
int token_idx_in_kv_cache = local_token_idx + cached_kv_len;
bool valid_token = local_token_idx < uncached_kv_len;
int const global_token_idx = local_token_idx + global_token_offset;
if (valid_token)
{
auto const src_k_global_offset = static_cast<size_t>(global_token_idx) * (K_DIM + ROPE_DIM);
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_idx_in_kv_cache));
auto inBlockIdx
= kv_cache.getKVLocalIdx(token_idx_in_kv_cache, 0, TOTAL_VECS_PER_HEAD, head_dim_vec_idx);
if constexpr (std::is_same_v<TCache, T>)
{
reinterpret_cast<VecT*>(kDst)[inBlockIdx]
= *reinterpret_cast<VecT const*>(&latent_cache_ptr[src_k_global_offset + head_dim_idx]);
}
else if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
{
quantCopy<T, ELTS_PER_VEC>(reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
latent_cache_ptr + src_k_global_offset + head_dim_idx, quant_scale_kv_val);
}
}
}
}
}
template <typename T, int BLOCK_SIZE, int QK_NOPE_HEAD_DIM, int QK_ROPE_HEAD_DIM, int V_HEAD_DIM, bool ABSORPTION_MODE>
__global__ void quantizeCopyInputToFp8Kernel(T const* q_buf, __nv_fp8_e4m3* quant_q_buf, T const* k_buf,
__nv_fp8_e4m3* quant_k_buf, T const* v_buf, __nv_fp8_e4m3* quant_v_buf, int total_q_len, int total_kv_len,
float const* quant_scale_qkv_ptr, float* bmm1_scale, float* bmm2_scale, float const* quant_scale_o,
float const* dequant_scale_q, float const* dequant_scale_kv, float host_bmm1_scale)
{
// Constants.
using VecT = typename VecType<T>::Type;
constexpr auto BYTES_PER_ELT = sizeof(T);
constexpr auto BYTES_PER_LOAD = 16;
constexpr auto ELTS_PER_VEC = BYTES_PER_LOAD / BYTES_PER_ELT;
constexpr auto QK_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM;
static_assert(
(QK_HEAD_DIM * BYTES_PER_ELT) % BYTES_PER_LOAD == 0, "QK head size needs to be multiple of 16 bytes.");
static_assert((V_HEAD_DIM * BYTES_PER_ELT) % BYTES_PER_LOAD == 0, "V head size needs to be multiple of 16 bytes.");
constexpr auto QK_VECS_PER_HEAD = QK_HEAD_DIM * BYTES_PER_ELT / BYTES_PER_LOAD;
constexpr auto V_VECS_PER_HEAD = V_HEAD_DIM * BYTES_PER_ELT / BYTES_PER_LOAD;
static_assert(BLOCK_SIZE % QK_VECS_PER_HEAD == 0, "Kernel block should be able to handle entire heads.");
static_assert(ABSORPTION_MODE || (BLOCK_SIZE % V_VECS_PER_HEAD) == 0,
"Kernel block should be able to handle entire heads in non-absorption mode.");
constexpr auto QK_TOKENS_PER_BLOCK = BLOCK_SIZE / QK_VECS_PER_HEAD;
constexpr auto V_TOKENS_PER_BLOCK = BLOCK_SIZE / V_VECS_PER_HEAD;
size_t const head_idx = blockIdx.z;
size_t const head_num = gridDim.z;
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0)
{
// Calculate bmm scale for FP8 MLA
float dequant_scale_q_val = dequant_scale_q ? dequant_scale_q[0] : 1.f;
float dequant_scale_kv_val = dequant_scale_kv ? dequant_scale_kv[0] : 1.f;
float quant_scale_o_val = quant_scale_o ? quant_scale_o[0] : 1.f;
if (bmm1_scale)
{
// The scale prepared for log2 optimization.
constexpr float kLog2e = 1.4426950408889634074f;
// The scale after fmha bmm1.
float bmm1_scale_val = dequant_scale_q_val * dequant_scale_kv_val * host_bmm1_scale;
bmm1_scale[0] = bmm1_scale_val;
bmm1_scale[1] = bmm1_scale_val * kLog2e;
}
if (bmm2_scale)
{
// The scale after fmha bmm2.
bmm2_scale[0] = quant_scale_o_val * dequant_scale_kv_val;
}
}
size_t const qk_head_dim_vec_idx = (threadIdx.x % QK_VECS_PER_HEAD);
size_t const v_head_dim_vec_idx = (threadIdx.x % V_VECS_PER_HEAD);
size_t const qk_head_dim_idx = qk_head_dim_vec_idx * ELTS_PER_VEC;
size_t const v_head_dim_idx = v_head_dim_vec_idx * ELTS_PER_VEC;
size_t const q_len_loop_end
= size_t((total_q_len + QK_TOKENS_PER_BLOCK - 1) / QK_TOKENS_PER_BLOCK) * QK_TOKENS_PER_BLOCK;
size_t const k_len_loop_end
= size_t((total_kv_len + QK_TOKENS_PER_BLOCK - 1) / QK_TOKENS_PER_BLOCK) * QK_TOKENS_PER_BLOCK;
size_t const v_len_loop_end
= size_t((total_kv_len + V_TOKENS_PER_BLOCK - 1) / V_TOKENS_PER_BLOCK) * V_TOKENS_PER_BLOCK;
float quant_scale_qkv_val = quant_scale_qkv_ptr ? quant_scale_qkv_ptr[0] : 1.f;
// Quantize Q, both src and dst are contiguous
for (int q_token_idx = (threadIdx.x / QK_VECS_PER_HEAD) + blockIdx.x * QK_TOKENS_PER_BLOCK;
q_token_idx < q_len_loop_end; q_token_idx += QK_TOKENS_PER_BLOCK * gridDim.x)
{
if (q_token_idx < total_q_len)
{
auto const src_q_idx
= static_cast<size_t>(q_token_idx) * QK_HEAD_DIM * head_num + head_idx * QK_HEAD_DIM + qk_head_dim_idx;
auto const dst_q_idx = src_q_idx;
quantCopy<T, ELTS_PER_VEC>(quant_q_buf + dst_q_idx, &q_buf[src_q_idx], quant_scale_qkv_val);
}
}
// Only quantize K and V in non-absorption mode.
if constexpr (!ABSORPTION_MODE)
{
// Quantize K, both src and dst are contiguous
for (int k_token_idx = (threadIdx.x / QK_VECS_PER_HEAD) + blockIdx.x * QK_TOKENS_PER_BLOCK;
k_token_idx < k_len_loop_end; k_token_idx += QK_TOKENS_PER_BLOCK * gridDim.x)
{
if (k_token_idx < total_kv_len)
{
auto const src_k_idx = static_cast<size_t>(k_token_idx) * QK_HEAD_DIM * head_num
+ head_idx * QK_HEAD_DIM + qk_head_dim_idx;
auto const dst_k_idx = src_k_idx;
quantCopy<T, ELTS_PER_VEC>(quant_k_buf + dst_k_idx, &k_buf[src_k_idx], quant_scale_qkv_val);
}
}
// Quantize V, dst V is contiguous, but src V is not contiguous, so we need to calculate the stride
size_t const src_v_token_stride = (QK_NOPE_HEAD_DIM + V_HEAD_DIM) * head_num;
for (int v_token_idx = (threadIdx.x / V_VECS_PER_HEAD) + blockIdx.x * V_TOKENS_PER_BLOCK;
v_token_idx < v_len_loop_end; v_token_idx += V_TOKENS_PER_BLOCK * gridDim.x)
{
if (v_token_idx < total_kv_len)
{
auto const src_v_idx
= static_cast<size_t>(v_token_idx) * src_v_token_stride + head_idx * V_HEAD_DIM + v_head_dim_idx;
auto const dst_v_idx
= static_cast<size_t>(v_token_idx) * V_HEAD_DIM * head_num + head_idx * V_HEAD_DIM + v_head_dim_idx;
quantCopy<T, ELTS_PER_VEC>(quant_v_buf + dst_v_idx, &v_buf[src_v_idx], quant_scale_qkv_val);
}
}
}
}
template <typename T, typename KVCacheBuffer>
void invokeMLARopeContext(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, cudaStream_t stream)
{
dim3 grid(int(tensorrt_llm::common::divUp(params.max_input_seq_len, 32)), params.batch_size, params.head_num + 8);
auto head_size = params.meta.qk_nope_head_dim;
applyMLARopeAndAssignQKVKernelOptContext<T, 256, 512, 64, KVCacheBuffer><<<grid, 256, 0, stream>>>(params.q_buf,
params.q_pe, params.k_buf, params.latent_cache, kv_cache_buffer, params.q_pe_ld, params.q_pe_stride,
params.cos_sin_cache, params.head_num, head_size, params.meta.kv_lora_rank, params.cu_q_seqlens,
params.cache_seq_lens, params.max_input_seq_len, params.cache_type, params.quant_scale_kv,
params.helix_position_offsets, params.absorption_mode);
}
template <typename T>
void invokeMLAContextFp8Quantize(MlaParams<T>& params, int total_kv_len, cudaStream_t stream)
{
TLLM_CHECK_WITH_INFO(params.cache_type == KvCacheDataType::FP8, "MLA Context: cache_type must be FP8");
TLLM_CHECK_WITH_INFO(params.q_buf != nullptr, "MLA Context: q_buf must be non-null");
TLLM_CHECK_WITH_INFO(params.absorption_mode || params.k_buf != nullptr,
"MLA Context: k_buf must be non-null in non-absorption mode");
TLLM_CHECK_WITH_INFO(params.absorption_mode || params.v_buf != nullptr,
"MLA Context: v_buf must be non-null in non-absorption mode");
TLLM_CHECK_WITH_INFO(params.quant_q_buf != nullptr, "MLA Context: quant_q_buf must be non-null");
TLLM_CHECK_WITH_INFO(params.absorption_mode || params.quant_k_buf != nullptr,
"MLA Context: quant_k_buf must be non-null in non-absorption mode");
TLLM_CHECK_WITH_INFO(params.absorption_mode || params.quant_v_buf != nullptr,
"MLA Context: quant_v_buf must be non-null in non-absorption mode");
TLLM_LOG_DEBUG("MLA RoPE Context: Quantizing separate qkv to FP8");
if (params.acc_q_len > 0)
{
// The Q tensor has layout of [num_tokens, head_num, 576] in the absorption mode.
// Convert Q to FP8 in absorption mode.
if (params.absorption_mode)
{
constexpr int threads_per_block = 288;
constexpr int num_tokens_per_block = threads_per_block * 16 / 576 * sizeof(T);
dim3 grid(int(tensorrt_llm::common::divUp(total_kv_len, num_tokens_per_block)), 1, params.head_num);
TLLM_LOG_DEBUG(
"Launching quantizeCopyInputToFp8Kernel with grid_size: (%d, %d, %d), threads_per_block: %d, "
"total_kv_len: %d, acc_q_len: %d, absorption_mode: %d",
grid.x, grid.y, grid.z, threads_per_block, total_kv_len, params.acc_q_len, params.absorption_mode);
quantizeCopyInputToFp8Kernel<T, threads_per_block, 512, 64, 512, true>
<<<grid, threads_per_block, 0, stream>>>(params.q_buf, static_cast<__nv_fp8_e4m3*>(params.quant_q_buf),
params.k_buf, static_cast<__nv_fp8_e4m3*>(params.quant_k_buf), params.v_buf,
static_cast<__nv_fp8_e4m3*>(params.quant_v_buf), params.acc_q_len, total_kv_len,
params.quant_scale_qkv, params.bmm1_scale, params.bmm2_scale, params.quant_scale_o,
params.dequant_scale_q, params.dequant_scale_kv, params.host_bmm1_scale);
}
else
{
// The Q or K tensor has layout of [num_tokens, head_num, 192] in the non-absorption mode.
// The V tensor has layout of [num_tokens, head_num, 128] in the non-absorption mode.
// Convert Q, K, V to FP8 in non-absorption mode.