|
25 | 25 | #include <cub/cub.cuh>
|
26 | 26 |
|
27 | 27 | #include "fbgemm_gpu/utils/cuda_block_count.h"
|
| 28 | +#include "fbgemm_gpu/utils/kernel_launcher.cuh" |
28 | 29 | #include "fbgemm_gpu/utils/vec_quant.cuh"
|
29 | 30 |
|
30 | 31 | #include <torch/torch.h>
|
@@ -125,38 +126,40 @@ DEVICE_INLINE void per_row_amax(fx4& dst, float* amax) {
|
125 | 126 | }
|
126 | 127 | }
|
127 | 128 | __global__ void nope_qkv_varseq_prefill_kernel(
|
128 |
| - at::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits> |
| 129 | + pta::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits> |
129 | 130 | XQ, // [B_T][N_H][D_H]
|
130 |
| - at::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits> |
| 131 | + pta::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits> |
131 | 132 | XK, // [B_T][N_KVH][D_H]
|
132 |
| - at::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits> |
| 133 | + pta::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits> |
133 | 134 | XV, // [B_T][N_KVH][D_H]
|
134 |
| - at::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits> |
| 135 | + pta::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits> |
135 | 136 | cache_K, // [B][MAX_T][N_KVH][D_H] or
|
136 | 137 | // [1][MAX_PAGES * PAGE_SIZE][N_KVH][D_H] for paged attention
|
137 |
| - at::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits> |
| 138 | + pta::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits> |
138 | 139 | cache_V, // [B][MAX_T][N_KVH][D_H] or
|
139 | 140 | // [1][MAX_PAGES * PAGE_SIZE][N_KVH][D_H] for paged attention
|
140 |
| - at::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits> |
| 141 | + pta::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits> |
141 | 142 | XQ_O, // [B_T][N_H][D]
|
142 | 143 | int32_t* varseq_batch, // in decoding case we have T == 1 and so just pass
|
143 | 144 | // nullptr
|
144 |
| - at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> varseq_seqpos, |
| 145 | + pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> |
| 146 | + varseq_seqpos, |
145 | 147 | int32_t* block_tables, // [B][MAX_PAGES], maps logical pages to physical
|
146 | 148 | // ones for paged attention
|
147 | 149 | int32_t page_size,
|
148 | 150 | int32_t block_tables_b_stride,
|
149 |
| - at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> |
| 151 | + pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> |
150 | 152 | varseq_cache_seqpos,
|
151 |
| - int64_t* actual_batch_size = |
152 |
| - nullptr, // When running in CUDA graph mode, the actual batch size |
153 |
| - // can be smaller than block_tables.size(0). In this case |
154 |
| - // rows of block_tables beyond actual_batch_size are not |
155 |
| - // initialized, and using them wil cause undefined |
156 |
| - // behavior. To prevent this, when actual_batch_size is |
157 |
| - // provided, the kernel exits if the current batch index is |
158 |
| - // larger of equal to actual_batch_size, |
159 |
| - bool update_kv = true) { |
| 153 | + int64_t* |
| 154 | + actual_batch_size, // When running in CUDA graph mode, the actual batch |
| 155 | + // size can be smaller than block_tables.size(0). In |
| 156 | + // this case rows of block_tables beyond |
| 157 | + // actual_batch_size are not initialized, and using |
| 158 | + // them wil cause undefined behavior. To prevent |
| 159 | + // this, when actual_batch_size is provided, the |
| 160 | + // kernel exits if the current batch index is larger |
| 161 | + // of equal to actual_batch_size, |
| 162 | + bool update_kv) { |
160 | 163 | // Launch b_t_(sum(h)) warps.
|
161 | 164 | auto b_t_hh = blockIdx.x * blockDim.y + threadIdx.y;
|
162 | 165 | auto B_T = XQ.size(0);
|
@@ -624,77 +627,84 @@ quantize_int4_kv(fx4 dst, uint8_t* dst_row_q, bool do_norm = false) {
|
624 | 627 | }
|
625 | 628 | }
|
626 | 629 |
|
627 |
| -#define CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL( \ |
628 |
| - NUM_GROUPS, \ |
629 |
| - DTYPE, \ |
630 |
| - EMB_MODE, \ |
631 |
| - VARSEQ_BATCH, \ |
632 |
| - VARSEQ_SEQPOS, \ |
633 |
| - THETA, \ |
634 |
| - GAMMA, \ |
635 |
| - SCALE_BASE, \ |
636 |
| - EXPO_OFFSET, \ |
637 |
| - block_tables, \ |
638 |
| - page_size, \ |
639 |
| - block_tables_b_stride, \ |
640 |
| - varseq_cache_seqpos, \ |
641 |
| - actual_batch_size, \ |
642 |
| - rope_scaling, \ |
643 |
| - old_context_len, \ |
644 |
| - scaling_factor, \ |
645 |
| - lo_freq_factor, \ |
646 |
| - hi_freq_factor, \ |
647 |
| - write_k_back, \ |
648 |
| - k_norm) \ |
649 |
| - rope_xpos_qkv_varseq_prefill_kernel_quantized<EMB_MODE, DTYPE, NUM_GROUPS> \ |
650 |
| - <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \ |
651 |
| - XQ.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(), \ |
652 |
| - XK.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(), \ |
653 |
| - XV.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(), \ |
654 |
| - cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(), \ |
655 |
| - cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(), \ |
656 |
| - qparam_k_ptr, \ |
657 |
| - qparam_v_ptr, \ |
658 |
| - XQ_O.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(), \ |
659 |
| - VARSEQ_BATCH, \ |
660 |
| - VARSEQ_SEQPOS, \ |
661 |
| - THETA, \ |
662 |
| - GAMMA, \ |
663 |
| - SCALE_BASE, \ |
664 |
| - EXPO_OFFSET, \ |
665 |
| - block_tables, \ |
666 |
| - page_size, \ |
667 |
| - block_tables_b_stride, \ |
668 |
| - varseq_cache_seqpos, \ |
669 |
| - actual_batch_size, \ |
670 |
| - rope_scaling, \ |
671 |
| - old_context_len, \ |
672 |
| - scaling_factor, \ |
673 |
| - lo_freq_factor, \ |
674 |
| - hi_freq_factor, \ |
675 |
| - write_k_back, \ |
676 |
| - k_norm); |
| 630 | +#define CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL( \ |
| 631 | + NUM_GROUPS, \ |
| 632 | + DTYPE, \ |
| 633 | + EMB_MODE, \ |
| 634 | + VARSEQ_BATCH, \ |
| 635 | + VARSEQ_SEQPOS, \ |
| 636 | + THETA, \ |
| 637 | + GAMMA, \ |
| 638 | + SCALE_BASE, \ |
| 639 | + EXPO_OFFSET, \ |
| 640 | + block_tables, \ |
| 641 | + page_size, \ |
| 642 | + block_tables_b_stride, \ |
| 643 | + varseq_cache_seqpos, \ |
| 644 | + actual_batch_size, \ |
| 645 | + rope_scaling, \ |
| 646 | + old_context_len, \ |
| 647 | + scaling_factor, \ |
| 648 | + lo_freq_factor, \ |
| 649 | + hi_freq_factor, \ |
| 650 | + write_k_back, \ |
| 651 | + k_norm) \ |
| 652 | + FBGEMM_LAUNCH_KERNEL( \ |
| 653 | + (rope_xpos_qkv_varseq_prefill_kernel_quantized< \ |
| 654 | + EMB_MODE, \ |
| 655 | + DTYPE, \ |
| 656 | + NUM_GROUPS>), \ |
| 657 | + blocks, \ |
| 658 | + threads, \ |
| 659 | + 0, \ |
| 660 | + at::cuda::getCurrentCUDAStream(), \ |
| 661 | + PTA_B(XQ, at::BFloat16, 3, 32), \ |
| 662 | + PTA_B(XK, at::BFloat16, 3, 32), \ |
| 663 | + PTA_B(XV, at::BFloat16, 3, 32), \ |
| 664 | + PTA_B(cache_K, uint8_t, 4, 64), \ |
| 665 | + PTA_B(cache_V, uint8_t, 4, 64), \ |
| 666 | + qparam_k_ptr, \ |
| 667 | + qparam_v_ptr, \ |
| 668 | + PTA_B(XQ_O, at::BFloat16, 3, 32), \ |
| 669 | + VARSEQ_BATCH, \ |
| 670 | + VARSEQ_SEQPOS, \ |
| 671 | + THETA, \ |
| 672 | + GAMMA, \ |
| 673 | + SCALE_BASE, \ |
| 674 | + EXPO_OFFSET, \ |
| 675 | + block_tables, \ |
| 676 | + page_size, \ |
| 677 | + block_tables_b_stride, \ |
| 678 | + varseq_cache_seqpos, \ |
| 679 | + actual_batch_size, \ |
| 680 | + rope_scaling, \ |
| 681 | + old_context_len, \ |
| 682 | + scaling_factor, \ |
| 683 | + lo_freq_factor, \ |
| 684 | + hi_freq_factor, \ |
| 685 | + write_k_back, \ |
| 686 | + k_norm); |
677 | 687 |
|
678 | 688 | template <
|
679 | 689 | PositionEmbeddingMode EmbMode,
|
680 | 690 | CacheLogicalDtype kCacheDtype,
|
681 | 691 | int KVQuantNumGroups = 1>
|
682 | 692 | __global__ void rope_xpos_qkv_varseq_prefill_kernel_quantized(
|
683 |
| - at::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits> |
| 693 | + pta::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits> |
684 | 694 | XQ, // [B_T][N_H][D_H]
|
685 |
| - at::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits> |
| 695 | + pta::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits> |
686 | 696 | XK, // [B_T][N_KVH][D_H]
|
687 |
| - at::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits> |
| 697 | + pta::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits> |
688 | 698 | XV, // [B_T][N_KVH][D_H]
|
689 |
| - at::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits> |
| 699 | + pta::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits> |
690 | 700 | cache_K, // [B][MAX_T][N_KVH][D_H] or
|
691 | 701 | // [1][MAX_PAGES * PAGE_SIZE][N_KVH][D_H] for paged attention
|
692 |
| - at::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits> |
| 702 | + pta::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits> |
693 | 703 | cache_V, // [B][MAX_T][N_KVH][D_H] or
|
694 | 704 | // [1][MAX_PAGES * PAGE_SIZE][N_KVH][D_H] for paged attention
|
695 | 705 | int32_t* qparam_k_ptr,
|
696 | 706 | int32_t* qparam_v_ptr,
|
697 |
| - at::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits> |
| 707 | + pta::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits> |
698 | 708 | XQ_O, // [B_T][N_H][D]
|
699 | 709 | int32_t* varseq_batch, // in decoding case we have T == 1 and so just
|
700 | 710 | // pass nullptr
|
@@ -1111,27 +1121,27 @@ at::Tensor nope_qkv_varseq_prefill(
|
1111 | 1121 | CacheLogicalDtype cache_logical_dtype =
|
1112 | 1122 | static_cast<CacheLogicalDtype>(cache_logical_dtype_int);
|
1113 | 1123 | if (cache_K.dtype() == at::kBFloat16) {
|
1114 |
| - nope_qkv_varseq_prefill_kernel<<< |
| 1124 | + FBGEMM_LAUNCH_KERNEL( |
| 1125 | + (nope_qkv_varseq_prefill_kernel), |
1115 | 1126 | blocks,
|
1116 | 1127 | threads,
|
1117 | 1128 | 0,
|
1118 |
| - at::cuda::getCurrentCUDAStream()>>>( |
1119 |
| - XQ.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(), |
1120 |
| - XK.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(), |
1121 |
| - XV.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(), |
1122 |
| - cache_K.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(), |
1123 |
| - cache_V.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(), |
1124 |
| - XQ_O.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(), |
| 1129 | + at::cuda::getCurrentCUDAStream(), |
| 1130 | + PTA_B(XQ, at::BFloat16, 3, 32), |
| 1131 | + PTA_B(XK, at::BFloat16, 3, 32), |
| 1132 | + PTA_B(XV, at::BFloat16, 3, 32), |
| 1133 | + PTA_B(cache_K, at::BFloat16, 4, 64), |
| 1134 | + PTA_B(cache_V, at::BFloat16, 4, 64), |
| 1135 | + PTA_B(XQ_O, at::BFloat16, 3, 32), |
1125 | 1136 | varseq_batch.data_ptr<int32_t>(),
|
1126 |
| - varseq_seqpos.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), |
| 1137 | + PTA_B(varseq_seqpos, int32_t, 1, 32), |
1127 | 1138 | block_tables_ptr,
|
1128 | 1139 | page_size,
|
1129 | 1140 | block_tables_b_stride,
|
1130 |
| - varseq_cache_seqpos_ |
1131 |
| - .packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), |
| 1141 | + PTA_B(varseq_cache_seqpos_, int32_t, 1, 32), |
1132 | 1142 | actual_batch_size_ptr,
|
1133 | 1143 | update_kv);
|
1134 |
| - C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 1144 | + |
1135 | 1145 | } else {
|
1136 | 1146 | auto num_groups_ = num_groups ? num_groups.value() : 1;
|
1137 | 1147 | int32_t* qparam_k_ptr = nullptr;
|
@@ -1315,27 +1325,27 @@ at::Tensor nope_qkv_decoding(
|
1315 | 1325 | CacheLogicalDtype cache_logical_dtype =
|
1316 | 1326 | static_cast<CacheLogicalDtype>(cache_logical_dtype_int);
|
1317 | 1327 | if (cache_K.dtype() == at::kBFloat16) {
|
1318 |
| - nope_qkv_varseq_prefill_kernel<<< |
| 1328 | + FBGEMM_LAUNCH_KERNEL( |
| 1329 | + (nope_qkv_varseq_prefill_kernel), |
1319 | 1330 | blocks,
|
1320 | 1331 | threads,
|
1321 | 1332 | 0,
|
1322 |
| - at::cuda::getCurrentCUDAStream()>>>( |
1323 |
| - XQ.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(), |
1324 |
| - XK.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(), |
1325 |
| - XV.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(), |
1326 |
| - cache_K.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(), |
1327 |
| - cache_V.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(), |
1328 |
| - XQ_O.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(), |
| 1333 | + at::cuda::getCurrentCUDAStream(), |
| 1334 | + PTA_B(XQ, at::BFloat16, 3, 32), |
| 1335 | + PTA_B(XK, at::BFloat16, 3, 32), |
| 1336 | + PTA_B(XV, at::BFloat16, 3, 32), |
| 1337 | + PTA_B(cache_K, at::BFloat16, 4, 64), |
| 1338 | + PTA_B(cache_V, at::BFloat16, 4, 64), |
| 1339 | + PTA_B(XQ_O, at::BFloat16, 3, 32), |
1329 | 1340 | batch.has_value() ? batch.value().data_ptr<int32_t>() : nullptr,
|
1330 |
| - seqpos.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), |
| 1341 | + PTA_B(seqpos, int32_t, 1, 32), |
1331 | 1342 | block_tables_ptr,
|
1332 | 1343 | page_size,
|
1333 | 1344 | block_tables_b_stride,
|
1334 |
| - cache_seqpos_.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), |
| 1345 | + PTA_B(cache_seqpos_, int32_t, 1, 32), |
1335 | 1346 | actual_batch_size_ptr,
|
1336 | 1347 | update_kv);
|
1337 | 1348 |
|
1338 |
| - C10_CUDA_KERNEL_LAUNCH_CHECK(); |
1339 | 1349 | } else {
|
1340 | 1350 | auto num_groups_ = num_groups ? num_groups.value() : 1;
|
1341 | 1351 | int32_t* qparam_k_ptr = nullptr;
|
|
0 commit comments