Skip to content

Commit 11f2081

Browse files
[None][feat] Cherry-pick #10335: Use XQA JIT impl by default and mitigate perf loss with sliding window (#10954)
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
1 parent 453b7d1 commit 11f2081

File tree

3 files changed

+79
-42
lines changed

3 files changed

+79
-42
lines changed

cpp/kernels/xqa/gen_cubins.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@
8989
9090
#include "tensorrt_llm/common/config.h"
9191
92-
TRTLLM_NAMESPACE_BEGIN
92+
namespace tensorrt_llm
93+
{
9394
namespace kernels
9495
{
9596
// clang-format off
@@ -98,7 +99,7 @@
9899
cpp_file_suffex_text = R"""
99100
// clang-format on
100101
} // namespace kernels
101-
TRTLLM_NAMESPACE_END
102+
}
102103
"""
103104

104105
cubin_meta_info_struct_prefix_text = R"""
@@ -438,8 +439,9 @@ def generate_header_file_contents(
438439
CompileMacroOption('HEAD_ELEMS', 'd', [128]),
439440
CompileMacroOption('BEAM_WIDTH', 'beam', [1]),
440441
CompileMacroOption('CACHE_ELEM_ENUM', 'kvt', [0, 1, 2]),
441-
CompileMacroOption('TOKENS_PER_PAGE', 'pagedKV',
442-
[0, 64, 128]), # 0 denotes contiguous kv cache.
442+
CompileMacroOption(
443+
'TOKENS_PER_PAGE', 'pagedKV',
444+
[0, 32, 64, 128]), # 0 denotes contiguous kv cache.
443445
CompileMacroOption('HEAD_GRP_SIZE', 'nqpkv', [0]),
444446
CompileMacroOption('M_TILESIZE', 'm', [16, 32]),
445447
]]

cpp/kernels/xqa/mha.cu

Lines changed: 71 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -465,33 +465,24 @@ using WarpAcc = WarpAccT<warpTile.y, warpTile.x>;
465465
#if SPEC_DEC
466466
#define MMAS_N_PER_MASK 2
467467

468-
__device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskType const* mask, uint32_t rowOffset,
469-
uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize
470468
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
471-
,
472-
int32_t tok0WinBeg, uint32_t seqIter, uint32_t const cacheSeqLen, uint32_t const warpTileTokenBeg
473-
#endif
474-
)
469+
__device__ inline void applyMaskFromInputSlidingAndSpecDec(Warp const& warp, WarpAcc& acc, MaskType const* mask,
470+
uint32_t rowOffset, uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize,
471+
int32_t tok0WinBeg, uint32_t seqIter, uint32_t const cacheSeqLen, uint32_t const warpTileTokenBeg)
475472
{
476473
uint32_t const idxInQuad = laneId() % 4;
477474
uint32_t const idxQuad = laneId() / 4;
478475
// Packed mask is aligned with 32 bits (2 uint16_t).
479476
uint32_t const nbPackedMasksPerRow = divUp(qSeqLen, 32u) * 2u;
480477
uint16_t const* uint16Mask = reinterpret_cast<uint16_t const*>(mask);
481478
constexpr uint64_t fullMask = ~uint64_t{0};
482-
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
483479
Range const tileRange = {warpTileTokenBeg, warpTileTokenBeg + warpTile.x};
484480
Range const maxMaskOutRange = {0, mha::max(0, tok0WinBeg) + (nbValidRows / MMAS_N_PER_MASK - 1)};
485481
bool const ctaNeedBegMask = tileRange.beg < maxMaskOutRange.end;
486482
assert(ctaNeedBegMask == overlap(tileRange, maxMaskOutRange));
487483
int32_t const tok0NbMaskOut = int32_t(tok0WinBeg) - int32_t(warpTileTokenBeg);
488484
uint32_t const nbSeqItersWithoutSpecDecMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x;
489485
bool const ctaNeedSpecDecMask = (seqIter >= nbSeqItersWithoutSpecDecMask);
490-
#else
491-
constexpr bool ctaNeedBegMask = false;
492-
bool const ctaNeedSpecDecMask = true;
493-
int32_t const tok0NbMaskOut = -2147483648;
494-
#endif
495486
bool const needMask = ctaNeedBegMask || ctaNeedSpecDecMask;
496487

497488
if (!needMask)
@@ -559,6 +550,61 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
559550
}
560551
#endif
561552

553+
__device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskType const* mask, uint32_t rowOffset,
554+
uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize)
555+
{
556+
uint32_t const idxInQuad = laneId() % 4;
557+
uint32_t const idxQuad = laneId() / 4;
558+
// Packed mask is aligned with 32 bits (2 uint16_t).
559+
uint32_t const nbPackedMasksPerRow = divUp(qSeqLen, 32u) * 2u;
560+
uint16_t const* uint16Mask = reinterpret_cast<uint16_t const*>(mask);
561+
#pragma unroll
562+
for (uint32_t m = 0; m < acc.rows; m++)
563+
{
564+
#pragma unroll
565+
for (uint32_t i = 0; i < InstAcc::rows; i++)
566+
{
567+
uint32_t const tokenRow = min((rowOffset + instM * m + idxQuad + i * 8) / headGrpSize, actualQSeqLen - 1);
568+
#pragma unroll
569+
for (uint32_t mask_n = 0; mask_n < acc.cols / MMAS_N_PER_MASK; mask_n++)
570+
{
571+
uint32_t const firstCol = instN * mask_n * MMAS_N_PER_MASK + InstAcc::cols * idxInQuad;
572+
uint32_t const lastCol = firstCol + instN * (MMAS_N_PER_MASK - 1) + InstAcc::cols - 1;
573+
uint32_t const maskPos0 = firstCol + actualQSeqLen < nbValidCols
574+
? 0u
575+
: min(firstCol + actualQSeqLen - nbValidCols, actualQSeqLen - 1);
576+
uint32_t const maskPos1 = lastCol + actualQSeqLen < nbValidCols
577+
? 0u
578+
: min(lastCol + actualQSeqLen - nbValidCols, actualQSeqLen - 1);
579+
uint32_t packedMask = 0u;
580+
uint32_t const maskPosStart = (maskPos0 / 16) * 16;
581+
reinterpret_cast<uint16_t*>(&packedMask)[0]
582+
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos0 / 16)];
583+
reinterpret_cast<uint16_t*>(&packedMask)[1]
584+
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos1 / 16)];
585+
#pragma unroll
586+
for (uint32_t nj = 0; nj < MMAS_N_PER_MASK; nj++)
587+
{
588+
#pragma unroll
589+
for (uint32_t j = 0; j < InstAcc::cols; j++)
590+
{
591+
uint32_t const n = (mask_n * MMAS_N_PER_MASK + nj);
592+
uint32_t const col = instN * n + InstAcc::cols * idxInQuad + j;
593+
// bool const maskFlag = col + qSeqLen < nbValidCols ? true : mask[tokenRow * qSeqLen + (col +
594+
// qSeqLen - nbValidCols)];
595+
bool const maskFlag = col + actualQSeqLen < nbValidCols
596+
? true
597+
: packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart));
598+
acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax;
599+
}
600+
}
601+
}
602+
}
603+
}
604+
}
605+
606+
#endif
607+
562608
__device__ inline QuadRegRowMax warpTileOnlineSoftmax(Warp const& warp, QuadRegRowMax const& rowMaxHint, WarpAcc& acc)
563609
{
564610
QuadRegRowMax rowMax = rowMaxHint;
@@ -1655,7 +1701,7 @@ CUBIN_EXPORT __global__
16551701
uint32_t const tok0SeqLen = cacheSeqLen - actualQSeqLen + 1 + idxHeadTokenInGrp; // ctaTokOffset;
16561702
int32_t const tok0WinBeg = int32_t(tok0SeqLen) - int32_t(slidingWinSize);
16571703
uint32_t const nbTotalSkipTokens = mha::max(0, tok0WinBeg);
1658-
1704+
bool const rtIsReallySliding = (cacheSeqLen + actualQSeqLen > slidingWinSize);
16591705
#elif SLIDING_WINDOW
16601706
bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize);
16611707
assert(!SPEC_DEC || !rtIsReallySliding);
@@ -1673,7 +1719,8 @@ CUBIN_EXPORT __global__
16731719

16741720
uint32_t const nbSeqIters = useKVCache ? divUp(cacheSeqLen, ctaTile.x) : 0;
16751721
#if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE
1676-
uint32_t const nbSeqItersWithoutMask = nbSkipLeadingTiles;
1722+
uint32_t const nbSeqItersWithoutMask
1723+
= rtIsReallySliding ? nbSkipLeadingTiles : (cacheSeqLen - actualQSeqLen) / ctaTile.x;
16771724
#elif SPEC_DEC
16781725
uint32_t const nbSeqItersWithoutMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x;
16791726
#endif
@@ -1960,12 +2007,18 @@ CUBIN_EXPORT __global__
19602007
if (seqIter >= nbSeqItersWithoutMask)
19612008
{
19622009
uint32_t const nbValidCols = (warpTileTokenBeg < cacheSeqLen ? cacheSeqLen - warpTileTokenBeg : 0U);
1963-
applyMaskFromInput(warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize
19642010
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
1965-
,
1966-
tok0WinBeg, seqIter, cacheSeqLen, warpTileTokenBeg
2011+
if (rtIsReallySliding)
2012+
{
2013+
applyMaskFromInputSlidingAndSpecDec(warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen,
2014+
actualQSeqLen, headGrpSize, tok0WinBeg, seqIter, cacheSeqLen, warpTileTokenBeg);
2015+
}
2016+
else
19672017
#endif
1968-
);
2018+
{
2019+
applyMaskFromInput(
2020+
warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize);
2021+
}
19692022
}
19702023
#else
19712024
bool const isFirstIter = (seqIter == nbSkipLeadingTiles);

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -84,26 +84,8 @@ DecoderXQAImpl* DecoderXQARunner::getImplFromXQAParams(XQAParams const& xqaParam
8484
}
8585
else
8686
{
87-
if (xqaParams.multi_query_tokens)
88-
{
89-
// Some multi_query kernels are not ported to JIT yet.
90-
auto const grpSize = xqaParams.num_q_heads / xqaParams.num_kv_heads;
91-
// Hopper XQA supports spec dec with JIT, but only for E4M3 kv cache data type. Only allow 64%grpSize==0 for
92-
// now.
93-
bool const supportedByHopperXqa
94-
= (smVersion == 90 && xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3 && grpSize <= 64);
95-
bool const supportedBySm120Mla = (smVersion == 120 && xqaParams.isMLA()
96-
&& xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3);
97-
bool const supportedByAmpereXqa = (!xqaParams.isMLA() && (64 % grpSize == 0));
98-
99-
return (supportedByHopperXqa || supportedBySm120Mla || supportedByAmpereXqa) ? mJITImpl.get()
100-
: mPrecompiledImpl.get();
101-
}
102-
else
103-
{
104-
// regular decoding kernels uses JIT by default
105-
return mJITImpl.get();
106-
}
87+
// uses JIT by default
88+
return mJITImpl.get();
10789
}
10890
}
10991

0 commit comments

Comments
 (0)