Skip to content

Commit 77c9e80

Browse files
committed
Refines bias/mask prep in flash bwd
Pre-scales the keys right after synchronization so later matmul steps reuse the scaled values and hide latency. Unifies the mask and bias hydration before streaming to keep accumulators coherent and drops the now redundant gradient scaling.
1 parent f3102e1 commit 77c9e80

File tree

1 file changed

+31
-56
lines changed

1 file changed

+31
-56
lines changed

csrc/flash_dmattn/src/flash_bwd_kernel.h

Lines changed: 31 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -644,17 +644,37 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
644644
clear(acc_dv);
645645
if constexpr (Has_bias) { if (accum_dbias) { clear(acc_dbias); } }
646646

647+
cute::cp_async_wait<0>();
648+
__syncthreads();
649+
650+
// Scale K once before streaming loop Q
651+
#pragma unroll
652+
for (int k = 0; k < size(tKsK); ++k) {
653+
tKsK(k) = static_cast<Element>(tKsK(k) * params.scale_softmax);
654+
}
655+
647656
for (; m_block >= m_block_min; --m_block) {
648657
Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
649-
cute::cp_async_wait<0>();
650-
__syncthreads();
651-
652658
Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
653659
Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (MMA=4, MMA_M, MMA_K)
660+
cute::cp_async_wait<0>();
661+
__syncthreads();
654662

655663
if (any_active) {
656-
clear(acc_s);
664+
if constexpr (Has_bias) {
665+
// Copy bias from smem to registers
666+
Tensor tSrBias = make_tensor<Element>(shape(acc_s));
667+
Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias);
668+
cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view);
669+
#pragma unroll
670+
for (int i = 0; i < size(acc_s); ++i) { acc_s(i) = tSrBias(i); }
671+
} else {
672+
clear(acc_s);
673+
}
674+
}
675+
657676

677+
if (any_active) {
658678
Tensor dP_sum = make_fragment_like(lse);
659679
#pragma unroll
660680
for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); }
@@ -686,71 +706,26 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
686706
FLASH_NAMESPACE::calculate_dtanh(scores, dtanh, params.softcap);
687707
}
688708

689-
if constexpr (Has_mask && Has_bias) {
690-
// Copy mask and bias from smem to registers
691-
Tensor tSrMask = make_tensor<Element>(shape(acc_s));
692-
Tensor tSrMask_copy_view = smem_thr_copy_PdS.retile_D(tSrMask);
693-
cute::copy(smem_tiled_copy_PdS, tSsMask, tSrMask_copy_view);
694-
Tensor tSrBias = make_tensor<Element>(shape(acc_s));
695-
Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias);
696-
cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view);
697-
698-
// Reshape mask, bias from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N))
699-
Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout()));
700-
Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout()));
701-
702-
// TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
703-
// actual_seqlen_k, because acc_s would be some finite value for those indices.
704-
// In the end when we multiply with K to get dQ, the corresponding values of K would be 0,
705-
// so the result would still be correct.
706-
// However, it's possible that the values in acc_s are so large that they overflow
707-
// when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ.
708-
// So we need to mask out the elements beyond actual_seqlen_k.
709-
FLASH_NAMESPACE::apply_mask</*Causal_mask=*/Is_causal, /*Has_mask=*/Has_mask, /*Has_bias=*/Has_bias>(
710-
scores, mask, bias, params.scale_softmax,
711-
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
712-
binfo.actual_seqlen_k,
713-
m_block * kBlockM + get<0>(taccScS_row(0)),
714-
binfo.actual_seqlen_q,
715-
AtomLayoutMS * 16
716-
);
717-
} else if constexpr (Has_mask && !Has_bias) {
709+
if constexpr (Has_mask) {
718710
// Copy mask from smem to registers
719711
Tensor tSrMask = make_tensor<Element>(shape(acc_s));
720712
Tensor tSrMask_copy_view = smem_thr_copy_PdS.retile_D(tSrMask);
721713
cute::copy(smem_tiled_copy_PdS, tSsMask, tSrMask_copy_view);
722714

723715
// Reshape mask from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N))
724716
Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout()));
725-
726-
FLASH_NAMESPACE::apply_mask</*Causal_mask=*/Is_causal, /*Has_mask=*/Has_mask, /*Has_bias=*/Has_bias>(
727-
scores, mask, /*bias=*/nullptr, params.scale_softmax,
728-
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
729-
binfo.actual_seqlen_k,
730-
m_block * kBlockM + get<0>(taccScS_row(0)),
731-
binfo.actual_seqlen_q,
732-
AtomLayoutMS * 16
733-
);
734-
} else if constexpr (!Has_mask && Has_bias) {
735-
// Copy bias from smem to registers
736-
Tensor tSrBias = make_tensor<Element>(shape(acc_s));
737-
Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias);
738-
cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view);
739717

740-
// Reshape bias from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N))
741-
Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout()));
742-
743-
FLASH_NAMESPACE::apply_mask</*Causal_mask=*/Is_causal, /*Has_mask=*/Has_mask, /*Has_bias=*/Has_bias>(
744-
scores, /*mask=*/nullptr, bias, params.scale_softmax,
718+
FLASH_NAMESPACE::apply_mask<Is_causal, Has_mask>(
719+
scores, mask,
745720
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
746721
binfo.actual_seqlen_k,
747722
m_block * kBlockM + get<0>(taccScS_row(0)),
748723
binfo.actual_seqlen_q,
749724
AtomLayoutMS * 16
750725
);
751726
} else {
752-
FLASH_NAMESPACE::apply_mask</*Causal_mask=*/Is_causal, /*Has_mask=*/Has_mask, /*Has_bias=*/Has_bias>(
753-
scores, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax,
727+
FLASH_NAMESPACE::apply_mask<Is_causal, Has_mask>(
728+
scores, /*mask=*/nullptr,
754729
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
755730
binfo.actual_seqlen_k,
756731
m_block * kBlockM + get<0>(taccScS_row(0)),
@@ -965,8 +940,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
965940
for (int i = 0; i < size(acc_dq); ++i) { atomicAdd(&tdQgdQaccum(i), acc_dq(i)); }
966941
}
967942
} else {
968-
#pragma unroll
969-
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; }
943+
// #pragma unroll
944+
// for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; }
970945
// Convert acc_dq from fp32 to fp16
971946
Tensor rdQ = FLASH_NAMESPACE::convert_type<Element>(acc_dq);
972947
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom, AtomNum), MMA_M, MMA_K)

0 commit comments

Comments
 (0)