Skip to content

Commit d226164

Browse files
committed
Use QKVO threads-per-row for reduction partitioning
Aligns reduction configuration with the QKVO-specific per-row thread count to keep template and divisor consistent. Fixes a mismatch that could mis-partition threads, improving correctness and consistency in backward preprocessing.
1 parent bbfbbc3 commit d226164

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ inline __device__ void compute_dot_do_o(
148148
tdOcdO, tdOpdO,
149149
binfo.actual_seqlen_q - m_block * kBlockM
150150
);
151-
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(
151+
dot_do_o<Kernel_traits::kGmemThreadsPerRowQKVO>(
152152
tdOrdO, tdOrO, dP_sum,
153-
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow)
153+
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRowQKVO)
154154
);
155155
if (Clear_dQaccum) {
156156
// We're actually not zero'ing out all of dQaccum, but only the part that we're going to

0 commit comments

Comments
 (0)