From 9c40aaff38d9d60cab2c3fa35cf080026cfbd869 Mon Sep 17 00:00:00 2001 From: "Wang, Quintin" Date: Thu, 31 Oct 2024 04:27:56 +0000 Subject: [PATCH] [FA]:Optimize cache hit for N_CTX <= 512 --- .../flash_attention_fwd_benchmark.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index 83cca419ec..9409452aa5 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -78,6 +78,9 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, # start_m = tl.program_id(2) off_z = tl.program_id(0) off_h = tl.program_id(1) + if N_CTX <= 512: + start_m = tl.program_id(0) + off_z = tl.program_id(2) qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh # block pointers @@ -176,6 +179,9 @@ def forward(q, k, v, causal, sm_scale): num_warps = 8 if Lq == 64 else 16 stage = 3 if causal else 1 grid = lambda args: (q.shape[0], q.shape[1], triton.cdiv(q.shape[2], args['BLOCK_M'])) + n_ctx = q.shape[2] + if n_ctx <= 512: + grid = lambda args: (triton.cdiv(q.shape[2], args['BLOCK_M']), q.shape[1], q.shape[0]) M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0':