Skip to content

Commit 34ee120

Browse files
authored
Remove the if-else block in causal loop (#779)
* Remove the if-else block in causal loop * Remove unnecessary ( ) * fix format
1 parent 44313cf commit 34ee120

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

python/perf-kernels/flash-attention.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -261,21 +261,29 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
261261
# We start from end of seqlen_k so only the first iteration would need
262262
# to be checked for padding if it is not a multiple of block_n
263263
# TODO: This can be optimized to only be true for the padded block.
264+
mask = tl.full([BLOCK_M, BLOCK_N], True, dtype=tl.int1)
264265
if MASK_STEPS:
265266
# If this is the last block / iteration, we want to
266267
# mask if the sequence length is not a multiple of block size
267268
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn.
268269
# last step might get wasted but that is okay. check if this masking works For
269270
# that case.
270-
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
271-
boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)
272-
size_n = start_n + OFFS_N[None, :]
273-
mask = size_n < boundary_m[:, None]
274-
qk = tl.where(mask, qk, float("-inf"))
271+
272+
# remove the old if condition
273+
# if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
274+
# Though this will unconditionally compute mask_partial at runtime,
275+
# the causal for loop does not have the if-else block any more, which
276+
# helps instruction scheduling and register pressure.
277+
bound_cond = (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0)
278+
boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)
279+
size_n = start_n + OFFS_N[None, :]
280+
mask_partial = size_n < boundary_m[:, None]
281+
mask = tl.where(bound_cond, mask_partial, mask)
275282
if IS_CAUSAL:
276283
causal_boundary = start_n + offs_n_causal
277284
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
278-
qk = tl.where(causal_mask, qk, float("-inf"))
285+
mask = mask and causal_mask
286+
qk = tl.where(mask, qk, float("-inf"))
279287
# -- compute qk ----
280288
if INT8_GEMM:
281289
qk += ((((tl.dot(q, k).to(tl.float32) * q_descale)) * k_descale) * QK_SCALE)
@@ -370,7 +378,7 @@ def is_hip():
370378

371379
def is_cdna():
372380
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
373-
'gfx90a', 'gfx908')
381+
'gfx950', 'gfx90a', 'gfx908')
374382

375383

376384
def is_rdna():

0 commit comments

Comments
 (0)