Skip to content

Commit b8fc4b9

Browse files
authored
[FA]:Optimize FlashAttention for N_CTX <= 512 (#2600)
This change can make the cache behavior of N_CTX=512 better. We can get 20%+ performance gain from the change, but it may be harmful to large N_CTX. So restrict this change to N_CTX <= 512. [CI data](https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/11603077461/job/32309360931) ![image](https://github.com/user-attachments/assets/cf0470f0-c28f-48ba-9888-4fe24bee53a3)
1 parent c5beb57 commit b8fc4b9

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
7878
start_m = tl.program_id(2)
7979
off_z = tl.program_id(0)
8080
off_h = tl.program_id(1)
81+
if N_CTX <= 512:
82+
start_m = tl.program_id(0)
83+
off_z = tl.program_id(2)
8184
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
8285

8386
# block pointers
@@ -176,6 +179,9 @@ def forward(q, k, v, causal, sm_scale):
176179
num_warps = 8 if Lq == 64 else 16
177180
stage = 3 if causal else 1
178181
grid = lambda args: (q.shape[0], q.shape[1], triton.cdiv(q.shape[2], args['BLOCK_M']))
182+
n_ctx = q.shape[2]
183+
if n_ctx <= 512:
184+
grid = lambda args: (triton.cdiv(q.shape[2], args['BLOCK_M']), q.shape[1], q.shape[0])
179185
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
180186

181187
if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0':

0 commit comments

Comments
 (0)