Skip to content

Commit 039634d

Browse files
committed
enable autotune for all attn
1 parent f213106 commit 039634d

File tree

1 file changed

+12
-30
lines changed

1 file changed

+12
-30
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
156156
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': 'large'}, num_stages=s, num_warps=w) \
157157
for BM in [128, 256] \
158158
for BN in [32, 64] \
159-
for s in [3, 4] \
159+
for s in [2, 3, 4] \
160160
for w in [8, 16, 32] \
161161
]
162162

@@ -178,35 +178,17 @@ def forward(q, k, v, causal, sm_scale):
178178
grid = lambda args: (q.shape[0], q.shape[1], triton.cdiv(q.shape[2], args['BLOCK_M']))
179179
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
180180

181-
if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0':
182-
# default pipeline
183-
tune_attn_fwd[grid](
184-
q, k, v, sm_scale, M, o, #
185-
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
186-
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
187-
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
188-
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
189-
q.shape[0], q.shape[1], #
190-
N_CTX=q.shape[2], #
191-
BLOCK_DMODEL=Lk, #
192-
STAGE=stage, #
193-
)
194-
else:
195-
_attn_fwd[grid](
196-
q, k, v, sm_scale, M, o, #
197-
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
198-
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
199-
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
200-
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
201-
q.shape[0], q.shape[1], #
202-
N_CTX=q.shape[2], #
203-
BLOCK_M=BLOCK_M, #
204-
BLOCK_N=BLOCK_N, #
205-
BLOCK_DMODEL=Lk, #
206-
STAGE=stage, #
207-
num_warps=num_warps, #
208-
num_stages=num_stages #
209-
)
181+
tune_attn_fwd[grid](
182+
q, k, v, sm_scale, M, o, #
183+
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
184+
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
185+
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
186+
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
187+
q.shape[0], q.shape[1], #
188+
N_CTX=q.shape[2], #
189+
BLOCK_DMODEL=Lk, #
190+
STAGE=stage, #
191+
)
210192
return o
211193

212194

0 commit comments

Comments
 (0)