Skip to content

Commit 5e8c476

Browse files
q10facebook-github-bot
authored andcommitted
Migrate GenAI kv cache kernels to FBGEMM_LAUNCH_KERNEL, pt 2 (pytorch#4880)
Summary: Pull Request resolved: pytorch#4880 - Migrate GenAI kv cache kernels to `FBGEMM_LAUNCH_KERNEL`, pt 2 Reviewed By: cthi Differential Revision: D81699243 fbshipit-source-id: 85a11e0077c108000bc3a86698864b1e52692e06
1 parent 9dadc55 commit 5e8c476

File tree

1 file changed

+37
-21
lines changed

1 file changed

+37
-21
lines changed

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,10 +1167,15 @@ at::Tensor nope_qkv_varseq_prefill(
11671167
is_precalculated_qparam =
11681168
static_cast<bool*>(kv_quant_scale_precomputed.value().data_ptr());
11691169
}
1170-
rope_xpos_qkv_varseq_prefill_kernel_fp8<
1171-
PositionEmbeddingMode::NOPE,
1172-
CacheLogicalDtype::FP8,
1173-
1><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
1170+
FBGEMM_LAUNCH_KERNEL(
1171+
(rope_xpos_qkv_varseq_prefill_kernel_fp8<
1172+
PositionEmbeddingMode::NOPE,
1173+
CacheLogicalDtype::FP8,
1174+
1>),
1175+
blocks,
1176+
threads,
1177+
0,
1178+
at::cuda::getCurrentCUDAStream(),
11741179
XQ.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
11751180
XK.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
11761181
XV.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
@@ -1205,7 +1210,7 @@ at::Tensor nope_qkv_varseq_prefill(
12051210
k_norm,
12061211
amax_ptr,
12071212
is_precalculated_qparam);
1208-
C10_CUDA_KERNEL_LAUNCH_CHECK();
1213+
12091214
} else {
12101215
CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL(
12111216
1,
@@ -1367,10 +1372,15 @@ at::Tensor nope_qkv_decoding(
13671372
if (amax_qkv.has_value()) {
13681373
amax_ptr = static_cast<float*>(amax_qkv.value().data_ptr());
13691374
}
1370-
rope_xpos_qkv_varseq_prefill_kernel_fp8<
1371-
PositionEmbeddingMode::NOPE,
1372-
CacheLogicalDtype::FP8,
1373-
1><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
1375+
FBGEMM_LAUNCH_KERNEL(
1376+
(rope_xpos_qkv_varseq_prefill_kernel_fp8<
1377+
PositionEmbeddingMode::NOPE,
1378+
CacheLogicalDtype::FP8,
1379+
1>),
1380+
blocks,
1381+
threads,
1382+
0,
1383+
at::cuda::getCurrentCUDAStream(),
13741384
XQ.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
13751385
XK.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
13761386
XV.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
@@ -1406,8 +1416,6 @@ at::Tensor nope_qkv_decoding(
14061416
amax_ptr,
14071417
nullptr);
14081418

1409-
C10_CUDA_KERNEL_LAUNCH_CHECK();
1410-
14111419
} else {
14121420
CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL(
14131421
1,
@@ -1596,10 +1604,15 @@ at::Tensor rope_qkv_varseq_prefill(
15961604
is_precalculated_qparam =
15971605
static_cast<bool*>(kv_quant_scale_precomputed.value().data_ptr());
15981606
}
1599-
rope_xpos_qkv_varseq_prefill_kernel_fp8<
1600-
PositionEmbeddingMode::ROPE,
1601-
CacheLogicalDtype::FP8,
1602-
1><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
1607+
FBGEMM_LAUNCH_KERNEL(
1608+
(rope_xpos_qkv_varseq_prefill_kernel_fp8<
1609+
PositionEmbeddingMode::ROPE,
1610+
CacheLogicalDtype::FP8,
1611+
1>),
1612+
blocks,
1613+
threads,
1614+
0,
1615+
at::cuda::getCurrentCUDAStream(),
16031616
XQ.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
16041617
XK.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
16051618
XV.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
@@ -1634,7 +1647,6 @@ at::Tensor rope_qkv_varseq_prefill(
16341647
k_norm,
16351648
amax_ptr,
16361649
is_precalculated_qparam);
1637-
C10_CUDA_KERNEL_LAUNCH_CHECK();
16381650

16391651
} else {
16401652
CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL(
@@ -1970,10 +1982,15 @@ at::Tensor rope_qkv_decoding(
19701982
if (amax_qkv.has_value()) {
19711983
amax_ptr = static_cast<float*>(amax_qkv.value().data_ptr());
19721984
}
1973-
rope_xpos_qkv_varseq_prefill_kernel_fp8<
1974-
PositionEmbeddingMode::ROPE,
1975-
CacheLogicalDtype::FP8,
1976-
1><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
1985+
FBGEMM_LAUNCH_KERNEL(
1986+
(rope_xpos_qkv_varseq_prefill_kernel_fp8<
1987+
PositionEmbeddingMode::ROPE,
1988+
CacheLogicalDtype::FP8,
1989+
1>),
1990+
blocks,
1991+
threads,
1992+
0,
1993+
at::cuda::getCurrentCUDAStream(),
19771994
XQ.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
19781995
XK.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
19791996
XV.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
@@ -2009,7 +2026,6 @@ at::Tensor rope_qkv_decoding(
20092026
amax_ptr,
20102027
nullptr);
20112028

2012-
C10_CUDA_KERNEL_LAUNCH_CHECK();
20132029
} else {
20142030
CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL(
20152031
1,

0 commit comments

Comments
 (0)