@@ -540,13 +540,132 @@ __forceinline__ __device__ void copy(
540540// //////////////////////////////////////////////////////////////////////////////////////////////////
541541
542542template <
543- bool Is_even_MN=true , bool Clear_OOB_MN=false , bool Bool_to_Element= false , typename To_type= void ,
544- // typename TiledCopy,
543+ bool Is_even_MN=true , bool Clear_OOB_MN=false ,
544+ typename TiledCopy,
545545 typename Engine0, typename Layout0, typename Engine1, typename Layout1,
546546 typename Engine2, typename Layout2, typename Engine3, typename Layout3
547547>
548548__forceinline__ __device__ void copy_MN (
549- // TiledCopy tiled_copy,
549+ TiledCopy tiled_copy,
550+ Tensor<Engine0, Layout0> const &S, Tensor<Engine1, Layout1> &D,
551+ Tensor<Engine2, Layout2> const &identity_MN, Tensor<Engine3, Layout3> const &predicate_N,
552+ const int max_M=0
553+ ) {
554+ CUTE_STATIC_ASSERT_V (rank (S) == Int<3 >{}); // (MMA, MMA_M, MMA_N)
555+ CUTE_STATIC_ASSERT_V (rank (D) == Int<3 >{}); // (MMA, MMA_M, MMA_N)
556+ CUTE_STATIC_ASSERT_V (size<0 >(S) == size<0 >(D)); // MMA
557+ CUTE_STATIC_ASSERT_V (size<1 >(S) == size<1 >(D)); // MMA_M
558+ CUTE_STATIC_ASSERT_V (size<2 >(S) == size<2 >(D)); // MMA_N
559+
560+ #pragma unroll
561+ for (int m = 0 ; m < size<1 >(S); ++m) {
562+ if (Is_even_MN || get<0 >(identity_MN (0 , m, 0 )) < max_M) {
563+ #pragma unroll
564+ for (int n = 0 ; n < size<2 >(S); ++n) {
565+ if (Is_even_MN || predicate_N (n)) {
566+ cute::copy (tiled_copy, S (_, m, n), D (_, m, n));
567+ } else if (Clear_OOB_MN) {
568+ cute::clear (D (_, m, n));
569+ }
570+ }
571+ } else if (Clear_OOB_MN) {
572+ cute::clear (D (_, m, _));
573+ }
574+ }
575+ }
576+
577+ // //////////////////////////////////////////////////////////////////////////////////////////////////
578+
579+ template <
580+ bool Is_even_MN=true , bool Clear_OOB_MN=false ,
581+ typename TiledCopy,
582+ typename Engine0, typename Layout0, typename Engine1, typename Layout1,
583+ typename Engine2, typename Layout2, typename Engine3, typename Layout3
584+ >
585+ __forceinline__ __device__ void copy_mask (
586+ TiledCopy tiled_copy,
587+ Tensor<Engine0, Layout0> const &S, Tensor<Engine1, Layout1> &D,
588+ Tensor<Engine2, Layout2> const &identity_MN, Tensor<Engine3, Layout3> const &predicate_N,
589+ const int max_M=0
590+ ) {
591+ CUTE_STATIC_ASSERT_V (rank (S) == Int<3 >{}); // (MMA, MMA_M, MMA_N)
592+ CUTE_STATIC_ASSERT_V (rank (D) == Int<3 >{}); // (MMA, MMA_M, MMA_N)
593+ CUTE_STATIC_ASSERT_V (size<0 >(S) == size<0 >(D)); // MMA
594+ CUTE_STATIC_ASSERT_V (size<1 >(S) == size<1 >(D)); // MMA_M
595+ CUTE_STATIC_ASSERT_V (size<2 >(S) == size<2 >(D)); // MMA_N
596+
597+ #pragma unroll
598+ for (int m = 0 ; m < size<1 >(S); ++m) {
599+ if (Is_even_MN || get<0 >(identity_MN (0 , m, 0 )) < max_M) {
600+ #pragma unroll
601+ for (int n = 0 ; n < size<2 >(S); ++n) {
602+ if (Is_even_MN || predicate_N (n)) {
603+ cute::copy (tiled_copy, S (_, m, n), D (_, m, n));
604+ } else if (Clear_OOB_MN) {
605+ cute::clear (D (_, m, n));
606+ }
607+ }
608+ } else if (Clear_OOB_MN) {
609+ cute::clear (D (_, m, _));
610+ }
611+ }
612+ }
613+
614+ // //////////////////////////////////////////////////////////////////////////////////////////////////
615+
616+ template <
617+ bool Is_even_MN=true , bool Clear_OOB_MN=false , typename To_type=void ,
618+ typename TiledCopy,
619+ typename Engine0, typename Layout0, typename Engine1, typename Layout1,
620+ typename Engine2, typename Layout2, typename Engine3, typename Layout3
621+ >
622+ __forceinline__ __device__ void copy_mask_with_or_reduce (
623+ TiledCopy tiled_copy,
624+ Tensor<Engine0, Layout0> const &S, Tensor<Engine1, Layout1> &D,
625+ bool &block_active,
626+ Tensor<Engine2, Layout2> const &identity_MN, Tensor<Engine3, Layout3> const &predicate_N,
627+ const int max_M=0
628+ ) {
629+ CUTE_STATIC_ASSERT_V (rank (S) == Int<3 >{}); // (MMA, MMA_M, MMA_N)
630+ CUTE_STATIC_ASSERT_V (rank (D) == Int<3 >{}); // (MMA, MMA_M, MMA_N)
631+ CUTE_STATIC_ASSERT_V (size<0 >(S) == size<0 >(D)); // MMA
632+ CUTE_STATIC_ASSERT_V (size<1 >(S) == size<1 >(D)); // MMA_M
633+ CUTE_STATIC_ASSERT_V (size<2 >(S) == size<2 >(D)); // MMA_N
634+
635+ bool any_active = false ;
636+ #pragma unroll
637+ for (int m = 0 ; m < size<1 >(S); ++m) {
638+ if (Is_even_MN || get<0 >(identity_MN (0 , m, 0 )) < max_M) {
639+ #pragma unroll
640+ for (int n = 0 ; n < size<2 >(S); ++n) {
641+ if (Is_even_MN || predicate_N (n)) {
642+ #pragma unroll
643+ for (int i = 0 ; i < size<0 >(S); ++i) {
644+ any_active |= S (i, m, n);
645+ D (i, m, n) = static_cast <To_type>(S (i, m, n));
646+ }
647+ } else if (Clear_OOB_MN) {
648+ cute::clear (D (_, m, n));
649+ }
650+ }
651+ } else if (Clear_OOB_MN) {
652+ cute::clear (D (_, m, _));
653+ }
654+ }
655+
656+ block_active = __syncthreads_or (any_active);
657+ }
658+
659+ // //////////////////////////////////////////////////////////////////////////////////////////////////
660+
661+ template <
662+ bool Is_even_MN=true , bool Clear_OOB_MN=false ,
663+ typename TiledCopy,
664+ typename Engine0, typename Layout0, typename Engine1, typename Layout1,
665+ typename Engine2, typename Layout2, typename Engine3, typename Layout3
666+ >
667+ __forceinline__ __device__ void copy_bias (
668+ TiledCopy tiled_copy,
550669 Tensor<Engine0, Layout0> const &S, Tensor<Engine1, Layout1> &D,
551670 Tensor<Engine2, Layout2> const &identity_MN, Tensor<Engine3, Layout3> const &predicate_N,
552671 const int max_M=0
@@ -563,18 +682,10 @@ __forceinline__ __device__ void copy_MN(
563682 #pragma unroll
564683 for (int n = 0 ; n < size<2 >(S); ++n) {
565684 if (Is_even_MN || predicate_N (n)) {
566- if constexpr (Bool_to_Element) {
567- #pragma unroll
568- for (int i = 0 ; i < size<0 >(S); ++i) {
569- D (i, m, n) = static_cast <bool >(S (i, m, n)) ? To_type (1 ) : To_type (0 );
570- }
571- } else {
572- // Using vectorized load will cause out-of-bounds access when !Is_even_MN && !predicate_N(n)
573- // cute::copy(tiled_copy, S(_, m, n), D(_, m, n));
574- #pragma unroll
575- for (int i = 0 ; i < size<0 >(S); ++i) {
576- D (i, m, n) = S (i, m, n);
577- }
685+ // cute::copy(tiled_copy, S(_, m, n), D(_, m, n));
686+ #pragma unroll
687+ for (int i = 0 ; i < size<0 >(S); ++i) {
688+ D (i, m, n) = S (i, m, n);
578689 }
579690 } else if (Clear_OOB_MN) {
580691 cute::clear (D (_, m, n));
0 commit comments