Skip to content

Commit 71562ec

Browse files
authored
[FlashAttn Backward] Change the config of flex attn bwd kernel (#5152)
Change the num_warps and num_stages to 16 and 3 which get better performance in compiling and running. Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 9290e9a commit 71562ec

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ def backward(ctx, do):
508508
dv = torch.empty_like(v)
509509
BATCH, N_HEAD, N_CTX = q.shape[:3]
510510
PRE_BLOCK = 128
511-
NUM_WARPS, NUM_STAGES = 4, 5
511+
NUM_WARPS, NUM_STAGES = 16, 3
512512
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
513513
BLK_SLICE_FACTOR = 2
514514
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)

0 commit comments

Comments
 (0)