-
Notifications
You must be signed in to change notification settings - Fork 40
Enhance bias gradient accumulation in backward pass #193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -159,6 +159,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in | |
| Shape<Int<kBlockM>, Int<kBlockN>>{}, | ||
| make_stride(params.dbias_row_stride, _1{}) | ||
| ); | ||
| [[maybe_unused]] ElementAccum *gdBias_accum_ptr = nullptr; | ||
| if constexpr (Has_bias) { | ||
| gdBias_accum_ptr = reinterpret_cast<ElementAccum *>(params.dbias_ptr) + row_offset_dbias; | ||
| } | ||
| Tensor gdO = make_tensor( | ||
| make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do), | ||
| Shape<Int<kBlockM>, Int<kHeadDim>>{}, | ||
|
|
@@ -848,12 +852,37 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in | |
| __syncthreads(); | ||
| if constexpr (Has_bias) { | ||
| // Write dS to dBias | ||
| FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/false>( | ||
| gmem_tiled_copy_dBias, | ||
| tBiassBias, tdBiasgdBias, | ||
| tBiascBias, tBiaspBias, | ||
| binfo.actual_seqlen_q - m_block * kBlockM | ||
| ); | ||
| if (!params.accum_dbias) { | ||
| FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/false>( | ||
| gmem_tiled_copy_dBias, | ||
| tBiassBias, tdBiasgdBias, | ||
| tBiascBias, tBiaspBias, | ||
| binfo.actual_seqlen_q - m_block * kBlockM | ||
| ); | ||
| } else { | ||
| #pragma unroll | ||
| for (int m = 0; m < size<1>(tBiassBias); ++m) { | ||
| if (Is_even_MN || get<0>(tBiascBias(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) { | ||
| #pragma unroll | ||
| for (int n = 0; n < size<2>(tBiassBias); ++n) { | ||
| if (Is_even_MN || tBiaspBias(n)) { | ||
| #pragma unroll | ||
| for (int i = 0; i < size<0>(tBiassBias); ++i) { | ||
| const auto coord = tBiascBias(i, m, n); | ||
| const int row = get<0>(coord); | ||
| const int col = get<1>(coord); | ||
| if (Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM) { | ||
| atomicAdd( | ||
| gdBias_accum_ptr + row * params.dbias_row_stride + col, | ||
| static_cast<ElementAccum>(tBiassBias(i, m, n)) | ||
| ); | ||
| } | ||
| } | ||
| } | ||
|
Comment on lines
+863
to
+881
|
||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // if (cute::thread0()) { print(tPrP); } | ||
|
|
@@ -994,6 +1023,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in | |
| // Advance gBias and gdBias | ||
| tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride)); | ||
| tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride)); | ||
| if (params.accum_dbias) { | ||
| gdBias_accum_ptr -= int(kBlockM * params.dbias_row_stride); | ||
| } | ||
| if (any_active_next) { | ||
| FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>( | ||
| gmem_tiled_copy_Bias, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new condition only allocates an expanded buffer when batch_size_bias == 1 or seqlen_q_bias == 1 (broadcast), but drops the previous handling for general mismatches (batch_size_bias != batch_size or seqlen_q_bias != seqlen_q where the size is not 1). This can cause the kernel to write into a tensor with incompatible shape/strides when dbias dimensions mismatch but are not broadcastable (e.g., batch_size_bias=2 vs batch_size=4, seqlen_q_bias=64 vs seqlen_q=128). Restore the general mismatch allocation while keeping the float accumulation path for seqlen_q_bias == 1. For example, keep the outer mismatch check as (num_heads_bias != num_heads || batch_size_bias != batch_size || seqlen_q_bias != seqlen_q) and specialize the inner allocation to use float only when seqlen_q_bias == 1.