Skip to content

Commit 0f42a24

Browse files
authored
[None][feat] Fix attention sink load in xqa (#8836)
Signed-off-by: Qidi Sang <[email protected]>
1 parent 6d6797c commit 0f42a24

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

cpp/kernels/xqa/mha_sm90.cu

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2078,9 +2078,13 @@ __device__ inline RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gme
20782078
for (uint32_t i = 0; i < exactDiv(ShmQWiseVec::size, gmma::instNBase); i++)
20792079
{
20802080
static_assert(nbThrdsPerInstNBase * RegColWiseVec::size == exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols));
2081-
ret[i] = reinterpret_cast<
2082-
Vec<Vec<float, GmmaAccCoreMat::cols>, exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)> const&>(
2083-
gmemVec)[mha::min(i * nbThrdsPerInstNBase + idx, bound)];
2081+
uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound);
2082+
uint32_t const baseOffset = clampedIdx * GmmaAccCoreMat::cols;
2083+
#pragma unroll
2084+
for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++)
2085+
{
2086+
ret[i][j] = gmemVec[baseOffset + j];
2087+
}
20842088
}
20852089
return ret;
20862090
}

0 commit comments

Comments
 (0)