File tree Expand file tree Collapse file tree 1 file changed +1
-3
lines changed
paddleformers/transformers/llama Expand file tree Collapse file tree 1 file changed +1
-3
lines changed Original file line number Diff line number Diff line change @@ -248,16 +248,14 @@ def fusion_flash_attention(
248
248
else :
249
249
if attn_mask_startend_row_indices is not None :
250
250
assert alibi is None , "flashmask_attention or flash_attention_with_sparse_mask not support alibi"
251
- if len (attn_mask_startend_row_indices .shape ) == 2 :
252
- attn_mask_startend_row_indices = paddle .unsqueeze (attn_mask_startend_row_indices , axis = 1 )
253
251
254
252
if hasattr (F , "flashmask_attention" ):
255
253
attn_output = no_recompute (
256
254
F .flashmask_attention ,
257
255
query_states ,
258
256
key_states ,
259
257
value_states ,
260
- startend_row_indices = attn_mask_startend_row_indices . unsqueeze ( - 1 ) ,
258
+ startend_row_indices = attn_mask_startend_row_indices ,
261
259
causal = True ,
262
260
enable = skip_recompute ,
263
261
)
You can’t perform that action at this time.
0 commit comments