Skip to content

Commit ebb89e6

Browse files
authored
Add auto-tuner to the flash attention benchmark. (#2299)
The hardcoded configuration for flash attention maybe not best performance for Triton. Add auto-tuner to the flash attention benchmark.
1 parent 978fe24 commit ebb89e6

File tree

1 file changed

+44
-17
lines changed

1 file changed

+44
-17
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import torch
23
import triton
34
import triton.language as tl
@@ -151,6 +152,18 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
151152
tl.store(O_block_ptr, acc.to(Out.type.element_ty))
152153

153154

155+
configs = [
156+
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \
157+
for BM in [256] \
158+
for BN in [32, 64] \
159+
for s in [3] \
160+
for w in [32] \
161+
]
162+
163+
tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'])
164+
tune_attn_fwd = tuner(_attn_fwd)
165+
166+
154167
def forward(q, k, v, causal, sm_scale):
155168
# shape constraints
156169
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
@@ -162,23 +175,38 @@ def forward(q, k, v, causal, sm_scale):
162175
num_stages = 3
163176
num_warps = 8 if Lq == 64 else 16
164177
stage = 3 if causal else 1
165-
grid = (q.shape[0], q.shape[1], triton.cdiv(q.shape[2], BLOCK_M))
178+
grid = lambda args: (q.shape[0], q.shape[1], triton.cdiv(q.shape[2], args['BLOCK_M']))
166179
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
167-
_attn_fwd[grid](
168-
q, k, v, sm_scale, M, o, #
169-
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
170-
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
171-
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
172-
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
173-
q.shape[0], q.shape[1], #
174-
N_CTX=q.shape[2], #
175-
BLOCK_M=BLOCK_M, #
176-
BLOCK_N=BLOCK_N, #
177-
BLOCK_DMODEL=Lk, #
178-
STAGE=stage, #
179-
num_warps=num_warps, #
180-
num_stages=num_stages #
181-
)
180+
181+
if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0':
182+
# default pipeline
183+
tune_attn_fwd[grid](
184+
q, k, v, sm_scale, M, o, #
185+
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
186+
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
187+
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
188+
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
189+
q.shape[0], q.shape[1], #
190+
N_CTX=q.shape[2], #
191+
BLOCK_DMODEL=Lk, #
192+
STAGE=stage, #
193+
)
194+
else:
195+
_attn_fwd[grid](
196+
q, k, v, sm_scale, M, o, #
197+
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
198+
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
199+
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
200+
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
201+
q.shape[0], q.shape[1], #
202+
N_CTX=q.shape[2], #
203+
BLOCK_M=BLOCK_M, #
204+
BLOCK_N=BLOCK_N, #
205+
BLOCK_DMODEL=Lk, #
206+
STAGE=stage, #
207+
num_warps=num_warps, #
208+
num_stages=num_stages #
209+
)
182210
return o
183211

184212

@@ -243,7 +271,6 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
243271
elif provider == 'triton':
244272
# FIXME: remove below if condition when extend attention support for Causal = True done
245273
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1102
246-
import os
247274
if os.environ.get('TRITON_INTEL_ADVANCED_PATH', '0') == '1' and CAUSAL:
248275
min_ms, max_ms, mean, cv = (float('inf'), ) * 4
249276
else:

0 commit comments

Comments
 (0)