Skip to content

Commit ff8b5b0

Browse files
q10facebook-github-bot
authored andcommitted
Migrate GenAI kv cache kernels to FBGEMM_LAUNCH_KERNEL, pt 3 (pytorch#4885)
Summary: Pull Request resolved: pytorch#4885 - Migrate GenAI kv cache kernels to `FBGEMM_LAUNCH_KERNEL`, pt 3 Reviewed By: r-barnes Differential Revision: D81703163 fbshipit-source-id: 9be7255cca53ab9ad116422e64bc6abb7964e871
1 parent 17c6316 commit ff8b5b0

File tree

2 files changed

+127
-111
lines changed

2 files changed

+127
-111
lines changed

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu

Lines changed: 126 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,35 +1550,38 @@ at::Tensor rope_qkv_varseq_prefill(
15501550
static_cast<int64_t*>(actual_batch_size.value().data_ptr());
15511551
}
15521552
if (cache_K.dtype() == at::kBFloat16) {
1553-
rope_xpos_qkv_varseq_prefill_kernel<PositionEmbeddingMode::ROPE>
1554-
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
1555-
XQ.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1556-
XK.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1557-
XV.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1558-
cache_K.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1559-
cache_V.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1560-
XQ_O.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1561-
varseq_batch.data_ptr<int32_t>(),
1562-
varseq_seqpos
1563-
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1564-
theta,
1565-
0,
1566-
0,
1567-
0,
1568-
block_tables_ptr,
1569-
page_size,
1570-
block_tables_b_stride,
1571-
varseq_cache_seqpos_
1572-
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1573-
actual_batch_size_ptr,
1574-
rope_scaling,
1575-
old_context_len,
1576-
scaling_factor,
1577-
lo_freq_factor,
1578-
hi_freq_factor,
1579-
write_k_back,
1580-
update_kv);
1581-
C10_CUDA_KERNEL_LAUNCH_CHECK();
1553+
FBGEMM_LAUNCH_KERNEL(
1554+
(rope_xpos_qkv_varseq_prefill_kernel<PositionEmbeddingMode::ROPE>),
1555+
blocks,
1556+
threads,
1557+
0,
1558+
at::cuda::getCurrentCUDAStream(),
1559+
XQ.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1560+
XK.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1561+
XV.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1562+
cache_K.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1563+
cache_V.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1564+
XQ_O.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1565+
varseq_batch.data_ptr<int32_t>(),
1566+
varseq_seqpos.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1567+
theta,
1568+
0,
1569+
0,
1570+
0,
1571+
block_tables_ptr,
1572+
page_size,
1573+
block_tables_b_stride,
1574+
varseq_cache_seqpos_
1575+
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1576+
actual_batch_size_ptr,
1577+
rope_scaling,
1578+
old_context_len,
1579+
scaling_factor,
1580+
lo_freq_factor,
1581+
hi_freq_factor,
1582+
write_k_back,
1583+
update_kv);
1584+
15821585
} else {
15831586
auto num_groups_ = num_groups ? num_groups.value() : 1;
15841587
auto varseq_batch_ = varseq_batch.data_ptr<int32_t>();
@@ -1767,33 +1770,38 @@ at::Tensor xpos_qkv_varseq_prefill(
17671770
static_cast<int64_t*>(actual_batch_size.value().data_ptr());
17681771
}
17691772
if (cache_K.dtype() == at::kBFloat16) {
1770-
rope_xpos_qkv_varseq_prefill_kernel<PositionEmbeddingMode::XPOS>
1771-
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
1772-
XQ.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1773-
XK.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1774-
XV.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1775-
cache_K.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1776-
cache_V.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1777-
XQ_O.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1778-
varseq_batch.data_ptr<int32_t>(),
1779-
varseq_seqpos
1780-
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1781-
theta,
1782-
gamma,
1783-
scale_base,
1784-
exponent_offset,
1785-
block_tables_ptr,
1786-
page_size,
1787-
block_tables_b_stride,
1788-
varseq_cache_seqpos_
1789-
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1790-
actual_batch_size_ptr,
1791-
rope_scaling,
1792-
old_context_len,
1793-
scaling_factor,
1794-
lo_freq_factor,
1795-
hi_freq_factor);
1796-
C10_CUDA_KERNEL_LAUNCH_CHECK();
1773+
FBGEMM_LAUNCH_KERNEL(
1774+
(rope_xpos_qkv_varseq_prefill_kernel<PositionEmbeddingMode::XPOS>),
1775+
blocks,
1776+
threads,
1777+
0,
1778+
at::cuda::getCurrentCUDAStream(),
1779+
XQ.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1780+
XK.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1781+
XV.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1782+
cache_K.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1783+
cache_V.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1784+
XQ_O.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1785+
varseq_batch.data_ptr<int32_t>(),
1786+
varseq_seqpos.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1787+
theta,
1788+
gamma,
1789+
scale_base,
1790+
exponent_offset,
1791+
block_tables_ptr,
1792+
page_size,
1793+
block_tables_b_stride,
1794+
varseq_cache_seqpos_
1795+
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1796+
actual_batch_size_ptr,
1797+
rope_scaling,
1798+
old_context_len,
1799+
scaling_factor,
1800+
lo_freq_factor,
1801+
hi_freq_factor,
1802+
false,
1803+
true);
1804+
17971805
} else {
17981806
auto num_groups_ = num_groups ? num_groups.value() : 1;
17991807
auto varseq_batch_ = varseq_batch.data_ptr<int32_t>();
@@ -1935,34 +1943,37 @@ at::Tensor rope_qkv_decoding(
19351943
}
19361944
auto cache_seqpos_ = cache_seqpos.value_or(seqpos);
19371945
if (cache_K.dtype() == at::kBFloat16) {
1938-
rope_xpos_qkv_varseq_prefill_kernel<PositionEmbeddingMode::ROPE>
1939-
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
1940-
XQ.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1941-
XK.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1942-
XV.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1943-
cache_K.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1944-
cache_V.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1945-
XQ_O.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1946-
batch.has_value() ? batch.value().data_ptr<int32_t>() : nullptr,
1947-
seqpos.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1948-
theta,
1949-
0,
1950-
0,
1951-
0,
1952-
block_tables_ptr,
1953-
page_size,
1954-
block_tables_b_stride,
1955-
cache_seqpos_
1956-
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1957-
actual_batch_size_ptr,
1958-
rope_scaling,
1959-
old_context_len,
1960-
scaling_factor,
1961-
lo_freq_factor,
1962-
hi_freq_factor,
1963-
false,
1964-
update_kv);
1965-
C10_CUDA_KERNEL_LAUNCH_CHECK();
1946+
FBGEMM_LAUNCH_KERNEL(
1947+
(rope_xpos_qkv_varseq_prefill_kernel<PositionEmbeddingMode::ROPE>),
1948+
blocks,
1949+
threads,
1950+
0,
1951+
at::cuda::getCurrentCUDAStream(),
1952+
XQ.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1953+
XK.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1954+
XV.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1955+
cache_K.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1956+
cache_V.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
1957+
XQ_O.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
1958+
batch.has_value() ? batch.value().data_ptr<int32_t>() : nullptr,
1959+
seqpos.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1960+
theta,
1961+
0,
1962+
0,
1963+
0,
1964+
block_tables_ptr,
1965+
page_size,
1966+
block_tables_b_stride,
1967+
cache_seqpos_.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1968+
actual_batch_size_ptr,
1969+
rope_scaling,
1970+
old_context_len,
1971+
scaling_factor,
1972+
lo_freq_factor,
1973+
hi_freq_factor,
1974+
false,
1975+
update_kv);
1976+
19661977
} else {
19671978
auto seqpos_ =
19681979
seqpos.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>();
@@ -2142,32 +2153,37 @@ at::Tensor xpos_qkv_decoding(
21422153
}
21432154
auto cache_seqpos_ = cache_seqpos.value_or(seqpos);
21442155
if (cache_K.dtype() == at::kBFloat16) {
2145-
rope_xpos_qkv_varseq_prefill_kernel<PositionEmbeddingMode::XPOS>
2146-
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
2147-
XQ.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
2148-
XK.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
2149-
XV.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
2150-
cache_K.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
2151-
cache_V.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
2152-
XQ_O.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
2153-
batch.has_value() ? batch.value().data_ptr<int32_t>() : nullptr,
2154-
seqpos.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
2155-
theta,
2156-
gamma,
2157-
scale_base,
2158-
exponent_offset,
2159-
block_tables_ptr,
2160-
page_size,
2161-
block_tables_b_stride,
2162-
cache_seqpos_
2163-
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
2164-
actual_batch_size_ptr,
2165-
rope_scaling,
2166-
old_context_len,
2167-
scaling_factor,
2168-
lo_freq_factor,
2169-
hi_freq_factor);
2170-
C10_CUDA_KERNEL_LAUNCH_CHECK();
2156+
FBGEMM_LAUNCH_KERNEL(
2157+
(rope_xpos_qkv_varseq_prefill_kernel<PositionEmbeddingMode::XPOS>),
2158+
blocks,
2159+
threads,
2160+
0,
2161+
at::cuda::getCurrentCUDAStream(),
2162+
XQ.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
2163+
XK.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
2164+
XV.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
2165+
cache_K.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
2166+
cache_V.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
2167+
XQ_O.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
2168+
batch.has_value() ? batch.value().data_ptr<int32_t>() : nullptr,
2169+
seqpos.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
2170+
theta,
2171+
gamma,
2172+
scale_base,
2173+
exponent_offset,
2174+
block_tables_ptr,
2175+
page_size,
2176+
block_tables_b_stride,
2177+
cache_seqpos_.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
2178+
actual_batch_size_ptr,
2179+
rope_scaling,
2180+
old_context_len,
2181+
scaling_factor,
2182+
lo_freq_factor,
2183+
hi_freq_factor,
2184+
false,
2185+
true);
2186+
21712187
} else {
21722188
auto num_groups_ = num_groups ? num_groups.value() : 1;
21732189
auto seqpos_ =

fbgemm_gpu/include/fbgemm_gpu/utils/kernel_launcher.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ struct KernelLauncher {
443443
#define FBGEMM_LAUNCH_KERNEL(KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
444444
([&] { \
445445
constexpr auto context = SOURCE_CONTEXT_CURRENT(KERNEL); \
446-
decltype(KERNEL)& kernel = KERNEL; \
446+
const auto& kernel = KERNEL; \
447447
\
448448
return fbgemm_gpu::utils:: \
449449
KernelLauncher<false, _FKL_BLOCKING_, _FKL_TENSORCHECK_>(context) \

0 commit comments

Comments
 (0)