Skip to content

Commit 99778f4

Browse files
authored
[FA]: squeeze Z and H into one axis to align with XeTLA (#2618)
Squeeze Z H into the same axis as what XeTLA does. This change can have about 3% benefit for N_CTX = 512 shapes.
1 parent 49a52a2 commit 99778f4

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,11 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
7878
start_m = tl.program_id(2)
7979
off_z = tl.program_id(0)
8080
off_h = tl.program_id(1)
81+
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
8182
if N_CTX <= 512:
8283
start_m = tl.program_id(0)
8384
off_z = tl.program_id(2)
84-
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
85+
qvk_offset = off_z.to(tl.int64) * stride_qh
8586

8687
# block pointers
8788
Q_block_ptr = tl.make_block_ptr(
@@ -181,7 +182,7 @@ def forward(q, k, v, causal, sm_scale):
181182
grid = lambda args: (q.shape[0], q.shape[1], triton.cdiv(q.shape[2], args['BLOCK_M']))
182183
n_ctx = q.shape[2]
183184
if n_ctx <= 512:
184-
grid = lambda args: (triton.cdiv(q.shape[2], args['BLOCK_M']), q.shape[1], q.shape[0])
185+
grid = lambda args: (triton.cdiv(q.shape[2], args['BLOCK_M']), 1, q.shape[0] * q.shape[1])
185186
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
186187

187188
if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0':

0 commit comments

Comments
 (0)