Skip to content

Commit 9f1ebab

Browse files
authored
Support to run the flex attn microbench to on CUDA. (#4806)
Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 4d26c20 commit 9f1ebab

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from torch._inductor.template_heuristics import FlexConfig, FlexDecodeConfig
1717

1818
import triton_kernels_benchmark as benchmark_suit
19+
import triton
20+
21+
DEVICE = triton.runtime.driver.active.get_active_torch_device()
1922

2023
# Use TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 or uncomment the following line to print the auto-tune results.
2124
# torch._inductor.config.max_autotune_gemm = True
@@ -59,7 +62,7 @@ def get_flex_decode_configs(*args, **kwargs): # pylint: disable=unused-argument
5962

6063

6164
@lru_cache
62-
def create_block_mask_cached(score_mod, B, H, M, N, device='xpu'):
65+
def create_block_mask_cached(score_mod, B, H, M, N, device=DEVICE):
6366
block_mask = create_block_mask(score_mod, B, H, M, N, device=device)
6467
return block_mask
6568

@@ -136,19 +139,20 @@ def causal_mask(_, __, q_idx, kv_idx):
136139
def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provider):
137140
assert MODE in ['fwd']
138141
dtype = torch.float16
139-
q = torch.randn((Z, H_q, N_CTX_q, D_HEAD_qk), device='xpu', dtype=dtype, requires_grad=MODE == 'bwd')
140-
k = torch.randn((Z, H_kv, N_CTX_kv, D_HEAD_qk), device='xpu', dtype=dtype, requires_grad=MODE == 'bwd')
141-
v = torch.randn((Z, H_kv, N_CTX_kv, D_HEAD_v), device='xpu', dtype=dtype, requires_grad=MODE == 'bwd')
142+
q = torch.randn((Z, H_q, N_CTX_q, D_HEAD_qk), device=DEVICE, dtype=dtype, requires_grad=MODE == 'bwd')
143+
k = torch.randn((Z, H_kv, N_CTX_kv, D_HEAD_qk), device=DEVICE, dtype=dtype, requires_grad=MODE == 'bwd')
144+
v = torch.randn((Z, H_kv, N_CTX_kv, D_HEAD_v), device=DEVICE, dtype=dtype, requires_grad=MODE == 'bwd')
142145
sm_scale = 0.125
143146
if MODE == 'bwd':
144147
sm_scale = 1.3
145148

146149
quantiles = [0.5, 0.0, 1.0]
147-
block_mask = create_block_mask_cached(causal_mask, 1, 1, N_CTX_q, N_CTX_kv, device='xpu')
150+
block_mask = create_block_mask_cached(causal_mask, 1, 1, N_CTX_q, N_CTX_kv, device=DEVICE)
148151
torch_fn = lambda: flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=not H_q == H_kv)
149152

150153
if provider == 'torch':
151-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(torch_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
154+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(torch_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
155+
device=DEVICE)
152156

153157
elif provider == 'triton':
154158
kernel_options = {'BLOCKS_ARE_CONTIGUOUS': True}
@@ -160,7 +164,8 @@ def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provid
160164
triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True)
161165

162166
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-2, rtol=1e-3, err_msg='triton to torch')
163-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
167+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
168+
device=DEVICE)
164169

165170
elif provider == 'onednn':
166171
# OneDNN only supports MHA.

0 commit comments

Comments
 (0)