Skip to content

Commit 944647a

Browse files
q10facebook-github-bot
authored andcommitted
Migrate GenAI kv cache kernels to FBGEMM_LAUNCH_KERNEL, pt 4 (#4895)
Summary: Pull Request resolved: #4895 - Migrate GenAI kv cache kernels to `FBGEMM_LAUNCH_KERNEL`, pt 4 Reviewed By: cthi Differential Revision: D82773258 fbshipit-source-id: f2684f75eb09a85aa918305fbde0efaaa0655b64
1 parent 53f9e51 commit 944647a

File tree

2 files changed

+35
-25
lines changed

2 files changed

+35
-25
lines changed

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache_convert.cu

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#endif
2424

2525
#include "fbgemm_gpu/utils/cuda_block_count.h"
26+
#include "fbgemm_gpu/utils/kernel_launcher.cuh"
2627
#include "fbgemm_gpu/utils/vec_quant.cuh"
2728

2829
#include <torch/torch.h>
@@ -47,12 +48,12 @@ namespace fbgemm_gpu {
4748
* 32-63 to convert the V tensors. NV only has threads 0-31 per warp.
4849
*/
4950
__global__ void convert_e4m3fn_kv_cache_to_e4m3fnuz_inplace_kernel(
50-
at::PackedTensorAccessor64<uint8_t, 5, at::RestrictPtrTraits>
51+
pta::PackedTensorAccessor64<uint8_t, 5, at::RestrictPtrTraits>
5152
cache_K, // [N_H_L][B][MAX_T][N_KVH][D_H]
52-
at::PackedTensorAccessor64<uint8_t, 5, at::RestrictPtrTraits>
53+
pta::PackedTensorAccessor64<uint8_t, 5, at::RestrictPtrTraits>
5354
cache_V, // [N_H_L][B][MAX_T][N_KVH][D_H]
54-
at::PackedTensorAccessor64<int32_t, 5, at::RestrictPtrTraits> qparam_K,
55-
at::PackedTensorAccessor64<int32_t, 5, at::RestrictPtrTraits> qparam_V) {
55+
pta::PackedTensorAccessor64<int32_t, 5, at::RestrictPtrTraits> qparam_K,
56+
pta::PackedTensorAccessor64<int32_t, 5, at::RestrictPtrTraits> qparam_V) {
5657
auto N_KVH = cache_K.size(3);
5758
auto MAX_T = cache_K.size(2);
5859
auto D_H = cache_K.size(4);
@@ -133,17 +134,16 @@ void convert_e4m3fn_kv_cache_to_e4m3fnuz_inplace(
133134
dim3 blocks(N_H_L, B, std::max<int32_t>(1, kMaxBlocks / (B * N_H_L)));
134135
dim3 threads(kThreadsPerWarp, kWarpsPerBlock);
135136

136-
convert_e4m3fn_kv_cache_to_e4m3fnuz_inplace_kernel<<<
137+
FBGEMM_LAUNCH_KERNEL(
138+
(convert_e4m3fn_kv_cache_to_e4m3fnuz_inplace_kernel),
137139
blocks,
138140
threads,
139141
0,
140-
at::cuda::getCurrentCUDAStream()>>>(
141-
cache_K.packed_accessor64<uint8_t, 5, at::RestrictPtrTraits>(),
142-
cache_V.packed_accessor64<uint8_t, 5, at::RestrictPtrTraits>(),
143-
qparam_K.packed_accessor64<int32_t, 5, at::RestrictPtrTraits>(),
144-
qparam_V.packed_accessor64<int32_t, 5, at::RestrictPtrTraits>());
145-
146-
C10_CUDA_KERNEL_LAUNCH_CHECK();
142+
at::cuda::getCurrentCUDAStream(),
143+
PTA_B(cache_K, uint8_t, 5, 64),
144+
PTA_B(cache_V, uint8_t, 5, 64),
145+
PTA_B(qparam_K, int32_t, 5, 64),
146+
PTA_B(qparam_V, int32_t, 5, 64));
147147
}
148148
#else
149149
void convert_e4m3fn_kv_cache_to_e4m3fnuz_inplace(

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache_dequantize.cu

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <cub/cub.cuh>
2626

2727
#include "fbgemm_gpu/utils/cuda_block_count.h"
28+
#include "fbgemm_gpu/utils/kernel_launcher.cuh"
2829
#include "fbgemm_gpu/utils/vec_quant.cuh"
2930

3031
#include <torch/torch.h>
@@ -113,8 +114,12 @@ __global__ void dequantize_int4_cache_kernel(
113114
}
114115

115116
#define CALL_DEQUANTIZE_INT4_CACHE_GROUPWISE_KERNEL(NUM_GROUPS, ...) \
116-
dequantize_int4_cache_kernel< \
117-
NUM_GROUPS><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
117+
FBGEMM_LAUNCH_KERNEL( \
118+
(dequantize_int4_cache_kernel<NUM_GROUPS>), \
119+
blocks, \
120+
threads, \
121+
0, \
122+
at::cuda::getCurrentCUDAStream(), \
118123
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(), \
119124
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(), \
120125
kv_seqlen.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), \
@@ -539,16 +544,19 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
539544
dim3 blocks(B, std::max<int32_t>(1, kMaxBlocks / B));
540545
dim3 threads(kThreadsPerWarp, kWarpsPerBlock);
541546
#define CALL_DEQUANTIZE_FP8_CACHE(EXTERNAL_Q_PARAM) \
542-
const auto deq_fn = dequantize_fp8_cache_kernel<EXTERNAL_Q_PARAM>; \
543-
deq_fn<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
547+
FBGEMM_LAUNCH_KERNEL( \
548+
(dequantize_fp8_cache_kernel<EXTERNAL_Q_PARAM>), \
549+
blocks, \
550+
threads, \
551+
0, \
552+
at::cuda::getCurrentCUDAStream(), \
544553
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(), \
545554
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(), \
546555
kv_seqlen.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), \
547556
cache_K_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(), \
548557
cache_V_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(), \
549558
qparam_k_ptr, \
550-
qparam_v_ptr); \
551-
C10_CUDA_KERNEL_LAUNCH_CHECK()
559+
qparam_v_ptr);
552560
if (block_tables_ptr == nullptr) {
553561
if (qparam_k_ptr) {
554562
CALL_DEQUANTIZE_FP8_CACHE(true);
@@ -557,11 +565,12 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
557565
}
558566
#undef CALL_DEQUANTIZE_FP8_CACHE
559567
} else {
560-
dequantize_fp8_cache_kernel_paged<<<
568+
FBGEMM_LAUNCH_KERNEL(
569+
(dequantize_fp8_cache_kernel_paged),
561570
blocks,
562571
threads,
563572
0,
564-
at::cuda::getCurrentCUDAStream()>>>(
573+
at::cuda::getCurrentCUDAStream(),
565574
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
566575
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
567576
kv_seqlen.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
@@ -572,7 +581,6 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
572581
block_tables_ptr,
573582
block_tables_b_stride,
574583
page_size);
575-
C10_CUDA_KERNEL_LAUNCH_CHECK();
576584
}
577585

578586
return {cache_K_dq, cache_V_dq};
@@ -752,11 +760,13 @@ at::Tensor quantize_qkv_per_head(
752760
auto scale_q = at::zeros({B, N_KVH_L}, XQ_O.options().dtype(at::kFloat));
753761
float* const scale_q_ptr = scale_q.data_ptr<float>();
754762
// Launch the kernel
755-
quantizeQKVPerHead<<<
763+
764+
FBGEMM_LAUNCH_KERNEL(
765+
(quantizeQKVPerHead),
756766
grid_size,
757767
block_size,
758768
0,
759-
at::cuda::getCurrentCUDAStream()>>>(
769+
at::cuda::getCurrentCUDAStream(),
760770
xqkv_amax_row.data_ptr<float>(),
761771
xqkv.data_ptr<at::BFloat16>(),
762772
varseq_seqpos.data_ptr<int32_t>(),
@@ -770,8 +780,8 @@ at::Tensor quantize_qkv_per_head(
770780
cache_V.packed_accessor64<at::Float8_e4m3fn, 4, at::RestrictPtrTraits>(),
771781
scale_q_ptr,
772782
qparam_k_ptr,
773-
qparam_v_ptr);
774-
C10_CUDA_KERNEL_LAUNCH_CHECK();
783+
qparam_v_ptr,
784+
64.f);
775785
return scale_q;
776786
}
777787
#else

0 commit comments

Comments
 (0)