Skip to content

Commit a148a3a

Browse files
committed
Refactors copy function into specialized variants
Splits the generic copy_MN function into four specialized functions: - copy_MN for basic tensor copying with tiled copy operations - copy_mask for masked copying operations - copy_mask_with_or_reduce for copying with OR reduction and block activity tracking - copy_bias for bias-specific copying with element-wise assignment Removes the Bool_to_Element template parameter and related conditional logic, simplifying the codebase by creating purpose-specific functions instead of a single overloaded function with multiple behaviors.
1 parent 510aaf5 commit a148a3a

File tree

1 file changed

+126
-15
lines changed

1 file changed

+126
-15
lines changed

csrc/flash_dmattn/src/utils.h

Lines changed: 126 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -540,13 +540,132 @@ __forceinline__ __device__ void copy(
540540
////////////////////////////////////////////////////////////////////////////////////////////////////
541541

542542
template <
543-
bool Is_even_MN=true, bool Clear_OOB_MN=false, bool Bool_to_Element=false, typename To_type=void,
544-
// typename TiledCopy,
543+
bool Is_even_MN=true, bool Clear_OOB_MN=false,
544+
typename TiledCopy,
545545
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
546546
typename Engine2, typename Layout2, typename Engine3, typename Layout3
547547
>
548548
__forceinline__ __device__ void copy_MN(
549-
// TiledCopy tiled_copy,
549+
TiledCopy tiled_copy,
550+
Tensor<Engine0, Layout0> const &S, Tensor<Engine1, Layout1> &D,
551+
Tensor<Engine2, Layout2> const &identity_MN, Tensor<Engine3, Layout3> const &predicate_N,
552+
const int max_M=0
553+
) {
554+
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); // (MMA, MMA_M, MMA_N)
555+
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); // (MMA, MMA_M, MMA_N)
556+
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
557+
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
558+
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_N
559+
560+
#pragma unroll
561+
for (int m = 0; m < size<1>(S); ++m) {
562+
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_M) {
563+
#pragma unroll
564+
for (int n = 0; n < size<2>(S); ++n) {
565+
if (Is_even_MN || predicate_N(n)) {
566+
cute::copy(tiled_copy, S(_, m, n), D(_, m, n));
567+
} else if (Clear_OOB_MN) {
568+
cute::clear(D(_, m, n));
569+
}
570+
}
571+
} else if (Clear_OOB_MN) {
572+
cute::clear(D(_, m, _));
573+
}
574+
}
575+
}
576+
577+
////////////////////////////////////////////////////////////////////////////////////////////////////
578+
579+
template <
580+
bool Is_even_MN=true, bool Clear_OOB_MN=false,
581+
typename TiledCopy,
582+
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
583+
typename Engine2, typename Layout2, typename Engine3, typename Layout3
584+
>
585+
__forceinline__ __device__ void copy_mask(
586+
TiledCopy tiled_copy,
587+
Tensor<Engine0, Layout0> const &S, Tensor<Engine1, Layout1> &D,
588+
Tensor<Engine2, Layout2> const &identity_MN, Tensor<Engine3, Layout3> const &predicate_N,
589+
const int max_M=0
590+
) {
591+
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); // (MMA, MMA_M, MMA_N)
592+
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); // (MMA, MMA_M, MMA_N)
593+
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
594+
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
595+
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_N
596+
597+
#pragma unroll
598+
for (int m = 0; m < size<1>(S); ++m) {
599+
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_M) {
600+
#pragma unroll
601+
for (int n = 0; n < size<2>(S); ++n) {
602+
if (Is_even_MN || predicate_N(n)) {
603+
cute::copy(tiled_copy, S(_, m, n), D(_, m, n));
604+
} else if (Clear_OOB_MN) {
605+
cute::clear(D(_, m, n));
606+
}
607+
}
608+
} else if (Clear_OOB_MN) {
609+
cute::clear(D(_, m, _));
610+
}
611+
}
612+
}
613+
614+
////////////////////////////////////////////////////////////////////////////////////////////////////
615+
616+
template <
617+
bool Is_even_MN=true, bool Clear_OOB_MN=false, typename To_type=void,
618+
typename TiledCopy,
619+
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
620+
typename Engine2, typename Layout2, typename Engine3, typename Layout3
621+
>
622+
__forceinline__ __device__ void copy_mask_with_or_reduce(
623+
TiledCopy tiled_copy,
624+
Tensor<Engine0, Layout0> const &S, Tensor<Engine1, Layout1> &D,
625+
bool &block_active,
626+
Tensor<Engine2, Layout2> const &identity_MN, Tensor<Engine3, Layout3> const &predicate_N,
627+
const int max_M=0
628+
) {
629+
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); // (MMA, MMA_M, MMA_N)
630+
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); // (MMA, MMA_M, MMA_N)
631+
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
632+
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
633+
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_N
634+
635+
bool any_active = false;
636+
#pragma unroll
637+
for (int m = 0; m < size<1>(S); ++m) {
638+
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_M) {
639+
#pragma unroll
640+
for (int n = 0; n < size<2>(S); ++n) {
641+
if (Is_even_MN || predicate_N(n)) {
642+
#pragma unroll
643+
for (int i = 0; i < size<0>(S); ++i) {
644+
any_active |= S(i, m, n);
645+
D(i, m, n) = static_cast<To_type>(S(i, m, n));
646+
}
647+
} else if (Clear_OOB_MN) {
648+
cute::clear(D(_, m, n));
649+
}
650+
}
651+
} else if (Clear_OOB_MN) {
652+
cute::clear(D(_, m, _));
653+
}
654+
}
655+
656+
block_active = __syncthreads_or(any_active);
657+
}
658+
659+
////////////////////////////////////////////////////////////////////////////////////////////////////
660+
661+
template <
662+
bool Is_even_MN=true, bool Clear_OOB_MN=false,
663+
typename TiledCopy,
664+
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
665+
typename Engine2, typename Layout2, typename Engine3, typename Layout3
666+
>
667+
__forceinline__ __device__ void copy_bias(
668+
TiledCopy tiled_copy,
550669
Tensor<Engine0, Layout0> const &S, Tensor<Engine1, Layout1> &D,
551670
Tensor<Engine2, Layout2> const &identity_MN, Tensor<Engine3, Layout3> const &predicate_N,
552671
const int max_M=0
@@ -563,18 +682,10 @@ __forceinline__ __device__ void copy_MN(
563682
#pragma unroll
564683
for (int n = 0; n < size<2>(S); ++n) {
565684
if (Is_even_MN || predicate_N(n)) {
566-
if constexpr (Bool_to_Element) {
567-
#pragma unroll
568-
for (int i = 0; i < size<0>(S); ++i) {
569-
D(i, m, n) = static_cast<bool>(S(i, m, n)) ? To_type(1) : To_type(0);
570-
}
571-
} else {
572-
// Using vectorized load will cause out-of-bounds access when !Is_even_MN && !predicate_N(n)
573-
// cute::copy(tiled_copy, S(_, m, n), D(_, m, n));
574-
#pragma unroll
575-
for (int i = 0; i < size<0>(S); ++i) {
576-
D(i, m, n) = S(i, m, n);
577-
}
685+
// cute::copy(tiled_copy, S(_, m, n), D(_, m, n));
686+
#pragma unroll
687+
for (int i = 0; i < size<0>(S); ++i) {
688+
D(i, m, n) = S(i, m, n);
578689
}
579690
} else if (Clear_OOB_MN) {
580691
cute::clear(D(_, m, n));

0 commit comments

Comments
 (0)