Skip to content

Commit f3102e1

Browse files
committed
Stops redundant dQ scaling
Prevents applying the softmax factor twice in the backward preprocessing so downstream gradients stay correctly scaled.
1 parent fefb7a9 commit f3102e1

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,8 @@ inline __device__ void convert_dQ(
279279
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); }
280280
tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride;
281281
}
282-
#pragma unroll
283-
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; }
282+
// #pragma unroll
283+
// for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; }
284284
// Convert acc_dq from fp32 to fp16
285285
Tensor rdQ = FLASH_NAMESPACE::convert_type<Element>(acc_dq);
286286
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom, AtomNum), MMA_N, MMA_N)

0 commit comments

Comments
 (0)