Skip to content

Commit 0f629d8

Browse files
committed
Store strides, too.
1 parent fcfa1a5 commit 0f629d8

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

python/perf-kernels/flash-attention.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,10 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh
459459
tl.assume(stride_vh >= 0)
460460
tl.assume(stride_vk >= 0)
461461
tl.assume(stride_vn >= 0)
462+
tl.assume(stride_oz >= 0)
463+
tl.assume(stride_oh >= 0)
464+
tl.assume(stride_om >= 0)
465+
tl.assume(stride_on >= 0)
462466

463467
if PERSISTENT: # if persistent, kernel loops over multiple tiles
464468
NUM_WG = NUM_CU * GRID_CU_MULTIP # number of workgroups launched

0 commit comments

Comments
 (0)