Skip to content

Commit 60a7fde

Browse files
committed
Adds boolean-to-element conversion support in copy function
Introduces template parameters to enable converting boolean values to numeric elements during copy operations. Adds conditional logic that converts true values to 1.0f and false values to 0.0f when the Bool_to_Element flag is enabled, allowing for more flexible data type transformations in memory copy routines.
1 parent ab06c18 commit 60a7fde

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

csrc/flash_dmattn/src/utils.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ __forceinline__ __device__ void copy(
521521
////////////////////////////////////////////////////////////////////////////////////////////////////
522522

523523
template <
524-
bool Is_even_MN=true, bool Clear_OOB_MN=true,
524+
bool Is_even_MN=true, bool Clear_OOB_MN=true, bool Bool_to_Element=false, typename To_type=void,
525525
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
526526
typename Engine2, typename Layout2
527527
>
@@ -543,7 +543,16 @@ __forceinline__ __device__ void copy_MN(
543543
#pragma unroll
544544
for (int n = 0; n < size<2>(S); ++n) {
545545
if (Is_even_MN || get<1>(identity_MN(0, m, n)) < max_N) {
546-
cute::copy(tiled_copy, S(_, m, n), D(_, m, n));
546+
if constexpr (Bool_to_Element) {
547+
#pragma unroll
548+
for (int i = 0; i < size<0>(S); ++i) {
549+
D(i, m, n) = static_cast<bool>(S(i, m, n))
550+
? static_cast<To_type>(1.0f)
551+
: static_cast<To_type>(0.0f);
552+
}
553+
} else {
554+
cute::copy(tiled_copy, S(_, m, n), D(_, m, n));
555+
}
547556
} else if (Clear_OOB_MN) {
548557
cute::clear(D(_, m, n));
549558
}

0 commit comments

Comments
 (0)