Skip to content

Commit 8a3bb04

Browse files
committed
Fixes out-of-bounds access in copy function
Replaces vectorized copy with element-wise assignment to prevent memory access violations when bounds checking is disabled. Changes predicate handling to use dedicated predicate tensor instead of coordinate-based bounds checking for improved safety. Updates default Clear_OOB_MN to false and removes max_N parameter as bounds checking now relies on predicate tensor.
1 parent c82f7dc commit 8a3bb04

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

csrc/flash_dmattn/src/utils.h

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -521,15 +521,16 @@ __forceinline__ __device__ void copy(
521521
////////////////////////////////////////////////////////////////////////////////////////////////////
522522

523523
template <
524-
bool Is_even_MN=true, bool Clear_OOB_MN=true, bool Bool_to_Element=false, typename To_type=void,
525-
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
526-
typename Engine2, typename Layout2
524+
bool Is_even_MN=true, bool Clear_OOB_MN=false, bool Bool_to_Element=false, typename To_type=void,
525+
// typename TiledCopy,
526+
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
527+
typename Engine2, typename Layout2, typename Engine3, typename Layout3
527528
>
528529
__forceinline__ __device__ void copy_MN(
529-
TiledCopy tiled_copy,
530+
// TiledCopy tiled_copy,
530531
Tensor<Engine0, Layout0> const &S, Tensor<Engine1, Layout1> &D,
531-
Tensor<Engine2, Layout2> const &identity_MN,
532-
const int max_M=0, const int max_N=0
532+
Tensor<Engine2, Layout2> const &identity_MN, Tensor<Engine3, Layout3> const &predicate_N,
533+
const int max_M=0
533534
) {
534535
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); // (MMA, MMA_M, MMA_N)
535536
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); // (MMA, MMA_M, MMA_N)
@@ -542,14 +543,19 @@ __forceinline__ __device__ void copy_MN(
542543
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_M) {
543544
#pragma unroll
544545
for (int n = 0; n < size<2>(S); ++n) {
545-
if (Is_even_MN || get<1>(identity_MN(0, m, n)) < max_N) {
546+
if (Is_even_MN || predicate_N(n)) {
546547
if constexpr (Bool_to_Element) {
547548
#pragma unroll
548549
for (int i = 0; i < size<0>(S); ++i) {
549550
D(i, m, n) = static_cast<bool>(S(i, m, n)) ? To_type(1) : To_type(0);
550551
}
551552
} else {
552-
cute::copy(tiled_copy, S(_, m, n), D(_, m, n));
553+
// Using vectorized load will cause out-of-bounds access when !Is_even_MN && !predicate_N(n)
554+
// cute::copy(tiled_copy, S(_, m, n), D(_, m, n));
555+
#pragma unroll
556+
for (int i = 0; i < size<0>(S); ++i) {
557+
D(i, m, n) = S(i, m, n);
558+
}
553559
}
554560
} else if (Clear_OOB_MN) {
555561
cute::clear(D(_, m, n));

0 commit comments

Comments
 (0)