Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions csrc/flash_dmattn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ void set_params_dgrad(
const float softcap,
bool has_mask,
bool has_bias,
bool accum_dbias,
bool deterministic,
const bool unpadded_lse
) {
Expand Down Expand Up @@ -245,6 +246,8 @@ void set_params_dgrad(
// Softmax sum
params.dsoftmax_sum = dsoftmax_sum_d;

params.accum_dbias = accum_dbias;

params.deterministic = deterministic;
}

Expand Down Expand Up @@ -977,12 +980,13 @@ mha_bwd(
? torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts)
: dv;
dbias_expanded = has_bias
? (
(num_heads_bias != num_heads || batch_size_bias != batch_size || seqlen_q_bias != seqlen_q) // MQA / GQA or dbias has different batch size or seqlen_q
? torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts)
: dbias
)
? (num_heads_bias != num_heads || batch_size_bias == 1 || seqlen_q_bias == 1) // MQA / GQA or dbias has different batch size or seqlen_q
? (seqlen_q_bias == 1)
? torch::zeros({batch_size, num_heads, 1, seqlen_k_rounded}, opts.dtype(at::kFloat))
: torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts)
: dbias
: torch::empty({0}, opts);
bool accum_dbias = has_bias && seqlen_q_bias != seqlen_q && seqlen_q_bias == 1;
Copy link

Copilot AI Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new condition only allocates an expanded buffer when batch_size_bias == 1 or seqlen_q_bias == 1 (broadcast), but drops the previous handling for general mismatches (batch_size_bias != batch_size or seqlen_q_bias != seqlen_q where the size is not 1). This can cause the kernel to write into a tensor with incompatible shape/strides when dbias dimensions mismatch but are not broadcastable (e.g., batch_size_bias=2 vs batch_size=4, seqlen_q_bias=64 vs seqlen_q=128). Restore the general mismatch allocation while keeping the float accumulation path for seqlen_q_bias == 1. For example, keep the outer mismatch check as (num_heads_bias != num_heads || batch_size_bias != batch_size || seqlen_q_bias != seqlen_q) and specialize the inner allocation to use float only when seqlen_q_bias == 1.

Copilot uses AI. Check for mistakes.

Flash_bwd_params params;

Expand All @@ -1009,6 +1013,7 @@ mha_bwd(
softcap,
has_mask,
has_bias,
accum_dbias,
deterministic,
/*unpadded_lse*/false
);
Expand Down Expand Up @@ -1036,9 +1041,10 @@ mha_bwd(
if (num_heads_bias != num_heads && batch_size_bias == batch_size && seqlen_q_bias == seqlen_q) {
at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2});
} else {
dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2});
if (seqlen_q_bias == 1) {
dbias_expanded = at::sum(dbias_expanded, {2}, true);
if (accum_dbias) {
dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, 1, seqlen_k_rounded}), {2});
} else {
dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2});
}
if (batch_size_bias == 1) {
dbias_expanded = at::sum(dbias_expanded, {0}, true);
Expand Down Expand Up @@ -1238,6 +1244,7 @@ mha_varlen_bwd(
softcap,
has_mask,
has_bias,
/*accum_dbias*/false,
deterministic,
/*unpadded_lse*/true
);
Expand Down
2 changes: 2 additions & 0 deletions csrc/flash_dmattn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ struct Flash_bwd_params : public Flash_fwd_params {

bool deterministic;
index_t dq_accum_split_stride;

bool accum_dbias;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
44 changes: 38 additions & 6 deletions csrc/flash_dmattn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.dbias_row_stride, _1{})
);
[[maybe_unused]] ElementAccum *gdBias_accum_ptr = nullptr;
if constexpr (Has_bias) {
gdBias_accum_ptr = reinterpret_cast<ElementAccum *>(params.dbias_ptr) + row_offset_dbias;
}
Tensor gdO = make_tensor(
make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
Expand Down Expand Up @@ -848,12 +852,37 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
__syncthreads();
if constexpr (Has_bias) {
// Write dS to dBias
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/false>(
gmem_tiled_copy_dBias,
tBiassBias, tdBiasgdBias,
tBiascBias, tBiaspBias,
binfo.actual_seqlen_q - m_block * kBlockM
);
if (!params.accum_dbias) {
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/false>(
gmem_tiled_copy_dBias,
tBiassBias, tdBiasgdBias,
tBiascBias, tBiaspBias,
binfo.actual_seqlen_q - m_block * kBlockM
);
} else {
#pragma unroll
for (int m = 0; m < size<1>(tBiassBias); ++m) {
if (Is_even_MN || get<0>(tBiascBias(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) {
#pragma unroll
for (int n = 0; n < size<2>(tBiassBias); ++n) {
if (Is_even_MN || tBiaspBias(n)) {
#pragma unroll
for (int i = 0; i < size<0>(tBiassBias); ++i) {
const auto coord = tBiascBias(i, m, n);
const int row = get<0>(coord);
const int col = get<1>(coord);
if (Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM) {
atomicAdd(
gdBias_accum_ptr + row * params.dbias_row_stride + col,
static_cast<ElementAccum>(tBiassBias(i, m, n))
);
}
}
}
Comment on lines +863 to +881
Copy link

Copilot AI Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This path performs an atomicAdd per element, which can cause significant contention when seqlen_q is large (many M-tiles accumulate into the same broadcasted bias row). Consider reducing within the threadblock first (e.g., per-(row,col) partial sums in shared memory or warp-level reductions) and issuing a single atomicAdd per (row,col) per block. This typically cuts the number of atomics by a factor of size<0>(tBiassBias) and improves throughput.

Copilot uses AI. Check for mistakes.
}
}
}
}
}

// if (cute::thread0()) { print(tPrP); }
Expand Down Expand Up @@ -994,6 +1023,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// Advance gBias and gdBias
tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride));
tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride));
if (params.accum_dbias) {
gdBias_accum_ptr -= int(kBlockM * params.dbias_row_stride);
}
if (any_active_next) {
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_Bias,
Expand Down