@@ -87,8 +87,8 @@ def causal_mask(_, __, q_idx, kv_idx):
8787 # Decode shapes of Deepseek-v3 (Rope)
8888 [],
8989 line_arg = 'provider' ,
90- line_vals = ['triton' ],
91- line_names = ['Triton' ],
90+ line_vals = ['triton' , 'torch' ],
91+ line_names = ['Triton' , 'Torch' ],
9292 styles = [('green' , '-' ), ('green' , '--' ), ('blue' , '-' ), ('blue' , '--' )],
9393 ylabel = ['GB/s' , 'TFlops' ],
9494 plot_name = 'flexAttnCausal-performance' ,
@@ -105,12 +105,16 @@ def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provid
105105 sm_scale = 1.3
106106
107107 quantiles = [0.5 , 0.0 , 1.0 ]
108- if provider == 'triton' :
108+ block_mask = create_block_mask_cached (causal_mask , 1 , 1 , N_CTX_q , N_CTX_kv , device = 'xpu' )
109+ torch_fn = lambda : flex_attention (q , k , v , block_mask = block_mask , scale = sm_scale , enable_gqa = not H_q == H_kv )
110+
111+ if provider == 'torch' :
112+ _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (torch_fn , n_warmup = 10 , n_repeat = 10 , quantiles = quantiles )
113+
114+ elif provider == 'triton' :
109115 kernel_options = {'num_stages' : 2 , 'num_warps' : 16 if D_HEAD_qk == 128 else 8 , 'BLOCKS_ARE_CONTIGUOUS' : True }
110- block_mask = create_block_mask_cached (causal_mask , 1 , 1 , N_CTX_q , N_CTX_kv , device = 'xpu' )
111116 triton_fn = lambda : compiled_flex_attention (q , k , v , block_mask = block_mask , scale = sm_scale , enable_gqa = (
112117 not H_q == H_kv ), kernel_options = kernel_options )
113- torch_fn = lambda : flex_attention (q , k , v , block_mask = block_mask , scale = sm_scale , enable_gqa = not H_q == H_kv )
114118 if MODE == 'bwd' :
115119 triton_o = triton_fn ()
116120 triton_do = torch .randn_like (triton_o )
0 commit comments