Skip to content

Commit 15d2334

Browse files
committed
enable autotune for all attn
1 parent f213106 commit 15d2334

File tree

1 file changed

+16
-34
lines changed

1 file changed

+16
-34
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
156156
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': 'large'}, num_stages=s, num_warps=w) \
157157
for BM in [128, 256] \
158158
for BN in [32, 64] \
159-
for s in [3, 4] \
159+
for s in [2, 3, 4] \
160160
for w in [8, 16, 32] \
161161
]
162162

@@ -170,43 +170,25 @@ def forward(q, k, v, causal, sm_scale):
170170
assert Lq == Lk and Lk == Lv
171171
assert Lk in {16, 32, 64, 128}
172172
o = torch.empty_like(q, dtype=torch.float32)
173-
BLOCK_M = 128
174-
BLOCK_N = 64 if Lk <= 64 else 32
175-
num_stages = 3
176-
num_warps = 8 if Lq == 64 else 16
173+
#BLOCK_M = 128
174+
#BLOCK_N = 64 if Lk <= 64 else 32
175+
#num_stages = 3
176+
#num_warps = 8 if Lq == 64 else 16
177177
stage = 3 if causal else 1
178178
grid = lambda args: (q.shape[0], q.shape[1], triton.cdiv(q.shape[2], args['BLOCK_M']))
179179
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
180180

181-
if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0':
182-
# default pipeline
183-
tune_attn_fwd[grid](
184-
q, k, v, sm_scale, M, o, #
185-
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
186-
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
187-
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
188-
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
189-
q.shape[0], q.shape[1], #
190-
N_CTX=q.shape[2], #
191-
BLOCK_DMODEL=Lk, #
192-
STAGE=stage, #
193-
)
194-
else:
195-
_attn_fwd[grid](
196-
q, k, v, sm_scale, M, o, #
197-
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
198-
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
199-
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
200-
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
201-
q.shape[0], q.shape[1], #
202-
N_CTX=q.shape[2], #
203-
BLOCK_M=BLOCK_M, #
204-
BLOCK_N=BLOCK_N, #
205-
BLOCK_DMODEL=Lk, #
206-
STAGE=stage, #
207-
num_warps=num_warps, #
208-
num_stages=num_stages #
209-
)
181+
tune_attn_fwd[grid](
182+
q, k, v, sm_scale, M, o, #
183+
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
184+
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
185+
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
186+
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
187+
q.shape[0], q.shape[1], #
188+
N_CTX=q.shape[2], #
189+
BLOCK_DMODEL=Lk, #
190+
STAGE=stage, #
191+
)
210192
return o
211193

212194

0 commit comments

Comments
 (0)