3737 "or 3 (down-right aligned causal mask)."
3838 )
3939
40+ ATTN_MASK_NPU = None
41+
4042
4143def is_npu_fa2_top_left_aligned_causal_mask ():
4244 return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available () else False
@@ -171,7 +173,9 @@ def npu_flash_attn_func(
171173 head_num = q .shape [2 ]
172174 output = torch_npu .npu_fusion_attention (q , k , v , head_num , "BSND" , keep_prob = keep_prob , scale = softmax_scale )[0 ]
173175 else :
174- attn_mask_npu = torch .triu (torch .ones ([2048 , 2048 ], device = q .device ), diagonal = 1 ).bool ()
176+ global ATTN_MASK_NPU
177+ if ATTN_MASK_NPU is None :
178+ ATTN_MASK_NPU = torch .triu (torch .ones ([2048 , 2048 ], device = q .device ), diagonal = 1 ).bool ()
175179 head_num = q .shape [2 ]
176180 output = torch_npu .npu_fusion_attention (
177181 q ,
@@ -181,7 +185,7 @@ def npu_flash_attn_func(
181185 "BSND" ,
182186 keep_prob = keep_prob ,
183187 scale = softmax_scale ,
184- atten_mask = attn_mask_npu ,
188+ atten_mask = ATTN_MASK_NPU ,
185189 sparse_mode = SPARSE_MODE ,
186190 )[0 ]
187191
@@ -222,7 +226,9 @@ def npu_flash_attn_varlen_func(
222226 actual_seq_kvlen = tuple (cu_seqlens_k [1 :].cpu ().numpy ().tolist ()),
223227 )[0 ]
224228 else :
225- attn_mask_npu = torch .triu (torch .ones ([2048 , 2048 ], device = q .device ), diagonal = 1 ).bool ()
229+ global ATTN_MASK_NPU
230+ if ATTN_MASK_NPU is None :
231+ ATTN_MASK_NPU = torch .triu (torch .ones ([2048 , 2048 ], device = q .device ), diagonal = 1 ).bool ()
226232 head_num = q .shape [1 ]
227233 output = torch_npu .npu_fusion_attention (
228234 q ,
@@ -231,7 +237,7 @@ def npu_flash_attn_varlen_func(
231237 head_num ,
232238 pse = None ,
233239 padding_mask = None ,
234- atten_mask = attn_mask_npu ,
240+ atten_mask = ATTN_MASK_NPU ,
235241 scale = softmax_scale ,
236242 keep_prob = keep_prob ,
237243 input_layout = "TND" ,
0 commit comments