Skip to content

Commit 1eacb9e

Browse files
q10facebook-github-bot
authored andcommitted
Migrate GenAI gqa attn splitk kernels to FBGEMM_LAUNCH_KERNEL, pt 2 (#4914)
Summary: Pull Request resolved: #4914 - Migrate GenAI gqa attn splitk kernels to `FBGEMM_LAUNCH_KERNEL`, pt 2 Reviewed By: r-barnes Differential Revision: D81829067 fbshipit-source-id: 8a09a460db4b61794913ea3cf4fdbe86e950631b
1 parent 42600ba commit 1eacb9e

File tree

1 file changed

+38
-31
lines changed

1 file changed

+38
-31
lines changed

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/gqa_attn_splitk.cu

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,18 +1632,21 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> gqa_attn_splitk_impl(
16321632
O.packed_accessor32<float, 5, at::RestrictPtrTraits>(),
16331633
seq_positions.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>());
16341634
} else {
1635-
#define CALL_MQA_ATTN_SPLITKV_INT4_GROUPWISE_KERNEL(NUM_GROUPS, ...) \
1636-
if (set_max_dynamic_smem) { \
1637-
set_gpu_max_dynamic_shared_memory( \
1638-
gqa_attn_splitk_v_int4_kernel<NUM_GROUPS>, smem, device); \
1639-
} \
1640-
gqa_attn_splitk_v_int4_kernel<NUM_GROUPS> \
1641-
<<<blocks, threads, smem, at::cuda::getCurrentCUDAStream()>>>( \
1642-
attn_out.packed_accessor32<float, 3, at::RestrictPtrTraits>(), \
1643-
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(), \
1644-
O.packed_accessor32<float, 5, at::RestrictPtrTraits>(), \
1645-
seq_positions \
1646-
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>());
1635+
#define CALL_MQA_ATTN_SPLITKV_INT4_GROUPWISE_KERNEL(NUM_GROUPS, ...) \
1636+
if (set_max_dynamic_smem) { \
1637+
set_gpu_max_dynamic_shared_memory( \
1638+
gqa_attn_splitk_v_int4_kernel<NUM_GROUPS>, smem, device); \
1639+
} \
1640+
FBGEMM_LAUNCH_KERNEL( \
1641+
(gqa_attn_splitk_v_int4_kernel<NUM_GROUPS>), \
1642+
blocks, \
1643+
threads, \
1644+
smem, \
1645+
at::cuda::getCurrentCUDAStream(), \
1646+
attn_out.packed_accessor32<float, 3, at::RestrictPtrTraits>(), \
1647+
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(), \
1648+
O.packed_accessor32<float, 5, at::RestrictPtrTraits>(), \
1649+
seq_positions.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>());
16471650

16481651
auto num_groups_ = num_groups ? num_groups.value() : 1;
16491652
CALL_INT4_KERNEL_WITH_KV_GROUPWISE_QUANT_CHECK(
@@ -2535,30 +2538,31 @@ at::Tensor mqa_attn(
25352538
if (set_max_dynamic_smem) {
25362539
set_gpu_max_dynamic_shared_memory(mqa_attn_kernel, smem, XQ.get_device());
25372540
}
2538-
mqa_attn_kernel<<<
2541+
FBGEMM_LAUNCH_KERNEL(
2542+
(mqa_attn_kernel),
25392543
blocks,
25402544
threads,
25412545
smem,
2542-
at::cuda::getCurrentCUDAStream()>>>(
2546+
at::cuda::getCurrentCUDAStream(),
25432547
XQ.packed_accessor32<at::BFloat16, 4, at::RestrictPtrTraits>(),
25442548
cache_K.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
25452549
cache_V.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
25462550
O.packed_accessor32<at::BFloat16, 4, at::RestrictPtrTraits>(),
25472551
seq_positions.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
25482552
qk_scale);
2549-
C10_CUDA_KERNEL_LAUNCH_CHECK();
25502553
} else {
25512554
if (cache_logical_dtype == CacheLogicalDtype::FP8) {
25522555
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000)
25532556
if (set_max_dynamic_smem) {
25542557
set_gpu_max_dynamic_shared_memory(
25552558
mqa_attn_fp8_kernel, smem, XQ.get_device());
25562559
}
2557-
mqa_attn_fp8_kernel<<<
2560+
FBGEMM_LAUNCH_KERNEL(
2561+
(mqa_attn_fp8_kernel),
25582562
blocks,
25592563
threads,
25602564
smem,
2561-
at::cuda::getCurrentCUDAStream()>>>(
2565+
at::cuda::getCurrentCUDAStream(),
25622566
XQ.packed_accessor32<at::BFloat16, 4, at::RestrictPtrTraits>(),
25632567
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
25642568
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
@@ -2569,20 +2573,23 @@ at::Tensor mqa_attn(
25692573
throw std::runtime_error("CUDA version is older than 12.0");
25702574
#endif
25712575
} else {
2572-
#define CALL_MQA_ATTN_INT4_GROUPWISE_KERNEL(NUM_GROUPS, ...) \
2573-
if (set_max_dynamic_smem) { \
2574-
set_gpu_max_dynamic_shared_memory( \
2575-
mqa_attn_int4_kernel<NUM_GROUPS>, smem, XQ.get_device()); \
2576-
} \
2577-
mqa_attn_int4_kernel<NUM_GROUPS> \
2578-
<<<blocks, threads, smem, at::cuda::getCurrentCUDAStream()>>>( \
2579-
XQ.packed_accessor32<at::BFloat16, 4, at::RestrictPtrTraits>(), \
2580-
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(), \
2581-
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(), \
2582-
O.packed_accessor32<at::BFloat16, 4, at::RestrictPtrTraits>(), \
2583-
seq_positions \
2584-
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), \
2585-
qk_scale);
2576+
#define CALL_MQA_ATTN_INT4_GROUPWISE_KERNEL(NUM_GROUPS, ...) \
2577+
if (set_max_dynamic_smem) { \
2578+
set_gpu_max_dynamic_shared_memory( \
2579+
mqa_attn_int4_kernel<NUM_GROUPS>, smem, XQ.get_device()); \
2580+
} \
2581+
FBGEMM_LAUNCH_KERNEL( \
2582+
(mqa_attn_int4_kernel<NUM_GROUPS>), \
2583+
blocks, \
2584+
threads, \
2585+
smem, \
2586+
at::cuda::getCurrentCUDAStream(), \
2587+
XQ.packed_accessor32<at::BFloat16, 4, at::RestrictPtrTraits>(), \
2588+
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(), \
2589+
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(), \
2590+
O.packed_accessor32<at::BFloat16, 4, at::RestrictPtrTraits>(), \
2591+
seq_positions.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), \
2592+
qk_scale);
25862593

25872594
auto num_groups_ = num_groups ? num_groups.value() : 1;
25882595
CALL_INT4_KERNEL_WITH_KV_GROUPWISE_QUANT_CHECK(

0 commit comments

Comments
 (0)