Skip to content

Commit 9b66962

Browse files
q10facebook-github-bot
authored andcommitted
Migrate GenAI kv cache kernel arguments to use PTA_B (#4899)
Summary: Pull Request resolved: #4899 - Migrate GenAI kv cache kernel arguments to use `PTA_B` Reviewed By: cthi Differential Revision: D82775043 fbshipit-source-id: 83a96d12149fe21e97ba1e040582ca46c8ea53ec
1 parent e6d8742 commit 9b66962

File tree

1 file changed

+61
-62
lines changed

1 file changed

+61
-62
lines changed

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache_dequantize.cu

Lines changed: 61 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include <ATen/cuda/Exceptions.h>
1414
#include <c10/cuda/CUDAGuard.h>
1515
#include <ATen/cuda/Atomic.cuh>
16-
#include <algorithm>
1716
#include "c10/core/ScalarType.h"
1817
#include "c10/util/BFloat16.h"
1918
#include "kv_cache.cuh"
@@ -34,14 +33,14 @@ namespace fbgemm_gpu {
3433

3534
template <int KVQuantNumGroups = 1>
3635
__global__ void dequantize_int4_cache_kernel(
37-
at::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
36+
pta::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
3837
cache_K, // [B][MAX_T][N_KVH][D_H]
39-
at::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
38+
pta::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
4039
cache_V, // [B][MAX_T][N_KVH][D_H // G]
41-
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> kv_seqlen,
42-
at::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
40+
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> kv_seqlen,
41+
pta::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
4342
cache_K_dq, // [B][MAX_T][N_KVH][D_H]
44-
at::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
43+
pta::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
4544
cache_V_dq // [B][MAX_T][N_KVH][D_H]
4645
) {
4746
auto N_KVH = cache_K.size(2);
@@ -113,18 +112,18 @@ __global__ void dequantize_int4_cache_kernel(
113112
}
114113
}
115114

116-
#define CALL_DEQUANTIZE_INT4_CACHE_GROUPWISE_KERNEL(NUM_GROUPS, ...) \
117-
FBGEMM_LAUNCH_KERNEL( \
118-
(dequantize_int4_cache_kernel<NUM_GROUPS>), \
119-
blocks, \
120-
threads, \
121-
0, \
122-
at::cuda::getCurrentCUDAStream(), \
123-
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(), \
124-
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(), \
125-
kv_seqlen.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), \
126-
cache_K_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(), \
127-
cache_V_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>());
115+
#define CALL_DEQUANTIZE_INT4_CACHE_GROUPWISE_KERNEL(NUM_GROUPS, ...) \
116+
FBGEMM_LAUNCH_KERNEL( \
117+
(dequantize_int4_cache_kernel<NUM_GROUPS>), \
118+
blocks, \
119+
threads, \
120+
0, \
121+
at::cuda::getCurrentCUDAStream(), \
122+
PTA_B(cache_K, uint8_t, 4, 64), \
123+
PTA_B(cache_V, uint8_t, 4, 64), \
124+
PTA_B(kv_seqlen, int32_t, 1, 32), \
125+
PTA_B(cache_K_dq, at::BFloat16, 4, 64), \
126+
PTA_B(cache_V_dq, at::BFloat16, 4, 64));
128127

129128
std::tuple<at::Tensor, at::Tensor> dequantize_int4_cache(
130129
at::Tensor cache_K,
@@ -178,14 +177,14 @@ std::tuple<at::Tensor, at::Tensor> dequantize_int4_cache(
178177
template <bool ExternalQParam>
179178
__global__ void dequantize_fp8_cache_kernel(
180179
// This code currently represents FP8 version not int4
181-
at::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
180+
pta::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
182181
cache_K, // [B][MAX_T][N_KVH][D_H]
183-
at::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
182+
pta::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
184183
cache_V, // [B][MAX_T][N_KVH][D_H // G]
185-
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> kv_seqlen,
186-
at::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
184+
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> kv_seqlen,
185+
pta::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
187186
cache_K_dq, // [B][MAX_T][N_KVH][D_H]
188-
at::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
187+
pta::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
189188
cache_V_dq, // [B][MAX_T][N_KVH][D_H]
190189
int32_t* qparam_k_ptr,
191190
int32_t* qparam_v_ptr) {
@@ -262,14 +261,14 @@ __global__ void dequantize_fp8_cache_kernel(
262261

263262
__global__ void dequantize_fp8_cache_kernel_paged(
264263
// This code currently represents FP8 version not int4
265-
at::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
264+
pta::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
266265
cache_K, // [1][MAX_PAGE * PAGE_SIZE][N_KVH][D_H]
267-
at::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
266+
pta::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
268267
cache_V, // [1][MAX_PAGE * PAGE_SIZE][N_KVH][D_H // G]
269-
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> kv_seqlen,
270-
at::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
268+
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> kv_seqlen,
269+
pta::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
271270
cache_K_dq, // [1][MAX_T][N_KVH][D_H]
272-
at::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
271+
pta::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
273272
cache_V_dq, // [1][MAX_T][N_KVH][D_H]
274273
int32_t* qparam_k_ptr,
275274
int32_t* qparam_v_ptr,
@@ -283,14 +282,14 @@ __global__ void dequantize_fp8_cache_kernel_paged(
283282
template <bool ExternalQParam>
284283
__global__ void dequantize_fp8_cache_kernel(
285284
// This code currently represents FP8 version not int4
286-
at::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
285+
pta::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
287286
cache_K, // [B][MAX_T][N_KVH][D_H]
288-
at::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
287+
pta::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
289288
cache_V, // [B][MAX_T][N_KVH][D_H // G]
290-
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> kv_seqlen,
291-
at::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
289+
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> kv_seqlen,
290+
pta::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
292291
cache_K_dq, // [B][MAX_T][N_KVH][D_H]
293-
at::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
292+
pta::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
294293
cache_V_dq, // [B][MAX_T][N_KVH][D_H]
295294
int32_t* qparam_k_ptr,
296295
int32_t* qparam_v_ptr) {
@@ -375,14 +374,14 @@ __global__ void dequantize_fp8_cache_kernel(
375374
// kernel for now.
376375
__global__ void dequantize_fp8_cache_kernel_paged(
377376
// This code currently represents FP8 version not int4
378-
at::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
377+
pta::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
379378
cache_K, // [1][MAX_PAGE * PAGE_SIZE][N_KVH][D_H]
380-
at::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
379+
pta::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
381380
cache_V, // [1][MAX_PAGE * PAGE_SIZE][N_KVH][D_H // G]
382-
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> kv_seqlen,
383-
at::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
381+
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> kv_seqlen,
382+
pta::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
384383
cache_K_dq, // [1][MAX_T][N_KVH][D_H]
385-
at::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
384+
pta::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
386385
cache_V_dq, // [1][MAX_T][N_KVH][D_H]
387386
int32_t* qparam_k_ptr,
388387
int32_t* qparam_v_ptr,
@@ -543,19 +542,19 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
543542
constexpr int32_t kMaxBlocks = 512;
544543
dim3 blocks(B, std::max<int32_t>(1, kMaxBlocks / B));
545544
dim3 threads(kThreadsPerWarp, kWarpsPerBlock);
546-
#define CALL_DEQUANTIZE_FP8_CACHE(EXTERNAL_Q_PARAM) \
547-
FBGEMM_LAUNCH_KERNEL( \
548-
(dequantize_fp8_cache_kernel<EXTERNAL_Q_PARAM>), \
549-
blocks, \
550-
threads, \
551-
0, \
552-
at::cuda::getCurrentCUDAStream(), \
553-
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(), \
554-
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(), \
555-
kv_seqlen.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), \
556-
cache_K_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(), \
557-
cache_V_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(), \
558-
qparam_k_ptr, \
545+
#define CALL_DEQUANTIZE_FP8_CACHE(EXTERNAL_Q_PARAM) \
546+
FBGEMM_LAUNCH_KERNEL( \
547+
(dequantize_fp8_cache_kernel<EXTERNAL_Q_PARAM>), \
548+
blocks, \
549+
threads, \
550+
0, \
551+
at::cuda::getCurrentCUDAStream(), \
552+
PTA_B(cache_K, uint8_t, 4, 64), \
553+
PTA_B(cache_V, uint8_t, 4, 64), \
554+
PTA_B(kv_seqlen, int32_t, 1, 32), \
555+
PTA_B(cache_K_dq, at::BFloat16, 4, 64), \
556+
PTA_B(cache_V_dq, at::BFloat16, 4, 64), \
557+
qparam_k_ptr, \
559558
qparam_v_ptr);
560559
if (block_tables_ptr == nullptr) {
561560
if (qparam_k_ptr) {
@@ -571,11 +570,11 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
571570
threads,
572571
0,
573572
at::cuda::getCurrentCUDAStream(),
574-
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
575-
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
576-
kv_seqlen.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
577-
cache_K_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
578-
cache_V_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
573+
PTA_B(cache_K, uint8_t, 4, 64),
574+
PTA_B(cache_V, uint8_t, 4, 64),
575+
PTA_B(kv_seqlen, int32_t, 1, 32),
576+
PTA_B(cache_K_dq, at::BFloat16, 4, 64),
577+
PTA_B(cache_V_dq, at::BFloat16, 4, 64),
579578
qparam_k_ptr,
580579
qparam_v_ptr,
581580
block_tables_ptr,
@@ -612,11 +611,11 @@ __global__ void quantizeQKVPerHead(
612611
const int32_t* varseq_seqpos, // [B_T]
613612
const int32_t* varseq_batch, // [B_T]
614613
const bool* is_precalculated_qparam, // [B_T]
615-
at::PackedTensorAccessor64<at::Float8_e4m3fn, 3, at::RestrictPtrTraits>
614+
pta::PackedTensorAccessor64<at::Float8_e4m3fn, 3, at::RestrictPtrTraits>
616615
XQ_O, // [B_T][N_H][D]
617-
at::PackedTensorAccessor64<at::Float8_e4m3fn, 4, at::RestrictPtrTraits>
616+
pta::PackedTensorAccessor64<at::Float8_e4m3fn, 4, at::RestrictPtrTraits>
618617
cache_K, // [B][MAX_T][N_KVH][D_H]
619-
at::PackedTensorAccessor64<at::Float8_e4m3fn, 4, at::RestrictPtrTraits>
618+
pta::PackedTensorAccessor64<at::Float8_e4m3fn, 4, at::RestrictPtrTraits>
620619
cache_V, // [B][MAX_T][N_KVH][D_H]
621620
float* const scale_q,
622621
float* const scale_k,
@@ -775,9 +774,9 @@ at::Tensor quantize_qkv_per_head(
775774
is_precalculated_qparam.has_value()
776775
? is_precalculated_qparam.value().data_ptr<bool>()
777776
: nullptr,
778-
XQ_O.packed_accessor64<at::Float8_e4m3fn, 3, at::RestrictPtrTraits>(),
779-
cache_K.packed_accessor64<at::Float8_e4m3fn, 4, at::RestrictPtrTraits>(),
780-
cache_V.packed_accessor64<at::Float8_e4m3fn, 4, at::RestrictPtrTraits>(),
777+
PTA_B(XQ_O, at::Float8_e4m3fn, 3, 64),
778+
PTA_B(cache_K, at::Float8_e4m3fn, 4, 64),
779+
PTA_B(cache_V, at::Float8_e4m3fn, 4, 64),
781780
scale_q_ptr,
782781
qparam_k_ptr,
783782
qparam_v_ptr,

0 commit comments

Comments
 (0)