Skip to content

Commit 8426847

Browse files
authored
fix fa3 scale bug (#10961)
1 parent 8cb1712 commit 8426847

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1306,9 +1306,9 @@ def forward(
13061306
)
13071307

13081308
q_head_dim = query_states.shape[-1]
1309-
softmax_scale = softmax_scale * (q_head_dim**0.5)
13101309

13111310
if FA_VERSION == 2:
1311+
softmax_scale = softmax_scale * (q_head_dim**0.5)
13121312
query_states = query_states * softmax_scale
13131313
kv_seq_len = value_states.shape[1]
13141314
v_num_heads = value_states.shape[2]

0 commit comments

Comments
 (0)