@@ -465,33 +465,24 @@ using WarpAcc = WarpAccT<warpTile.y, warpTile.x>;
465465#if SPEC_DEC
466466#define MMAS_N_PER_MASK 2
467467
468- __device__ inline void applyMaskFromInput (Warp const & warp, WarpAcc& acc, MaskType const * mask, uint32_t rowOffset,
469- uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize
470468#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
471- ,
472- int32_t tok0WinBeg, uint32_t seqIter, uint32_t const cacheSeqLen, uint32_t const warpTileTokenBeg
473- #endif
474- )
469+ __device__ inline void applyMaskFromInputSlidingAndSpecDec (Warp const & warp, WarpAcc& acc, MaskType const * mask,
470+ uint32_t rowOffset, uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize,
471+ int32_t tok0WinBeg, uint32_t seqIter, uint32_t const cacheSeqLen, uint32_t const warpTileTokenBeg)
475472{
476473 uint32_t const idxInQuad = laneId () % 4 ;
477474 uint32_t const idxQuad = laneId () / 4 ;
478475 // Packed mask is aligned with 32 bits (2 uint16_t).
479476 uint32_t const nbPackedMasksPerRow = divUp (qSeqLen, 32u ) * 2u ;
480477 uint16_t const * uint16Mask = reinterpret_cast <uint16_t const *>(mask);
481478 constexpr uint64_t fullMask = ~uint64_t {0 };
482- #if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
483479 Range const tileRange = {warpTileTokenBeg, warpTileTokenBeg + warpTile.x };
484480 Range const maxMaskOutRange = {0 , mha::max (0 , tok0WinBeg) + (nbValidRows / MMAS_N_PER_MASK - 1 )};
485481 bool const ctaNeedBegMask = tileRange.beg < maxMaskOutRange.end ;
486482 assert (ctaNeedBegMask == overlap (tileRange, maxMaskOutRange));
487483 int32_t const tok0NbMaskOut = int32_t (tok0WinBeg) - int32_t (warpTileTokenBeg);
488484 uint32_t const nbSeqItersWithoutSpecDecMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x ;
489485 bool const ctaNeedSpecDecMask = (seqIter >= nbSeqItersWithoutSpecDecMask);
490- #else
491- constexpr bool ctaNeedBegMask = false ;
492- bool const ctaNeedSpecDecMask = true ;
493- int32_t const tok0NbMaskOut = -2147483648 ;
494- #endif
495486 bool const needMask = ctaNeedBegMask || ctaNeedSpecDecMask;
496487
497488 if (!needMask)
@@ -559,6 +550,61 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
559550}
560551#endif
561552
553+ __device__ inline void applyMaskFromInput (Warp const & warp, WarpAcc& acc, MaskType const * mask, uint32_t rowOffset,
554+ uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize)
555+ {
556+ uint32_t const idxInQuad = laneId () % 4 ;
557+ uint32_t const idxQuad = laneId () / 4 ;
558+ // Packed mask is aligned with 32 bits (2 uint16_t).
559+ uint32_t const nbPackedMasksPerRow = divUp (qSeqLen, 32u ) * 2u ;
560+ uint16_t const * uint16Mask = reinterpret_cast <uint16_t const *>(mask);
561+ #pragma unroll
562+ for (uint32_t m = 0 ; m < acc.rows ; m++)
563+ {
564+ #pragma unroll
565+ for (uint32_t i = 0 ; i < InstAcc::rows; i++)
566+ {
567+ uint32_t const tokenRow = min ((rowOffset + instM * m + idxQuad + i * 8 ) / headGrpSize, actualQSeqLen - 1 );
568+ #pragma unroll
569+ for (uint32_t mask_n = 0 ; mask_n < acc.cols / MMAS_N_PER_MASK; mask_n++)
570+ {
571+ uint32_t const firstCol = instN * mask_n * MMAS_N_PER_MASK + InstAcc::cols * idxInQuad;
572+ uint32_t const lastCol = firstCol + instN * (MMAS_N_PER_MASK - 1 ) + InstAcc::cols - 1 ;
573+ uint32_t const maskPos0 = firstCol + actualQSeqLen < nbValidCols
574+ ? 0u
575+ : min (firstCol + actualQSeqLen - nbValidCols, actualQSeqLen - 1 );
576+ uint32_t const maskPos1 = lastCol + actualQSeqLen < nbValidCols
577+ ? 0u
578+ : min (lastCol + actualQSeqLen - nbValidCols, actualQSeqLen - 1 );
579+ uint32_t packedMask = 0u ;
580+ uint32_t const maskPosStart = (maskPos0 / 16 ) * 16 ;
581+ reinterpret_cast <uint16_t *>(&packedMask)[0 ]
582+ = uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos0 / 16 )];
583+ reinterpret_cast <uint16_t *>(&packedMask)[1 ]
584+ = uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos1 / 16 )];
585+ #pragma unroll
586+ for (uint32_t nj = 0 ; nj < MMAS_N_PER_MASK; nj++)
587+ {
588+ #pragma unroll
589+ for (uint32_t j = 0 ; j < InstAcc::cols; j++)
590+ {
591+ uint32_t const n = (mask_n * MMAS_N_PER_MASK + nj);
592+ uint32_t const col = instN * n + InstAcc::cols * idxInQuad + j;
593+ // bool const maskFlag = col + qSeqLen < nbValidCols ? true : mask[tokenRow * qSeqLen + (col +
594+ // qSeqLen - nbValidCols)];
595+ bool const maskFlag = col + actualQSeqLen < nbValidCols
596+ ? true
597+ : packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart));
598+ acc (m, n)(i, j) = maskFlag && col < nbValidCols ? acc (m, n)(i, j) : safeInitRowMax;
599+ }
600+ }
601+ }
602+ }
603+ }
604+ }
605+
606+ #endif
607+
562608__device__ inline QuadRegRowMax warpTileOnlineSoftmax (Warp const & warp, QuadRegRowMax const & rowMaxHint, WarpAcc& acc)
563609{
564610 QuadRegRowMax rowMax = rowMaxHint;
@@ -1655,7 +1701,7 @@ CUBIN_EXPORT __global__
16551701 uint32_t const tok0SeqLen = cacheSeqLen - actualQSeqLen + 1 + idxHeadTokenInGrp; // ctaTokOffset;
16561702 int32_t const tok0WinBeg = int32_t (tok0SeqLen) - int32_t (slidingWinSize);
16571703 uint32_t const nbTotalSkipTokens = mha::max (0 , tok0WinBeg);
1658-
1704+ bool const rtIsReallySliding = (cacheSeqLen + actualQSeqLen > slidingWinSize);
16591705#elif SLIDING_WINDOW
16601706 bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize);
16611707 assert (!SPEC_DEC || !rtIsReallySliding);
@@ -1673,7 +1719,8 @@ CUBIN_EXPORT __global__
16731719
16741720 uint32_t const nbSeqIters = useKVCache ? divUp (cacheSeqLen, ctaTile.x ) : 0 ;
16751721#if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE
1676- uint32_t const nbSeqItersWithoutMask = nbSkipLeadingTiles;
1722+ uint32_t const nbSeqItersWithoutMask
1723+ = rtIsReallySliding ? nbSkipLeadingTiles : (cacheSeqLen - actualQSeqLen) / ctaTile.x ;
16771724#elif SPEC_DEC
16781725 uint32_t const nbSeqItersWithoutMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x ;
16791726#endif
@@ -1960,12 +2007,18 @@ CUBIN_EXPORT __global__
19602007 if (seqIter >= nbSeqItersWithoutMask)
19612008 {
19622009 uint32_t const nbValidCols = (warpTileTokenBeg < cacheSeqLen ? cacheSeqLen - warpTileTokenBeg : 0U );
1963- applyMaskFromInput (warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize
19642010#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
1965- ,
1966- tok0WinBeg, seqIter, cacheSeqLen, warpTileTokenBeg
2011+ if (rtIsReallySliding)
2012+ {
2013+ applyMaskFromInputSlidingAndSpecDec (warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen,
2014+ actualQSeqLen, headGrpSize, tok0WinBeg, seqIter, cacheSeqLen, warpTileTokenBeg);
2015+ }
2016+ else
19672017#endif
1968- );
2018+ {
2019+ applyMaskFromInput (
2020+ warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize);
2021+ }
19692022 }
19702023#else
19712024 bool const isFirstIter = (seqIter == nbSkipLeadingTiles);
0 commit comments