Skip to content

Commit 12a1be8

Browse files
q10facebook-github-bot
authored andcommitted
Migrate GenAI kv cache kernels to FBGEMM_LAUNCH_KERNEL, pt 1 (pytorch#4872)
Summary: X-link: facebookresearch/FBGEMM#1893 Pull Request resolved: pytorch#4872 - Migrate GenAI kv cache kernels to `FBGEMM_LAUNCH_KERNEL`, pt 1 Reviewed By: r-barnes Differential Revision: D81629485 fbshipit-source-id: ae63b5fbfbd44887b4c45870f3f57d771bd4dcb7
1 parent d1340d8 commit 12a1be8

File tree

1 file changed

+106
-96
lines changed

1 file changed

+106
-96
lines changed

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu

Lines changed: 106 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <cub/cub.cuh>
2626

2727
#include "fbgemm_gpu/utils/cuda_block_count.h"
28+
#include "fbgemm_gpu/utils/kernel_launcher.cuh"
2829
#include "fbgemm_gpu/utils/vec_quant.cuh"
2930

3031
#include <torch/torch.h>
@@ -125,38 +126,40 @@ DEVICE_INLINE void per_row_amax(fx4& dst, float* amax) {
125126
}
126127
}
127128
__global__ void nope_qkv_varseq_prefill_kernel(
128-
at::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
129+
pta::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
129130
XQ, // [B_T][N_H][D_H]
130-
at::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
131+
pta::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
131132
XK, // [B_T][N_KVH][D_H]
132-
at::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
133+
pta::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
133134
XV, // [B_T][N_KVH][D_H]
134-
at::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
135+
pta::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
135136
cache_K, // [B][MAX_T][N_KVH][D_H] or
136137
// [1][MAX_PAGES * PAGE_SIZE][N_KVH][D_H] for paged attention
137-
at::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
138+
pta::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
138139
cache_V, // [B][MAX_T][N_KVH][D_H] or
139140
// [1][MAX_PAGES * PAGE_SIZE][N_KVH][D_H] for paged attention
140-
at::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
141+
pta::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
141142
XQ_O, // [B_T][N_H][D]
142143
int32_t* varseq_batch, // in decoding case we have T == 1 and so just pass
143144
// nullptr
144-
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> varseq_seqpos,
145+
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
146+
varseq_seqpos,
145147
int32_t* block_tables, // [B][MAX_PAGES], maps logical pages to physical
146148
// ones for paged attention
147149
int32_t page_size,
148150
int32_t block_tables_b_stride,
149-
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
151+
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
150152
varseq_cache_seqpos,
151-
int64_t* actual_batch_size =
152-
nullptr, // When running in CUDA graph mode, the actual batch size
153-
// can be smaller than block_tables.size(0). In this case
154-
// rows of block_tables beyond actual_batch_size are not
155-
// initialized, and using them wil cause undefined
156-
// behavior. To prevent this, when actual_batch_size is
157-
// provided, the kernel exits if the current batch index is
158-
// larger of equal to actual_batch_size,
159-
bool update_kv = true) {
153+
int64_t*
154+
actual_batch_size, // When running in CUDA graph mode, the actual batch
155+
// size can be smaller than block_tables.size(0). In
156+
// this case rows of block_tables beyond
157+
// actual_batch_size are not initialized, and using
158+
// them wil cause undefined behavior. To prevent
159+
// this, when actual_batch_size is provided, the
160+
// kernel exits if the current batch index is larger
161+
// of equal to actual_batch_size,
162+
bool update_kv) {
160163
// Launch b_t_(sum(h)) warps.
161164
auto b_t_hh = blockIdx.x * blockDim.y + threadIdx.y;
162165
auto B_T = XQ.size(0);
@@ -624,77 +627,84 @@ quantize_int4_kv(fx4 dst, uint8_t* dst_row_q, bool do_norm = false) {
624627
}
625628
}
626629

627-
#define CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL( \
628-
NUM_GROUPS, \
629-
DTYPE, \
630-
EMB_MODE, \
631-
VARSEQ_BATCH, \
632-
VARSEQ_SEQPOS, \
633-
THETA, \
634-
GAMMA, \
635-
SCALE_BASE, \
636-
EXPO_OFFSET, \
637-
block_tables, \
638-
page_size, \
639-
block_tables_b_stride, \
640-
varseq_cache_seqpos, \
641-
actual_batch_size, \
642-
rope_scaling, \
643-
old_context_len, \
644-
scaling_factor, \
645-
lo_freq_factor, \
646-
hi_freq_factor, \
647-
write_k_back, \
648-
k_norm) \
649-
rope_xpos_qkv_varseq_prefill_kernel_quantized<EMB_MODE, DTYPE, NUM_GROUPS> \
650-
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
651-
XQ.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(), \
652-
XK.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(), \
653-
XV.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(), \
654-
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(), \
655-
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(), \
656-
qparam_k_ptr, \
657-
qparam_v_ptr, \
658-
XQ_O.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(), \
659-
VARSEQ_BATCH, \
660-
VARSEQ_SEQPOS, \
661-
THETA, \
662-
GAMMA, \
663-
SCALE_BASE, \
664-
EXPO_OFFSET, \
665-
block_tables, \
666-
page_size, \
667-
block_tables_b_stride, \
668-
varseq_cache_seqpos, \
669-
actual_batch_size, \
670-
rope_scaling, \
671-
old_context_len, \
672-
scaling_factor, \
673-
lo_freq_factor, \
674-
hi_freq_factor, \
675-
write_k_back, \
676-
k_norm);
630+
#define CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL( \
631+
NUM_GROUPS, \
632+
DTYPE, \
633+
EMB_MODE, \
634+
VARSEQ_BATCH, \
635+
VARSEQ_SEQPOS, \
636+
THETA, \
637+
GAMMA, \
638+
SCALE_BASE, \
639+
EXPO_OFFSET, \
640+
block_tables, \
641+
page_size, \
642+
block_tables_b_stride, \
643+
varseq_cache_seqpos, \
644+
actual_batch_size, \
645+
rope_scaling, \
646+
old_context_len, \
647+
scaling_factor, \
648+
lo_freq_factor, \
649+
hi_freq_factor, \
650+
write_k_back, \
651+
k_norm) \
652+
FBGEMM_LAUNCH_KERNEL( \
653+
(rope_xpos_qkv_varseq_prefill_kernel_quantized< \
654+
EMB_MODE, \
655+
DTYPE, \
656+
NUM_GROUPS>), \
657+
blocks, \
658+
threads, \
659+
0, \
660+
at::cuda::getCurrentCUDAStream(), \
661+
PTA_B(XQ, at::BFloat16, 3, 32), \
662+
PTA_B(XK, at::BFloat16, 3, 32), \
663+
PTA_B(XV, at::BFloat16, 3, 32), \
664+
PTA_B(cache_K, uint8_t, 4, 64), \
665+
PTA_B(cache_V, uint8_t, 4, 64), \
666+
qparam_k_ptr, \
667+
qparam_v_ptr, \
668+
PTA_B(XQ_O, at::BFloat16, 3, 32), \
669+
VARSEQ_BATCH, \
670+
VARSEQ_SEQPOS, \
671+
THETA, \
672+
GAMMA, \
673+
SCALE_BASE, \
674+
EXPO_OFFSET, \
675+
block_tables, \
676+
page_size, \
677+
block_tables_b_stride, \
678+
varseq_cache_seqpos, \
679+
actual_batch_size, \
680+
rope_scaling, \
681+
old_context_len, \
682+
scaling_factor, \
683+
lo_freq_factor, \
684+
hi_freq_factor, \
685+
write_k_back, \
686+
k_norm);
677687

678688
template <
679689
PositionEmbeddingMode EmbMode,
680690
CacheLogicalDtype kCacheDtype,
681691
int KVQuantNumGroups = 1>
682692
__global__ void rope_xpos_qkv_varseq_prefill_kernel_quantized(
683-
at::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
693+
pta::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
684694
XQ, // [B_T][N_H][D_H]
685-
at::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
695+
pta::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
686696
XK, // [B_T][N_KVH][D_H]
687-
at::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
697+
pta::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
688698
XV, // [B_T][N_KVH][D_H]
689-
at::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
699+
pta::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
690700
cache_K, // [B][MAX_T][N_KVH][D_H] or
691701
// [1][MAX_PAGES * PAGE_SIZE][N_KVH][D_H] for paged attention
692-
at::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
702+
pta::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
693703
cache_V, // [B][MAX_T][N_KVH][D_H] or
694704
// [1][MAX_PAGES * PAGE_SIZE][N_KVH][D_H] for paged attention
695705
int32_t* qparam_k_ptr,
696706
int32_t* qparam_v_ptr,
697-
at::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
707+
pta::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
698708
XQ_O, // [B_T][N_H][D]
699709
int32_t* varseq_batch, // in decoding case we have T == 1 and so just
700710
// pass nullptr
@@ -1111,27 +1121,27 @@ at::Tensor nope_qkv_varseq_prefill(
11111121
CacheLogicalDtype cache_logical_dtype =
11121122
static_cast<CacheLogicalDtype>(cache_logical_dtype_int);
11131123
if (cache_K.dtype() == at::kBFloat16) {
1114-
nope_qkv_varseq_prefill_kernel<<<
1124+
FBGEMM_LAUNCH_KERNEL(
1125+
(nope_qkv_varseq_prefill_kernel),
11151126
blocks,
11161127
threads,
11171128
0,
1118-
at::cuda::getCurrentCUDAStream()>>>(
1119-
XQ.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1120-
XK.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1121-
XV.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1122-
cache_K.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1123-
cache_V.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1124-
XQ_O.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1129+
at::cuda::getCurrentCUDAStream(),
1130+
PTA_B(XQ, at::BFloat16, 3, 32),
1131+
PTA_B(XK, at::BFloat16, 3, 32),
1132+
PTA_B(XV, at::BFloat16, 3, 32),
1133+
PTA_B(cache_K, at::BFloat16, 4, 64),
1134+
PTA_B(cache_V, at::BFloat16, 4, 64),
1135+
PTA_B(XQ_O, at::BFloat16, 3, 32),
11251136
varseq_batch.data_ptr<int32_t>(),
1126-
varseq_seqpos.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1137+
PTA_B(varseq_seqpos, int32_t, 1, 32),
11271138
block_tables_ptr,
11281139
page_size,
11291140
block_tables_b_stride,
1130-
varseq_cache_seqpos_
1131-
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1141+
PTA_B(varseq_cache_seqpos_, int32_t, 1, 32),
11321142
actual_batch_size_ptr,
11331143
update_kv);
1134-
C10_CUDA_KERNEL_LAUNCH_CHECK();
1144+
11351145
} else {
11361146
auto num_groups_ = num_groups ? num_groups.value() : 1;
11371147
int32_t* qparam_k_ptr = nullptr;
@@ -1315,27 +1325,27 @@ at::Tensor nope_qkv_decoding(
13151325
CacheLogicalDtype cache_logical_dtype =
13161326
static_cast<CacheLogicalDtype>(cache_logical_dtype_int);
13171327
if (cache_K.dtype() == at::kBFloat16) {
1318-
nope_qkv_varseq_prefill_kernel<<<
1328+
FBGEMM_LAUNCH_KERNEL(
1329+
(nope_qkv_varseq_prefill_kernel),
13191330
blocks,
13201331
threads,
13211332
0,
1322-
at::cuda::getCurrentCUDAStream()>>>(
1323-
XQ.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1324-
XK.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1325-
XV.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1326-
cache_K.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1327-
cache_V.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1328-
XQ_O.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1333+
at::cuda::getCurrentCUDAStream(),
1334+
PTA_B(XQ, at::BFloat16, 3, 32),
1335+
PTA_B(XK, at::BFloat16, 3, 32),
1336+
PTA_B(XV, at::BFloat16, 3, 32),
1337+
PTA_B(cache_K, at::BFloat16, 4, 64),
1338+
PTA_B(cache_V, at::BFloat16, 4, 64),
1339+
PTA_B(XQ_O, at::BFloat16, 3, 32),
13291340
batch.has_value() ? batch.value().data_ptr<int32_t>() : nullptr,
1330-
seqpos.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1341+
PTA_B(seqpos, int32_t, 1, 32),
13311342
block_tables_ptr,
13321343
page_size,
13331344
block_tables_b_stride,
1334-
cache_seqpos_.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1345+
PTA_B(cache_seqpos_, int32_t, 1, 32),
13351346
actual_batch_size_ptr,
13361347
update_kv);
13371348

1338-
C10_CUDA_KERNEL_LAUNCH_CHECK();
13391349
} else {
13401350
auto num_groups_ = num_groups ? num_groups.value() : 1;
13411351
int32_t* qparam_k_ptr = nullptr;

0 commit comments

Comments
 (0)