Skip to content

Commit 9c40aaf

Browse files
committed
[FA]:Optimize cache hit for N_CTX <= 512
1 parent 6db3b52 commit 9c40aaf

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
7878
start_m = tl.program_id(2)
7979
off_z = tl.program_id(0)
8080
off_h = tl.program_id(1)
81+
if N_CTX <= 512:
82+
start_m = tl.program_id(0)
83+
off_z = tl.program_id(2)
8184
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
8285

8386
# block pointers
@@ -176,6 +179,9 @@ def forward(q, k, v, causal, sm_scale):
176179
num_warps = 8 if Lq == 64 else 16
177180
stage = 3 if causal else 1
178181
grid = lambda args: (q.shape[0], q.shape[1], triton.cdiv(q.shape[2], args['BLOCK_M']))
182+
n_ctx = q.shape[2]
183+
if n_ctx <= 512:
184+
grid = lambda args: (triton.cdiv(q.shape[2], args['BLOCK_M']), q.shape[1], q.shape[0])
179185
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
180186

181187
if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0':

0 commit comments

Comments
 (0)