Skip to content

Commit 3c289e2

Browse files
authored
[performance_optim] reduce frequency of declaring attention_mask in Ascend NPU flash attention (#38278)
[performance_optim] reduce frequency of declaring attention_mask in ASCEND NPU flash attention
1 parent f5d45d8 commit 3c289e2

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

src/transformers/integrations/npu_flash_attention.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
"or 3 (down-right aligned causal mask)."
3838
)
3939

40+
ATTN_MASK_NPU = None
41+
4042

4143
def is_npu_fa2_top_left_aligned_causal_mask():
4244
return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False
@@ -171,7 +173,9 @@ def npu_flash_attn_func(
171173
head_num = q.shape[2]
172174
output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
173175
else:
174-
attn_mask_npu = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
176+
global ATTN_MASK_NPU
177+
if ATTN_MASK_NPU is None:
178+
ATTN_MASK_NPU = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
175179
head_num = q.shape[2]
176180
output = torch_npu.npu_fusion_attention(
177181
q,
@@ -181,7 +185,7 @@ def npu_flash_attn_func(
181185
"BSND",
182186
keep_prob=keep_prob,
183187
scale=softmax_scale,
184-
atten_mask=attn_mask_npu,
188+
atten_mask=ATTN_MASK_NPU,
185189
sparse_mode=SPARSE_MODE,
186190
)[0]
187191

@@ -222,7 +226,9 @@ def npu_flash_attn_varlen_func(
222226
actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
223227
)[0]
224228
else:
225-
attn_mask_npu = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
229+
global ATTN_MASK_NPU
230+
if ATTN_MASK_NPU is None:
231+
ATTN_MASK_NPU = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
226232
head_num = q.shape[1]
227233
output = torch_npu.npu_fusion_attention(
228234
q,
@@ -231,7 +237,7 @@ def npu_flash_attn_varlen_func(
231237
head_num,
232238
pse=None,
233239
padding_mask=None,
234-
atten_mask=attn_mask_npu,
240+
atten_mask=ATTN_MASK_NPU,
235241
scale=softmax_scale,
236242
keep_prob=keep_prob,
237243
input_layout="TND",

0 commit comments

Comments
 (0)