@@ -261,21 +261,29 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
261261 # We start from end of seqlen_k so only the first iteration would need
262262 # to be checked for padding if it is not a multiple of block_n
263263 # TODO: This can be optimized to only be true for the padded block.
264+ mask = tl .full ([BLOCK_M , BLOCK_N ], True , dtype = tl .int1 )
264265 if MASK_STEPS :
265266 # If this is the last block / iteration, we want to
266267 # mask if the sequence length is not a multiple of block size
267268 # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn.
268269 # last step might get wasted but that is okay. check if this masking works For
269270 # that case.
270- if (start_n + BLOCK_N == block_max ) and (n_extra_tokens != 0 ):
271- boundary_m = tl .full ([BLOCK_M ], actual_seqlen_k , dtype = tl .int32 )
272- size_n = start_n + OFFS_N [None , :]
273- mask = size_n < boundary_m [:, None ]
274- qk = tl .where (mask , qk , float ("-inf" ))
271+
272+ # remove the old if condition
273+ # if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
274+ # Though this will unconditionally compute mask_partial at runtime,
275+ # the causal for loop does not have the if-else block any more, which
276+ # helps instruction scheduling and register pressure.
277+ bound_cond = (start_n + BLOCK_N == block_max ) and (n_extra_tokens != 0 )
278+ boundary_m = tl .full ([BLOCK_M ], actual_seqlen_k , dtype = tl .int32 )
279+ size_n = start_n + OFFS_N [None , :]
280+ mask_partial = size_n < boundary_m [:, None ]
281+ mask = tl .where (bound_cond , mask_partial , mask )
275282 if IS_CAUSAL :
276283 causal_boundary = start_n + offs_n_causal
277284 causal_mask = OFFS_M [:, None ] >= causal_boundary [None , :]
278- qk = tl .where (causal_mask , qk , float ("-inf" ))
285+ mask = mask and causal_mask
286+ qk = tl .where (mask , qk , float ("-inf" ))
279287 # -- compute qk ----
280288 if INT8_GEMM :
281289 qk += ((((tl .dot (q , k ).to (tl .float32 ) * q_descale )) * k_descale ) * QK_SCALE )
@@ -370,7 +378,7 @@ def is_hip():
370378
371379def is_cdna ():
372380 return is_hip () and triton .runtime .driver .active .get_current_target ().arch in ('gfx940' , 'gfx941' , 'gfx942' ,
373- 'gfx90a' , 'gfx908' )
381+ 'gfx950' , ' gfx90a' , 'gfx908' )
374382
375383
376384def is_rdna ():
0 commit comments