13
13
#include < ATen/cuda/Exceptions.h>
14
14
#include < c10/cuda/CUDAGuard.h>
15
15
#include < ATen/cuda/Atomic.cuh>
16
- #include < algorithm>
17
16
#include " c10/core/ScalarType.h"
18
17
#include " c10/util/BFloat16.h"
19
18
#include " kv_cache.cuh"
@@ -34,14 +33,14 @@ namespace fbgemm_gpu {
34
33
35
34
template <int KVQuantNumGroups = 1 >
36
35
__global__ void dequantize_int4_cache_kernel (
37
- at ::PackedTensorAccessor64<uint8_t , 4 , at::RestrictPtrTraits>
36
+ pta ::PackedTensorAccessor64<uint8_t , 4 , at::RestrictPtrTraits>
38
37
cache_K, // [B][MAX_T][N_KVH][D_H]
39
- at ::PackedTensorAccessor64<uint8_t , 4 , at::RestrictPtrTraits>
38
+ pta ::PackedTensorAccessor64<uint8_t , 4 , at::RestrictPtrTraits>
40
39
cache_V, // [B][MAX_T][N_KVH][D_H // G]
41
- at ::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> kv_seqlen,
42
- at ::PackedTensorAccessor64<at::BFloat16, 4 , at::RestrictPtrTraits>
40
+ pta ::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> kv_seqlen,
41
+ pta ::PackedTensorAccessor64<at::BFloat16, 4 , at::RestrictPtrTraits>
43
42
cache_K_dq, // [B][MAX_T][N_KVH][D_H]
44
- at ::PackedTensorAccessor64<at::BFloat16, 4 , at::RestrictPtrTraits>
43
+ pta ::PackedTensorAccessor64<at::BFloat16, 4 , at::RestrictPtrTraits>
45
44
cache_V_dq // [B][MAX_T][N_KVH][D_H]
46
45
) {
47
46
auto N_KVH = cache_K.size (2 );
@@ -113,18 +112,18 @@ __global__ void dequantize_int4_cache_kernel(
113
112
}
114
113
}
115
114
116
- #define CALL_DEQUANTIZE_INT4_CACHE_GROUPWISE_KERNEL (NUM_GROUPS, ...) \
117
- FBGEMM_LAUNCH_KERNEL ( \
118
- (dequantize_int4_cache_kernel<NUM_GROUPS>), \
119
- blocks, \
120
- threads, \
121
- 0 , \
122
- at::cuda::getCurrentCUDAStream (), \
123
- cache_K.packed_accessor64< uint8_t , 4 , at::RestrictPtrTraits>(), \
124
- cache_V.packed_accessor64< uint8_t , 4 , at::RestrictPtrTraits>(), \
125
- kv_seqlen.packed_accessor32< int32_t , 1 , at::RestrictPtrTraits>(), \
126
- cache_K_dq.packed_accessor64< at::BFloat16, 4 , at::RestrictPtrTraits>(), \
127
- cache_V_dq.packed_accessor64< at::BFloat16, 4 , at::RestrictPtrTraits>( ));
115
+ #define CALL_DEQUANTIZE_INT4_CACHE_GROUPWISE_KERNEL (NUM_GROUPS, ...) \
116
+ FBGEMM_LAUNCH_KERNEL ( \
117
+ (dequantize_int4_cache_kernel<NUM_GROUPS>), \
118
+ blocks, \
119
+ threads, \
120
+ 0 , \
121
+ at::cuda::getCurrentCUDAStream (), \
122
+ PTA_B ( cache_K, uint8_t , 4 , 64 ), \
123
+ PTA_B ( cache_V, uint8_t , 4 , 64 ), \
124
+ PTA_B ( kv_seqlen, int32_t , 1 , 32 ), \
125
+ PTA_B ( cache_K_dq, at::BFloat16, 4 , 64 ), \
126
+ PTA_B ( cache_V_dq, at::BFloat16, 4 , 64 ));
128
127
129
128
std::tuple<at::Tensor, at::Tensor> dequantize_int4_cache (
130
129
at::Tensor cache_K,
@@ -178,14 +177,14 @@ std::tuple<at::Tensor, at::Tensor> dequantize_int4_cache(
178
177
template <bool ExternalQParam>
179
178
__global__ void dequantize_fp8_cache_kernel (
180
179
// This code currently represents FP8 version not int4
181
- at ::PackedTensorAccessor64<uint8_t , 4 , at::RestrictPtrTraits>
180
+ pta ::PackedTensorAccessor64<uint8_t , 4 , at::RestrictPtrTraits>
182
181
cache_K, // [B][MAX_T][N_KVH][D_H]
183
- at ::PackedTensorAccessor64<uint8_t , 4 , at::RestrictPtrTraits>
182
+ pta ::PackedTensorAccessor64<uint8_t , 4 , at::RestrictPtrTraits>
184
183
cache_V, // [B][MAX_T][N_KVH][D_H // G]
185
- at ::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> kv_seqlen,
186
- at ::PackedTensorAccessor64<at::BFloat16, 4 , at::RestrictPtrTraits>
184
+ pta ::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> kv_seqlen,
185
+ pta ::PackedTensorAccessor64<at::BFloat16, 4 , at::RestrictPtrTraits>
187
186
cache_K_dq, // [B][MAX_T][N_KVH][D_H]
188
- at ::PackedTensorAccessor64<at::BFloat16, 4 , at::RestrictPtrTraits>
187
+ pta ::PackedTensorAccessor64<at::BFloat16, 4 , at::RestrictPtrTraits>
189
188
cache_V_dq, // [B][MAX_T][N_KVH][D_H]
190
189
int32_t * qparam_k_ptr,
191
190
int32_t * qparam_v_ptr) {
@@ -262,14 +261,14 @@ __global__ void dequantize_fp8_cache_kernel(
262
261
263
262
__global__ void dequantize_fp8_cache_kernel_paged (
264
263
// This code currently represents FP8 version not int4
265
- at ::PackedTensorAccessor64<uint8_t , 4 , at::RestrictPtrTraits>
264
+ pta ::PackedTensorAccessor64<uint8_t , 4 , at::RestrictPtrTraits>
266
265
cache_K, // [1][MAX_PAGE * PAGE_SIZE][N_KVH][D_H]
267
- at ::PackedTensorAccessor64<uint8_t , 4 , at::RestrictPtrTraits>
266
+ pta ::PackedTensorAccessor64<uint8_t , 4 , at::RestrictPtrTraits>
268
267
cache_V, // [1][MAX_PAGE * PAGE_SIZE][N_KVH][D_H // G]
269
- at ::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> kv_seqlen,
270
- at ::PackedTensorAccessor64<at::BFloat16, 4 , at::RestrictPtrTraits>
268
+ pta ::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> kv_seqlen,
269
+ pta ::PackedTensorAccessor64<at::BFloat16, 4 , at::RestrictPtrTraits>
271
270
cache_K_dq, // [1][MAX_T][N_KVH][D_H]
272
- at ::PackedTensorAccessor64<at::BFloat16, 4 , at::RestrictPtrTraits>
271
+ pta ::PackedTensorAccessor64<at::BFloat16, 4 , at::RestrictPtrTraits>
273
272
cache_V_dq, // [1][MAX_T][N_KVH][D_H]
274
273
int32_t * qparam_k_ptr,
275
274
int32_t * qparam_v_ptr,
@@ -283,14 +282,14 @@ __global__ void dequantize_fp8_cache_kernel_paged(
283
282
template <bool ExternalQParam>
284
283
__global__ void dequantize_fp8_cache_kernel (
285
284
// This code currently represents FP8 version not int4
286
- at ::PackedTensorAccessor64<uint8_t , 4 , at::RestrictPtrTraits>
285
+ pta ::PackedTensorAccessor64<uint8_t , 4 , at::RestrictPtrTraits>
287
286
cache_K, // [B][MAX_T][N_KVH][D_H]
288
- at ::PackedTensorAccessor64<uint8_t , 4 , at::RestrictPtrTraits>
287
+ pta ::PackedTensorAccessor64<uint8_t , 4 , at::RestrictPtrTraits>
289
288
cache_V, // [B][MAX_T][N_KVH][D_H // G]
290
- at ::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> kv_seqlen,
291
- at ::PackedTensorAccessor64<at::BFloat16, 4 , at::RestrictPtrTraits>
289
+ pta ::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> kv_seqlen,
290
+ pta ::PackedTensorAccessor64<at::BFloat16, 4 , at::RestrictPtrTraits>
292
291
cache_K_dq, // [B][MAX_T][N_KVH][D_H]
293
- at ::PackedTensorAccessor64<at::BFloat16, 4 , at::RestrictPtrTraits>
292
+ pta ::PackedTensorAccessor64<at::BFloat16, 4 , at::RestrictPtrTraits>
294
293
cache_V_dq, // [B][MAX_T][N_KVH][D_H]
295
294
int32_t * qparam_k_ptr,
296
295
int32_t * qparam_v_ptr) {
@@ -375,14 +374,14 @@ __global__ void dequantize_fp8_cache_kernel(
375
374
// kernel for now.
376
375
__global__ void dequantize_fp8_cache_kernel_paged (
377
376
// This code currently represents FP8 version not int4
378
- at ::PackedTensorAccessor64<uint8_t , 4 , at::RestrictPtrTraits>
377
+ pta ::PackedTensorAccessor64<uint8_t , 4 , at::RestrictPtrTraits>
379
378
cache_K, // [1][MAX_PAGE * PAGE_SIZE][N_KVH][D_H]
380
- at ::PackedTensorAccessor64<uint8_t , 4 , at::RestrictPtrTraits>
379
+ pta ::PackedTensorAccessor64<uint8_t , 4 , at::RestrictPtrTraits>
381
380
cache_V, // [1][MAX_PAGE * PAGE_SIZE][N_KVH][D_H // G]
382
- at ::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> kv_seqlen,
383
- at ::PackedTensorAccessor64<at::BFloat16, 4 , at::RestrictPtrTraits>
381
+ pta ::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> kv_seqlen,
382
+ pta ::PackedTensorAccessor64<at::BFloat16, 4 , at::RestrictPtrTraits>
384
383
cache_K_dq, // [1][MAX_T][N_KVH][D_H]
385
- at ::PackedTensorAccessor64<at::BFloat16, 4 , at::RestrictPtrTraits>
384
+ pta ::PackedTensorAccessor64<at::BFloat16, 4 , at::RestrictPtrTraits>
386
385
cache_V_dq, // [1][MAX_T][N_KVH][D_H]
387
386
int32_t * qparam_k_ptr,
388
387
int32_t * qparam_v_ptr,
@@ -543,19 +542,19 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
543
542
constexpr int32_t kMaxBlocks = 512 ;
544
543
dim3 blocks (B, std::max<int32_t >(1 , kMaxBlocks / B));
545
544
dim3 threads (kThreadsPerWarp , kWarpsPerBlock );
546
- #define CALL_DEQUANTIZE_FP8_CACHE (EXTERNAL_Q_PARAM ) \
547
- FBGEMM_LAUNCH_KERNEL ( \
548
- (dequantize_fp8_cache_kernel<EXTERNAL_Q_PARAM>), \
549
- blocks, \
550
- threads, \
551
- 0 , \
552
- at::cuda::getCurrentCUDAStream (), \
553
- cache_K. packed_accessor64 < uint8_t , 4 , at::RestrictPtrTraits>(), \
554
- cache_V. packed_accessor64 < uint8_t , 4 , at::RestrictPtrTraits>(), \
555
- kv_seqlen. packed_accessor32 < int32_t , 1 , at::RestrictPtrTraits>(), \
556
- cache_K_dq. packed_accessor64 < at::BFloat16, 4 , at::RestrictPtrTraits>(), \
557
- cache_V_dq. packed_accessor64 < at::BFloat16, 4 , at::RestrictPtrTraits>(), \
558
- qparam_k_ptr, \
545
+ #define CALL_DEQUANTIZE_FP8_CACHE (EXTERNAL_Q_PARAM ) \
546
+ FBGEMM_LAUNCH_KERNEL ( \
547
+ (dequantize_fp8_cache_kernel<EXTERNAL_Q_PARAM>), \
548
+ blocks, \
549
+ threads, \
550
+ 0 , \
551
+ at::cuda::getCurrentCUDAStream (), \
552
+ PTA_B ( cache_K, uint8_t , 4 , 64 ), \
553
+ PTA_B ( cache_V, uint8_t , 4 , 64 ), \
554
+ PTA_B ( kv_seqlen, int32_t , 1 , 32 ), \
555
+ PTA_B ( cache_K_dq, at::BFloat16, 4 , 64 ), \
556
+ PTA_B ( cache_V_dq, at::BFloat16, 4 , 64 ), \
557
+ qparam_k_ptr, \
559
558
qparam_v_ptr);
560
559
if (block_tables_ptr == nullptr ) {
561
560
if (qparam_k_ptr) {
@@ -571,11 +570,11 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
571
570
threads,
572
571
0 ,
573
572
at::cuda::getCurrentCUDAStream (),
574
- cache_K. packed_accessor64 < uint8_t , 4 , at::RestrictPtrTraits>( ),
575
- cache_V. packed_accessor64 < uint8_t , 4 , at::RestrictPtrTraits>( ),
576
- kv_seqlen. packed_accessor32 < int32_t , 1 , at::RestrictPtrTraits>( ),
577
- cache_K_dq. packed_accessor64 < at::BFloat16, 4 , at::RestrictPtrTraits>( ),
578
- cache_V_dq. packed_accessor64 < at::BFloat16, 4 , at::RestrictPtrTraits>( ),
573
+ PTA_B ( cache_K, uint8_t , 4 , 64 ),
574
+ PTA_B ( cache_V, uint8_t , 4 , 64 ),
575
+ PTA_B ( kv_seqlen, int32_t , 1 , 32 ),
576
+ PTA_B ( cache_K_dq, at::BFloat16, 4 , 64 ),
577
+ PTA_B ( cache_V_dq, at::BFloat16, 4 , 64 ),
579
578
qparam_k_ptr,
580
579
qparam_v_ptr,
581
580
block_tables_ptr,
@@ -612,11 +611,11 @@ __global__ void quantizeQKVPerHead(
612
611
const int32_t * varseq_seqpos, // [B_T]
613
612
const int32_t * varseq_batch, // [B_T]
614
613
const bool * is_precalculated_qparam, // [B_T]
615
- at ::PackedTensorAccessor64<at::Float8_e4m3fn, 3 , at::RestrictPtrTraits>
614
+ pta ::PackedTensorAccessor64<at::Float8_e4m3fn, 3 , at::RestrictPtrTraits>
616
615
XQ_O, // [B_T][N_H][D]
617
- at ::PackedTensorAccessor64<at::Float8_e4m3fn, 4 , at::RestrictPtrTraits>
616
+ pta ::PackedTensorAccessor64<at::Float8_e4m3fn, 4 , at::RestrictPtrTraits>
618
617
cache_K, // [B][MAX_T][N_KVH][D_H]
619
- at ::PackedTensorAccessor64<at::Float8_e4m3fn, 4 , at::RestrictPtrTraits>
618
+ pta ::PackedTensorAccessor64<at::Float8_e4m3fn, 4 , at::RestrictPtrTraits>
620
619
cache_V, // [B][MAX_T][N_KVH][D_H]
621
620
float * const scale_q,
622
621
float * const scale_k,
@@ -775,9 +774,9 @@ at::Tensor quantize_qkv_per_head(
775
774
is_precalculated_qparam.has_value ()
776
775
? is_precalculated_qparam.value ().data_ptr <bool >()
777
776
: nullptr ,
778
- XQ_O. packed_accessor64 < at::Float8_e4m3fn, 3 , at::RestrictPtrTraits>( ),
779
- cache_K. packed_accessor64 < at::Float8_e4m3fn, 4 , at::RestrictPtrTraits>( ),
780
- cache_V. packed_accessor64 < at::Float8_e4m3fn, 4 , at::RestrictPtrTraits>( ),
777
+ PTA_B ( XQ_O, at::Float8_e4m3fn, 3 , 64 ),
778
+ PTA_B ( cache_K, at::Float8_e4m3fn, 4 , 64 ),
779
+ PTA_B ( cache_V, at::Float8_e4m3fn, 4 , 64 ),
781
780
scale_q_ptr,
782
781
qparam_k_ptr,
783
782
qparam_v_ptr,
0 commit comments