@@ -101,6 +101,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
101101 if (n_block * kBlockN >= binfo.actual_seqlen_k ) return ;
102102
103103 int m_block_max = cute::ceil_div (binfo.actual_seqlen_q , kBlockM );
104+ bool accum_dbias = Has_bias && (params.dbias_row_stride == 0 ) && (binfo.actual_seqlen_q > 1 );
104105
105106 const index_t row_offset_q = binfo.q_offset (params.q_batch_stride , params.q_row_stride , bidb)
106107 + (m_block_max - 1 ) * kBlockM * params.q_row_stride + bidh * params.q_head_stride ;
@@ -159,10 +160,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
159160 Shape<Int<kBlockM >, Int<kBlockN >>{},
160161 make_stride (params.dbias_row_stride , _1{})
161162 );
162- [[maybe_unused]] ElementAccum *gdBias_accum_ptr = nullptr ;
163- if constexpr (Has_bias) {
164- gdBias_accum_ptr = reinterpret_cast <ElementAccum *>(params.dbias_ptr ) + row_offset_dbias;
165- }
166163 Tensor gdO = make_tensor (
167164 make_gmem_ptr (reinterpret_cast <Element *>(params.do_ptr ) + row_offset_do),
168165 Shape<Int<kBlockM >, Int<kHeadDim >>{},
@@ -287,8 +284,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
287284 GmemTiledCopydO gmem_tiled_copy_dO;
288285 auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice (tidx);
289286 typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
290- typename Kernel_traits::GmemTiledCopydBias gmem_tiled_copy_dBias;
291- auto gmem_thr_copy_dBias = gmem_tiled_copy_dBias.get_thread_slice (tidx);
292287 auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice (tidx);
293288 using GmemLayoutAtomdQaccum = std::conditional_t <
294289 !Seq_parallel,
@@ -297,6 +292,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
297292 >;
298293 GmemLayoutAtomdQaccum gmem_tiled_copy_dQaccum;
299294 auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice (tidx);
295+ typename Kernel_traits::GmemTiledCopydBias gmem_tiled_copy_dBias;
296+ auto gmem_thr_copy_dBias = gmem_tiled_copy_dBias.get_thread_slice (tidx);
300297
301298 Tensor tQgQ = gmem_thr_copy_QKV.partition_S (gQ );
302299 Tensor tQsQ = gmem_thr_copy_QKV.partition_D (sQ );
@@ -346,6 +343,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
346343
347344 Tensor acc_dk = partition_fragment_C (tiled_mma_dkv, Shape<Int<kBlockN >, Int<kHeadDim >>{}); // (MMA, MMA_N, MMA_K)
348345 Tensor acc_dv = partition_fragment_C (tiled_mma_dkv, Shape<Int<kBlockN >, Int<kHeadDim >>{}); // (MMA, MMA_N, MMA_K)
346+ [[maybe_unused]] auto acc_dbias = partition_fragment_C (tiled_mma_sdp, Shape<Int<kBlockM >, Int<kBlockN >>{});
347+ [[maybe_unused]] auto acc_dbias_rowcol = make_tensor (acc_dbias.data (), FLASH_NAMESPACE::convert_layout_acc_rowcol (acc_dbias.layout ()));
349348
350349 // Copy Atom retiling
351350 auto smem_tiled_copy_QdO = make_tiled_copy_A (typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
@@ -641,8 +640,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
641640 cute::copy (smem_tiled_copy_KV, tdPsV, tdPrV_copy_view);
642641 }
643642
644- clear (acc_dv);
645643 clear (acc_dk);
644+ clear (acc_dv);
645+ if constexpr (Has_bias) { if (accum_dbias) { clear (acc_dbias); } }
646646
647647 for (; m_block >= m_block_min; --m_block) {
648648 Tensor acc_s = partition_fragment_C (tiled_mma_sdp, Shape<Int<kBlockM >, Int<kBlockN >>{}); // (MMA=4, MMA_M, MMA_N)
@@ -806,6 +806,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
806806 float scaled_ds = pointwise_mult (scores (mi, ni), dS (mi, ni), dP_sum (mi));
807807 if constexpr (Is_softcap) { scaled_ds *= dtanh (mi, ni); }
808808 dS (mi, ni) = scaled_ds;
809+ if constexpr (Has_bias) {
810+ if (accum_dbias) {
811+ acc_dbias_rowcol (mi, ni) += scaled_ds;
812+ }
813+ }
809814 }
810815 }
811816 // if (cute::thread0()) { print(dS); }
@@ -852,36 +857,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
852857 __syncthreads ();
853858 if constexpr (Has_bias) {
854859 // Write dS to dBias
855- if (!params. accum_dbias ) {
860+ if (!accum_dbias) {
856861 FLASH_NAMESPACE::copy_MN<Is_even_MN, /* Clear_OOB_MN=*/ false >(
857862 gmem_tiled_copy_dBias,
858863 tBiassBias, tdBiasgdBias,
859864 tBiascBias, tBiaspBias,
860865 binfo.actual_seqlen_q - m_block * kBlockM
861866 );
862- } else {
863- #pragma unroll
864- for (int m = 0 ; m < size<1 >(tBiassBias); ++m) {
865- if (Is_even_MN || get<0 >(tBiascBias (0 , m, 0 )) < binfo.actual_seqlen_q - m_block * kBlockM ) {
866- #pragma unroll
867- for (int n = 0 ; n < size<2 >(tBiassBias); ++n) {
868- if (Is_even_MN || tBiaspBias (n)) {
869- #pragma unroll
870- for (int i = 0 ; i < size<0 >(tBiassBias); ++i) {
871- const auto coord = tBiascBias (i, m, n);
872- const int row = get<0 >(coord);
873- const int col = get<1 >(coord);
874- if (Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ) {
875- atomicAdd (
876- gdBias_accum_ptr + row * params.dbias_row_stride + col,
877- static_cast <ElementAccum>(tBiassBias (i, m, n))
878- );
879- }
880- }
881- }
882- }
883- }
884- }
885867 }
886868 }
887869
@@ -1023,9 +1005,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
10231005 // Advance gBias and gdBias
10241006 tBiasgBias.data () = tBiasgBias.data () + (-int (kBlockM * params.bias_row_stride ));
10251007 tdBiasgdBias.data () = tdBiasgdBias.data () + (-int (kBlockM * params.dbias_row_stride ));
1026- if (params.accum_dbias ) {
1027- gdBias_accum_ptr -= int (kBlockM * params.dbias_row_stride );
1028- }
10291008 if (any_active_next) {
10301009 FLASH_NAMESPACE::copy_MN<Is_even_MN, /* Clear_OOB_MN=*/ true >(
10311010 gmem_tiled_copy_Bias,
@@ -1069,10 +1048,53 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
10691048
10701049 // Epilogue
10711050
1051+ if constexpr (Has_bias) {
1052+ if (accum_dbias) {
1053+ const int actual_block_n = Is_even_MN ? kBlockN : std::max (0 , std::min (kBlockN , binfo.actual_seqlen_k - n_block * kBlockN ));
1054+
1055+ // Convert acc_dbias from fp32 to fp16
1056+ Tensor tdBiasrdBias = FLASH_NAMESPACE::convert_type<Element>(acc_dbias);
1057+
1058+ // Partition sBias to match the accumulator partitioning
1059+ Tensor tdBiasadBias = smem_thr_copy_Bias.retile_S (tdBiasrdBias); // ((Atom, AtomNum), MMA_M, MMA_N)
1060+
1061+ // We need syncthreads here since we're writing to the same location as sBias.
1062+ // Without syncthreads, some thread might modify the location of sBias while another thread
1063+ // is reading it for dQ gemm, leading to a race condition.
1064+ // If Is_last, there's already a __syncthreads() at the end of the loop.
1065+ if (!Is_last) { __syncthreads (); }
1066+
1067+ cute::copy (smem_tiled_copy_PdS, tdBiasadBias, tdSsdS);
1068+
1069+ __syncthreads ();
1070+ for (int col = threadIdx.x ; col < kBlockN ; col += blockDim.x ) {
1071+ if (col < actual_block_n) {
1072+ ElementAccum rowsum = 0 .f ;
1073+ #pragma unroll
1074+ for (int row = 0 ; row < kBlockM ; ++row) {
1075+ rowsum += static_cast <ElementAccum>(sdS (row, col));
1076+ }
1077+ sdS (0 , col) = static_cast <Element>(rowsum);
1078+ }
1079+ }
1080+ __syncthreads ();
1081+
1082+ #pragma unroll
1083+ for (int ni = 0 ; ni < size (tBiaspBias); ++ni) { tBiaspBias (ni) = ni < actual_block_n; }
1084+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
1085+ FLASH_NAMESPACE::copy_MN</* Is_even_MN=*/ false , /* Clear_OOB_MN=*/ false >(
1086+ gmem_tiled_copy_dBias,
1087+ tBiassBias, tdBiasgdBias,
1088+ tBiascBias, tBiaspBias,
1089+ /* max_M=*/ 1
1090+ );
1091+ }
1092+ }
1093+
10721094 #pragma unroll
10731095 for (int i = 0 ; i < size (acc_dk); ++i) { acc_dk (i) *= params.scale_softmax ; }
10741096
1075- // Convert acc_dv from fp32 to fp16
1097+ // Convert acc_dk, acc_dv from fp32 to fp16
10761098 Tensor rdK = FLASH_NAMESPACE::convert_type<Element>(acc_dk);
10771099 Tensor rdV = FLASH_NAMESPACE::convert_type<Element>(acc_dv);
10781100
0 commit comments