Skip to content

Commit 14f2c1c

Browse files
authored
【Fix Bug】fix startend_row_indices bug (#2565)
1 parent c2a93f6 commit 14f2c1c

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

paddleformers/transformers/llama/fusion_ops.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,16 +248,14 @@ def fusion_flash_attention(
248248
else:
249249
if attn_mask_startend_row_indices is not None:
250250
assert alibi is None, "flashmask_attention or flash_attention_with_sparse_mask not support alibi"
251-
if len(attn_mask_startend_row_indices.shape) == 2:
252-
attn_mask_startend_row_indices = paddle.unsqueeze(attn_mask_startend_row_indices, axis=1)
253251

254252
if hasattr(F, "flashmask_attention"):
255253
attn_output = no_recompute(
256254
F.flashmask_attention,
257255
query_states,
258256
key_states,
259257
value_states,
260-
startend_row_indices=attn_mask_startend_row_indices.unsqueeze(-1),
258+
startend_row_indices=attn_mask_startend_row_indices,
261259
causal=True,
262260
enable=skip_recompute,
263261
)

0 commit comments

Comments
 (0)