Skip to content

Commit 2c35c89

Browse files
committed
Drops redundant softmax unscale
Removes the unused reverse scaling parameter from the forward configuration to avoid stale values when softcap toggles.
1 parent 77c9e80 commit 2c35c89

File tree

2 files changed

+0
-3
lines changed

2 files changed

+0
-3
lines changed

csrc/flash_dmattn/flash_api.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,11 @@ void set_params_fprop(
132132
params.softcap = softmax_scale / softcap;
133133
params.scale_softmax = softcap;
134134
params.scale_softmax_log2 = softcap * M_LOG2E;
135-
params.unscale_softmax = 1.0f / softmax_scale;
136135
} else{
137136
// Remove potential NaN
138137
params.softcap = 0.0;
139138
params.scale_softmax = softmax_scale;
140139
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
141-
params.unscale_softmax = 1.0f / softmax_scale;
142140
}
143141

144142
params.is_causal = is_causal;

csrc/flash_dmattn/src/flash.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par
101101
// The scaling factors for the kernel.
102102
float scale_softmax;
103103
float scale_softmax_log2;
104-
float unscale_softmax;
105104
float softcap;
106105

107106
// array of length b+1 holding starting offset of each sequence.

0 commit comments

Comments
 (0)