16
16
from torch ._inductor .template_heuristics import FlexConfig , FlexDecodeConfig
17
17
18
18
import triton_kernels_benchmark as benchmark_suit
19
+ import triton
20
+
21
+ DEVICE = triton .runtime .driver .active .get_active_torch_device ()
19
22
20
23
# Use TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 or uncomment the following line to print the auto-tune results.
21
24
# torch._inductor.config.max_autotune_gemm = True
@@ -59,7 +62,7 @@ def get_flex_decode_configs(*args, **kwargs): # pylint: disable=unused-argument
59
62
60
63
61
64
@lru_cache
62
- def create_block_mask_cached (score_mod , B , H , M , N , device = 'xpu' ):
65
+ def create_block_mask_cached (score_mod , B , H , M , N , device = DEVICE ):
63
66
block_mask = create_block_mask (score_mod , B , H , M , N , device = device )
64
67
return block_mask
65
68
@@ -136,19 +139,20 @@ def causal_mask(_, __, q_idx, kv_idx):
136
139
def benchmark (Z , H_q , H_kv , N_CTX_q , N_CTX_kv , D_HEAD_qk , D_HEAD_v , MODE , provider ):
137
140
assert MODE in ['fwd' ]
138
141
dtype = torch .float16
139
- q = torch .randn ((Z , H_q , N_CTX_q , D_HEAD_qk ), device = 'xpu' , dtype = dtype , requires_grad = MODE == 'bwd' )
140
- k = torch .randn ((Z , H_kv , N_CTX_kv , D_HEAD_qk ), device = 'xpu' , dtype = dtype , requires_grad = MODE == 'bwd' )
141
- v = torch .randn ((Z , H_kv , N_CTX_kv , D_HEAD_v ), device = 'xpu' , dtype = dtype , requires_grad = MODE == 'bwd' )
142
+ q = torch .randn ((Z , H_q , N_CTX_q , D_HEAD_qk ), device = DEVICE , dtype = dtype , requires_grad = MODE == 'bwd' )
143
+ k = torch .randn ((Z , H_kv , N_CTX_kv , D_HEAD_qk ), device = DEVICE , dtype = dtype , requires_grad = MODE == 'bwd' )
144
+ v = torch .randn ((Z , H_kv , N_CTX_kv , D_HEAD_v ), device = DEVICE , dtype = dtype , requires_grad = MODE == 'bwd' )
142
145
sm_scale = 0.125
143
146
if MODE == 'bwd' :
144
147
sm_scale = 1.3
145
148
146
149
quantiles = [0.5 , 0.0 , 1.0 ]
147
- block_mask = create_block_mask_cached (causal_mask , 1 , 1 , N_CTX_q , N_CTX_kv , device = 'xpu' )
150
+ block_mask = create_block_mask_cached (causal_mask , 1 , 1 , N_CTX_q , N_CTX_kv , device = DEVICE )
148
151
torch_fn = lambda : flex_attention (q , k , v , block_mask = block_mask , scale = sm_scale , enable_gqa = not H_q == H_kv )
149
152
150
153
if provider == 'torch' :
151
- _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (torch_fn , n_warmup = 10 , n_repeat = 10 , quantiles = quantiles )
154
+ _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (torch_fn , n_warmup = 10 , n_repeat = 10 , quantiles = quantiles ,
155
+ device = DEVICE )
152
156
153
157
elif provider == 'triton' :
154
158
kernel_options = {'BLOCKS_ARE_CONTIGUOUS' : True }
@@ -160,7 +164,8 @@ def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provid
160
164
triton_fn = lambda : triton_o .backward (triton_do , retain_graph = True )
161
165
162
166
benchmark_suit .assert_close (triton_fn , torch_fn , atol = 1e-2 , rtol = 1e-3 , err_msg = 'triton to torch' )
163
- _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (triton_fn , n_warmup = 10 , n_repeat = 10 , quantiles = quantiles )
167
+ _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (triton_fn , n_warmup = 10 , n_repeat = 10 , quantiles = quantiles ,
168
+ device = DEVICE )
164
169
165
170
elif provider == 'onednn' :
166
171
# OneDNN only supports MHA.
0 commit comments