@@ -1632,18 +1632,21 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> gqa_attn_splitk_impl(
1632
1632
O.packed_accessor32 <float , 5 , at::RestrictPtrTraits>(),
1633
1633
seq_positions.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>());
1634
1634
} else {
1635
- #define CALL_MQA_ATTN_SPLITKV_INT4_GROUPWISE_KERNEL (NUM_GROUPS, ...) \
1636
- if (set_max_dynamic_smem) { \
1637
- set_gpu_max_dynamic_shared_memory ( \
1638
- gqa_attn_splitk_v_int4_kernel<NUM_GROUPS>, smem, device); \
1639
- } \
1640
- gqa_attn_splitk_v_int4_kernel<NUM_GROUPS> \
1641
- <<<blocks, threads, smem, at::cuda::getCurrentCUDAStream()>>> ( \
1642
- attn_out.packed_accessor32 <float , 3 , at::RestrictPtrTraits>(), \
1643
- cache_V.packed_accessor64 <uint8_t , 4 , at::RestrictPtrTraits>(), \
1644
- O.packed_accessor32 <float , 5 , at::RestrictPtrTraits>(), \
1645
- seq_positions \
1646
- .packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>());
1635
+ #define CALL_MQA_ATTN_SPLITKV_INT4_GROUPWISE_KERNEL (NUM_GROUPS, ...) \
1636
+ if (set_max_dynamic_smem) { \
1637
+ set_gpu_max_dynamic_shared_memory ( \
1638
+ gqa_attn_splitk_v_int4_kernel<NUM_GROUPS>, smem, device); \
1639
+ } \
1640
+ FBGEMM_LAUNCH_KERNEL ( \
1641
+ (gqa_attn_splitk_v_int4_kernel<NUM_GROUPS>), \
1642
+ blocks, \
1643
+ threads, \
1644
+ smem, \
1645
+ at::cuda::getCurrentCUDAStream (), \
1646
+ attn_out.packed_accessor32 <float , 3 , at::RestrictPtrTraits>(), \
1647
+ cache_V.packed_accessor64 <uint8_t , 4 , at::RestrictPtrTraits>(), \
1648
+ O.packed_accessor32 <float , 5 , at::RestrictPtrTraits>(), \
1649
+ seq_positions.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>());
1647
1650
1648
1651
auto num_groups_ = num_groups ? num_groups.value () : 1 ;
1649
1652
CALL_INT4_KERNEL_WITH_KV_GROUPWISE_QUANT_CHECK (
@@ -2535,30 +2538,31 @@ at::Tensor mqa_attn(
2535
2538
if (set_max_dynamic_smem) {
2536
2539
set_gpu_max_dynamic_shared_memory (mqa_attn_kernel, smem, XQ.get_device ());
2537
2540
}
2538
- mqa_attn_kernel<<<
2541
+ FBGEMM_LAUNCH_KERNEL (
2542
+ (mqa_attn_kernel),
2539
2543
blocks,
2540
2544
threads,
2541
2545
smem,
2542
- at::cuda::getCurrentCUDAStream ()>>>(
2546
+ at::cuda::getCurrentCUDAStream (),
2543
2547
XQ.packed_accessor32 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
2544
2548
cache_K.packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
2545
2549
cache_V.packed_accessor64 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
2546
2550
O.packed_accessor32 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
2547
2551
seq_positions.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
2548
2552
qk_scale);
2549
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
2550
2553
} else {
2551
2554
if (cache_logical_dtype == CacheLogicalDtype::FP8) {
2552
2555
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000)
2553
2556
if (set_max_dynamic_smem) {
2554
2557
set_gpu_max_dynamic_shared_memory (
2555
2558
mqa_attn_fp8_kernel, smem, XQ.get_device ());
2556
2559
}
2557
- mqa_attn_fp8_kernel<<<
2560
+ FBGEMM_LAUNCH_KERNEL (
2561
+ (mqa_attn_fp8_kernel),
2558
2562
blocks,
2559
2563
threads,
2560
2564
smem,
2561
- at::cuda::getCurrentCUDAStream ()>>>(
2565
+ at::cuda::getCurrentCUDAStream (),
2562
2566
XQ.packed_accessor32 <at::BFloat16, 4 , at::RestrictPtrTraits>(),
2563
2567
cache_K.packed_accessor64 <uint8_t , 4 , at::RestrictPtrTraits>(),
2564
2568
cache_V.packed_accessor64 <uint8_t , 4 , at::RestrictPtrTraits>(),
@@ -2569,20 +2573,23 @@ at::Tensor mqa_attn(
2569
2573
throw std::runtime_error (" CUDA version is older than 12.0" );
2570
2574
#endif
2571
2575
} else {
2572
- #define CALL_MQA_ATTN_INT4_GROUPWISE_KERNEL (NUM_GROUPS, ...) \
2573
- if (set_max_dynamic_smem) { \
2574
- set_gpu_max_dynamic_shared_memory ( \
2575
- mqa_attn_int4_kernel<NUM_GROUPS>, smem, XQ.get_device ()); \
2576
- } \
2577
- mqa_attn_int4_kernel<NUM_GROUPS> \
2578
- <<<blocks, threads, smem, at::cuda::getCurrentCUDAStream()>>> ( \
2579
- XQ.packed_accessor32 <at::BFloat16, 4 , at::RestrictPtrTraits>(), \
2580
- cache_K.packed_accessor64 <uint8_t , 4 , at::RestrictPtrTraits>(), \
2581
- cache_V.packed_accessor64 <uint8_t , 4 , at::RestrictPtrTraits>(), \
2582
- O.packed_accessor32 <at::BFloat16, 4 , at::RestrictPtrTraits>(), \
2583
- seq_positions \
2584
- .packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(), \
2585
- qk_scale);
2576
+ #define CALL_MQA_ATTN_INT4_GROUPWISE_KERNEL (NUM_GROUPS, ...) \
2577
+ if (set_max_dynamic_smem) { \
2578
+ set_gpu_max_dynamic_shared_memory ( \
2579
+ mqa_attn_int4_kernel<NUM_GROUPS>, smem, XQ.get_device ()); \
2580
+ } \
2581
+ FBGEMM_LAUNCH_KERNEL ( \
2582
+ (mqa_attn_int4_kernel<NUM_GROUPS>), \
2583
+ blocks, \
2584
+ threads, \
2585
+ smem, \
2586
+ at::cuda::getCurrentCUDAStream (), \
2587
+ XQ.packed_accessor32 <at::BFloat16, 4 , at::RestrictPtrTraits>(), \
2588
+ cache_K.packed_accessor64 <uint8_t , 4 , at::RestrictPtrTraits>(), \
2589
+ cache_V.packed_accessor64 <uint8_t , 4 , at::RestrictPtrTraits>(), \
2590
+ O.packed_accessor32 <at::BFloat16, 4 , at::RestrictPtrTraits>(), \
2591
+ seq_positions.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(), \
2592
+ qk_scale);
2586
2593
2587
2594
auto num_groups_ = num_groups ? num_groups.value () : 1 ;
2588
2595
CALL_INT4_KERNEL_WITH_KV_GROUPWISE_QUANT_CHECK (
0 commit comments