Skip to content

Commit 349c92f

Browse files
authored
fix_fa3_bug (#11038)
1 parent 3737315 commit 349c92f

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,10 +1553,7 @@ def backward(ctx, dout):
15531553
else:
15541554
assert False, f"invalid {FA_VERSION=}"
15551555

1556-
if FA_VERSION == 2:
1557-
assert not recompute_fa3
1558-
assert attn_out is not None and softmax_lse is not None
1559-
if FA_VERSION == 3 and not recompute_fa3:
1556+
if (FA_VERSION == 3 and not recompute_fa3) or FA_VERSION == 2:
15601557
assert attn_out is not None and softmax_lse is not None
15611558

15621559
q_ln_t, q_ln_invar = fused_ln.fused_rms_norm(q_init, q_ln_weight, eps)

0 commit comments

Comments
 (0)