Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions cpp/kernels/xqa/gen_cubins.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@

#include "tensorrt_llm/common/config.h"

TRTLLM_NAMESPACE_BEGIN
namespace tensorrt_llm
{
namespace kernels
{
// clang-format off
Expand All @@ -98,7 +99,7 @@
cpp_file_suffex_text = R"""
// clang-format on
} // namespace kernels
TRTLLM_NAMESPACE_END
}
"""

cubin_meta_info_struct_prefix_text = R"""
Expand Down Expand Up @@ -438,8 +439,9 @@ def generate_header_file_contents(
CompileMacroOption('HEAD_ELEMS', 'd', [128]),
CompileMacroOption('BEAM_WIDTH', 'beam', [1]),
CompileMacroOption('CACHE_ELEM_ENUM', 'kvt', [0, 1, 2]),
CompileMacroOption('TOKENS_PER_PAGE', 'pagedKV',
[0, 64, 128]), # 0 denotes contiguous kv cache.
CompileMacroOption(
'TOKENS_PER_PAGE', 'pagedKV',
[0, 32, 64, 128]), # 0 denotes contiguous kv cache.
CompileMacroOption('HEAD_GRP_SIZE', 'nqpkv', [0]),
CompileMacroOption('M_TILESIZE', 'm', [16, 32]),
]]
Expand Down
89 changes: 71 additions & 18 deletions cpp/kernels/xqa/mha.cu
Original file line number Diff line number Diff line change
Expand Up @@ -465,33 +465,24 @@ using WarpAcc = WarpAccT<warpTile.y, warpTile.x>;
#if SPEC_DEC
#define MMAS_N_PER_MASK 2

__device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskType const* mask, uint32_t rowOffset,
uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
,
int32_t tok0WinBeg, uint32_t seqIter, uint32_t const cacheSeqLen, uint32_t const warpTileTokenBeg
#endif
)
__device__ inline void applyMaskFromInputSlidingAndSpecDec(Warp const& warp, WarpAcc& acc, MaskType const* mask,
uint32_t rowOffset, uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize,
int32_t tok0WinBeg, uint32_t seqIter, uint32_t const cacheSeqLen, uint32_t const warpTileTokenBeg)
{
uint32_t const idxInQuad = laneId() % 4;
uint32_t const idxQuad = laneId() / 4;
// Packed mask is aligned with 32 bits (2 uint16_t).
uint32_t const nbPackedMasksPerRow = divUp(qSeqLen, 32u) * 2u;
uint16_t const* uint16Mask = reinterpret_cast<uint16_t const*>(mask);
constexpr uint64_t fullMask = ~uint64_t{0};
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
Range const tileRange = {warpTileTokenBeg, warpTileTokenBeg + warpTile.x};
Range const maxMaskOutRange = {0, mha::max(0, tok0WinBeg) + (nbValidRows / MMAS_N_PER_MASK - 1)};
bool const ctaNeedBegMask = tileRange.beg < maxMaskOutRange.end;
assert(ctaNeedBegMask == overlap(tileRange, maxMaskOutRange));
int32_t const tok0NbMaskOut = int32_t(tok0WinBeg) - int32_t(warpTileTokenBeg);
uint32_t const nbSeqItersWithoutSpecDecMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x;
bool const ctaNeedSpecDecMask = (seqIter >= nbSeqItersWithoutSpecDecMask);
#else
constexpr bool ctaNeedBegMask = false;
bool const ctaNeedSpecDecMask = true;
int32_t const tok0NbMaskOut = -2147483648;
#endif
bool const needMask = ctaNeedBegMask || ctaNeedSpecDecMask;

if (!needMask)
Expand Down Expand Up @@ -559,6 +550,61 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
}
#endif

__device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskType const* mask, uint32_t rowOffset,
uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize)
{
uint32_t const idxInQuad = laneId() % 4;
uint32_t const idxQuad = laneId() / 4;
// Packed mask is aligned with 32 bits (2 uint16_t).
uint32_t const nbPackedMasksPerRow = divUp(qSeqLen, 32u) * 2u;
uint16_t const* uint16Mask = reinterpret_cast<uint16_t const*>(mask);
#pragma unroll
for (uint32_t m = 0; m < acc.rows; m++)
{
#pragma unroll
for (uint32_t i = 0; i < InstAcc::rows; i++)
{
uint32_t const tokenRow = min((rowOffset + instM * m + idxQuad + i * 8) / headGrpSize, actualQSeqLen - 1);
#pragma unroll
for (uint32_t mask_n = 0; mask_n < acc.cols / MMAS_N_PER_MASK; mask_n++)
{
uint32_t const firstCol = instN * mask_n * MMAS_N_PER_MASK + InstAcc::cols * idxInQuad;
uint32_t const lastCol = firstCol + instN * (MMAS_N_PER_MASK - 1) + InstAcc::cols - 1;
uint32_t const maskPos0 = firstCol + actualQSeqLen < nbValidCols
? 0u
: min(firstCol + actualQSeqLen - nbValidCols, actualQSeqLen - 1);
uint32_t const maskPos1 = lastCol + actualQSeqLen < nbValidCols
? 0u
: min(lastCol + actualQSeqLen - nbValidCols, actualQSeqLen - 1);
uint32_t packedMask = 0u;
uint32_t const maskPosStart = (maskPos0 / 16) * 16;
reinterpret_cast<uint16_t*>(&packedMask)[0]
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos0 / 16)];
reinterpret_cast<uint16_t*>(&packedMask)[1]
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos1 / 16)];
#pragma unroll
for (uint32_t nj = 0; nj < MMAS_N_PER_MASK; nj++)
{
#pragma unroll
for (uint32_t j = 0; j < InstAcc::cols; j++)
{
uint32_t const n = (mask_n * MMAS_N_PER_MASK + nj);
uint32_t const col = instN * n + InstAcc::cols * idxInQuad + j;
// bool const maskFlag = col + qSeqLen < nbValidCols ? true : mask[tokenRow * qSeqLen + (col +
// qSeqLen - nbValidCols)];
bool const maskFlag = col + actualQSeqLen < nbValidCols
? true
: packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart));
acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax;
}
}
}
}
}
}

#endif

__device__ inline QuadRegRowMax warpTileOnlineSoftmax(Warp const& warp, QuadRegRowMax const& rowMaxHint, WarpAcc& acc)
{
QuadRegRowMax rowMax = rowMaxHint;
Expand Down Expand Up @@ -1655,7 +1701,7 @@ CUBIN_EXPORT __global__
uint32_t const tok0SeqLen = cacheSeqLen - actualQSeqLen + 1 + idxHeadTokenInGrp; // ctaTokOffset;
int32_t const tok0WinBeg = int32_t(tok0SeqLen) - int32_t(slidingWinSize);
uint32_t const nbTotalSkipTokens = mha::max(0, tok0WinBeg);

bool const rtIsReallySliding = (cacheSeqLen + actualQSeqLen > slidingWinSize);
#elif SLIDING_WINDOW
bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize);
assert(!SPEC_DEC || !rtIsReallySliding);
Expand All @@ -1673,7 +1719,8 @@ CUBIN_EXPORT __global__

uint32_t const nbSeqIters = useKVCache ? divUp(cacheSeqLen, ctaTile.x) : 0;
#if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE
uint32_t const nbSeqItersWithoutMask = nbSkipLeadingTiles;
uint32_t const nbSeqItersWithoutMask
= rtIsReallySliding ? nbSkipLeadingTiles : (cacheSeqLen - actualQSeqLen) / ctaTile.x;
#elif SPEC_DEC
uint32_t const nbSeqItersWithoutMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x;
#endif
Expand Down Expand Up @@ -1960,12 +2007,18 @@ CUBIN_EXPORT __global__
if (seqIter >= nbSeqItersWithoutMask)
{
uint32_t const nbValidCols = (warpTileTokenBeg < cacheSeqLen ? cacheSeqLen - warpTileTokenBeg : 0U);
applyMaskFromInput(warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
,
tok0WinBeg, seqIter, cacheSeqLen, warpTileTokenBeg
if (rtIsReallySliding)
{
applyMaskFromInputSlidingAndSpecDec(warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen,
actualQSeqLen, headGrpSize, tok0WinBeg, seqIter, cacheSeqLen, warpTileTokenBeg);
}
else
#endif
);
{
applyMaskFromInput(
warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize);
}
}
#else
bool const isFirstIter = (seqIter == nbSkipLeadingTiles);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,26 +84,8 @@ DecoderXQAImpl* DecoderXQARunner::getImplFromXQAParams(XQAParams const& xqaParam
}
else
{
if (xqaParams.multi_query_tokens)
{
// Some multi_query kernels are not ported to JIT yet.
auto const grpSize = xqaParams.num_q_heads / xqaParams.num_kv_heads;
// Hopper XQA supports spec dec with JIT, but only for E4M3 kv cache data type. Only allow 64%grpSize==0 for
// now.
bool const supportedByHopperXqa
= (smVersion == 90 && xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3 && grpSize <= 64);
bool const supportedBySm120Mla = (smVersion == 120 && xqaParams.isMLA()
&& xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3);
bool const supportedByAmpereXqa = (!xqaParams.isMLA() && (64 % grpSize == 0));

return (supportedByHopperXqa || supportedBySm120Mla || supportedByAmpereXqa) ? mJITImpl.get()
: mPrecompiledImpl.get();
}
else
{
// regular decoding kernels uses JIT by default
return mJITImpl.get();
}
// uses JIT by default
return mJITImpl.get();
}
}

Expand Down
Loading