25
25
#include < cub/cub.cuh>
26
26
27
27
#include " fbgemm_gpu/utils/cuda_block_count.h"
28
+ #include " fbgemm_gpu/utils/kernel_launcher.cuh"
28
29
#include " fbgemm_gpu/utils/vec_quant.cuh"
29
30
30
31
#include < torch/torch.h>
@@ -113,8 +114,12 @@ __global__ void dequantize_int4_cache_kernel(
113
114
}
114
115
115
116
#define CALL_DEQUANTIZE_INT4_CACHE_GROUPWISE_KERNEL (NUM_GROUPS, ...) \
116
- dequantize_int4_cache_kernel< \
117
- NUM_GROUPS><<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream()>>> ( \
117
+ FBGEMM_LAUNCH_KERNEL ( \
118
+ (dequantize_int4_cache_kernel<NUM_GROUPS>), \
119
+ blocks, \
120
+ threads, \
121
+ 0 , \
122
+ at::cuda::getCurrentCUDAStream (), \
118
123
cache_K.packed_accessor64<uint8_t , 4 , at::RestrictPtrTraits>(), \
119
124
cache_V.packed_accessor64<uint8_t , 4 , at::RestrictPtrTraits>(), \
120
125
kv_seqlen.packed_accessor32<int32_t , 1 , at::RestrictPtrTraits>(), \
@@ -539,16 +544,19 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
539
544
dim3 blocks (B, std::max<int32_t >(1 , kMaxBlocks / B));
540
545
dim3 threads (kThreadsPerWarp , kWarpsPerBlock );
541
546
#define CALL_DEQUANTIZE_FP8_CACHE (EXTERNAL_Q_PARAM ) \
542
- const auto deq_fn = dequantize_fp8_cache_kernel<EXTERNAL_Q_PARAM>; \
543
- deq_fn<<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream()>>> ( \
547
+ FBGEMM_LAUNCH_KERNEL ( \
548
+ (dequantize_fp8_cache_kernel<EXTERNAL_Q_PARAM>), \
549
+ blocks, \
550
+ threads, \
551
+ 0 , \
552
+ at::cuda::getCurrentCUDAStream (), \
544
553
cache_K.packed_accessor64 <uint8_t , 4 , at::RestrictPtrTraits>(), \
545
554
cache_V.packed_accessor64 <uint8_t , 4 , at::RestrictPtrTraits>(), \
546
555
kv_seqlen.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(), \
547
556
cache_K_dq.packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(), \
548
557
cache_V_dq.packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(), \
549
558
qparam_k_ptr, \
550
- qparam_v_ptr); \
551
- C10_CUDA_KERNEL_LAUNCH_CHECK ()
559
+ qparam_v_ptr);
552
560
if (block_tables_ptr == nullptr ) {
553
561
if (qparam_k_ptr) {
554
562
CALL_DEQUANTIZE_FP8_CACHE (true );
@@ -557,11 +565,12 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
557
565
}
558
566
#undef CALL_DEQUANTIZE_FP8_CACHE
559
567
} else {
560
- dequantize_fp8_cache_kernel_paged<<<
568
+ FBGEMM_LAUNCH_KERNEL (
569
+ (dequantize_fp8_cache_kernel_paged),
561
570
blocks,
562
571
threads,
563
572
0 ,
564
- at::cuda::getCurrentCUDAStream ()>>>(
573
+ at::cuda::getCurrentCUDAStream (),
565
574
cache_K.packed_accessor64 <uint8_t , 4 , at::RestrictPtrTraits>(),
566
575
cache_V.packed_accessor64 <uint8_t , 4 , at::RestrictPtrTraits>(),
567
576
kv_seqlen.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
@@ -572,7 +581,6 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
572
581
block_tables_ptr,
573
582
block_tables_b_stride,
574
583
page_size);
575
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
576
584
}
577
585
578
586
return {cache_K_dq, cache_V_dq};
@@ -752,11 +760,13 @@ at::Tensor quantize_qkv_per_head(
752
760
auto scale_q = at::zeros ({B, N_KVH_L}, XQ_O.options ().dtype (at::kFloat ));
753
761
float * const scale_q_ptr = scale_q.data_ptr <float >();
754
762
// Launch the kernel
755
- quantizeQKVPerHead<<<
763
+
764
+ FBGEMM_LAUNCH_KERNEL (
765
+ (quantizeQKVPerHead),
756
766
grid_size,
757
767
block_size,
758
768
0 ,
759
- at::cuda::getCurrentCUDAStream ()>>>(
769
+ at::cuda::getCurrentCUDAStream (),
760
770
xqkv_amax_row.data_ptr <float >(),
761
771
xqkv.data_ptr <at::BFloat16>(),
762
772
varseq_seqpos.data_ptr <int32_t >(),
@@ -770,8 +780,8 @@ at::Tensor quantize_qkv_per_head(
770
780
cache_V.packed_accessor64 <at::Float8_e4m3fn, 4 , at::RestrictPtrTraits>(),
771
781
scale_q_ptr,
772
782
qparam_k_ptr,
773
- qparam_v_ptr);
774
- C10_CUDA_KERNEL_LAUNCH_CHECK ( );
783
+ qparam_v_ptr,
784
+ 64 . f );
775
785
return scale_q;
776
786
}
777
787
#else
0 commit comments