Skip to content

Commit 6acc10c

Browse files
[FA] Add tl.assume to flash_attention.py
attn_fwd kernel can use buffer ops. Doesn't give any noticeable boost but maybe helpful to newer arch.
2 parents 5c18bec + 0f629d8 commit 6acc10c

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

python/perf-kernels/flash-attention.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,27 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh
443443
ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr,
444444
INT8: tl.constexpr, USE_P_SCALE: tl.constexpr, INT8_KV: tl.constexpr):
445445

446+
tl.assume(stride_qz >= 0)
447+
tl.assume(stride_qh >= 0)
448+
tl.assume(stride_qm >= 0)
449+
tl.assume(stride_qk >= 0)
450+
tl.assume(stride_kz >= 0)
451+
tl.assume(stride_kh >= 0)
452+
tl.assume(stride_kn >= 0)
453+
tl.assume(stride_kk >= 0)
454+
tl.assume(stride_bz >= 0)
455+
tl.assume(stride_bh >= 0)
456+
tl.assume(stride_bm >= 0)
457+
tl.assume(stride_bn >= 0)
458+
tl.assume(stride_vz >= 0)
459+
tl.assume(stride_vh >= 0)
460+
tl.assume(stride_vk >= 0)
461+
tl.assume(stride_vn >= 0)
462+
tl.assume(stride_oz >= 0)
463+
tl.assume(stride_oh >= 0)
464+
tl.assume(stride_om >= 0)
465+
tl.assume(stride_on >= 0)
466+
446467
if PERSISTENT: # if persistent, kernel loops over multiple tiles
447468
NUM_WG = NUM_CU * GRID_CU_MULTIP # number of workgroups launched
448469
num_tiles_per_head = tl.cdiv(MAX_SEQLENS_Q, BLOCK_M) # the number of work units (tiles) of a single head

0 commit comments

Comments
 (0)