@@ -276,13 +276,12 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
276276 causal_mask = OFFS_M [:, None ] >= causal_boundary [None , :]
277277 qk = tl .where (causal_mask , qk , float ("-inf" ))
278278 # -- compute qk ----
279-
280279 if INT8_GEMM :
281280 qk += ((((tl .dot (q , k ).to (tl .float32 ) * q_descale )) * k_descale ) * QK_SCALE )
282281 else :
283282 if INT8_KV :
284283 k = (k * k_descale ).to (q .type .element_ty )
285- qk += tl .dot (q , k ) * QK_SCALE
284+ qk += ( tl .dot (q , k ) * QK_SCALE )
286285
287286 if bias_ptrs is not None :
288287 bias_offs_n = start_n + tl .arange (0 , BLOCK_N ) if MASK_STEPS else None
@@ -1870,6 +1869,49 @@ def varlen_benchmark_configs():
18701869 return configs
18711870
18721871
1872+ def model_benchmark_configs (args ):
1873+ import os
1874+ import json
1875+ # If user did not provide an absolute path, resolve relative path from script directory
1876+ if not os .path .isabs (args .model_configs ):
1877+ config_file = os .path .join (os .path .dirname (os .path .abspath (__file__ )), args .model_configs )
1878+ else :
1879+ config_file = args .model_configs
1880+
1881+ with open (config_file , 'r' ) as f :
1882+ configs = json .load (f )
1883+ fa_configs = []
1884+
1885+ if args .model != "all" :
1886+ # Check if the model exists
1887+ model_name = args .model
1888+ if model_name not in configs :
1889+ raise ValueError (f"Model '{ model_name } ' not found in { config_file } " )
1890+ # Handle a specific model
1891+ config = configs [model_name ]
1892+ HQ = config ["num_attention_heads" ]
1893+ HK = HQ if config ["num_key_value_heads" ] is None else config ["num_key_value_heads" ]
1894+
1895+ max_ctx_len = config ["max_ctx_len" ]
1896+ N_CTX_Q = args .sq if args .sq else max_ctx_len
1897+ N_CTX_K = args .sk if args .sk else max_ctx_len
1898+ batch_size = args .b if args .b else 1
1899+
1900+ fa_configs .append ((model_name , batch_size , HQ , HK , N_CTX_Q , N_CTX_K ))
1901+ else :
1902+ # Handle all models
1903+ for model_name , config in configs .items ():
1904+ HQ = config ["num_attention_heads" ]
1905+ HK = HQ if config ["num_key_value_heads" ] is None else config ["num_key_value_heads" ]
1906+ max_ctx_len = config ["max_ctx_len" ]
1907+ N_CTX_Q = args .sq if args .sq else max_ctx_len
1908+ N_CTX_K = args .sk if args .sk else max_ctx_len
1909+ batch_size = args .b if args .b else 1
1910+ fa_configs .append ((model_name , batch_size , HQ , HK , N_CTX_Q , N_CTX_K ))
1911+
1912+ return fa_configs
1913+
1914+
18731915def run_benchmark (custom , args ):
18741916
18751917 dtype = arg_to_torch_dtype [args .dtype ]
@@ -1884,6 +1926,7 @@ def run_benchmark(custom, args):
18841926 int8_kv = args .int8_kv and int8
18851927 varlen = args .layout == 'thd'
18861928 configs = []
1929+ plot_name = f'fused-attention-{ mode } -d{ head_size } -layout{ args .layout } '
18871930 if custom :
18881931 x_vals_list = [(args .b , args .hq , hk , args .sq , sk )]
18891932 else :
@@ -1892,16 +1935,22 @@ def run_benchmark(custom, args):
18921935 else :
18931936 x_vals_list = nonvarlen_benchmark_configs ()
18941937
1938+ if args .model :
1939+ x_vals_list = model_benchmark_configs (args )
1940+ x_names = ['model' , 'BATCH' , 'HQ' , 'HK' , 'N_CTX_Q' , 'N_CTX_K' ]
1941+ plot_name = f'fused-attention-{ mode } -layout{ args .layout } '
1942+
18951943 print_time = args .return_time
1896- line_names = 'Time (ms)' if print_time else 'TFLOPS'
1944+ line_vals = [ 'triton' , 'torch' ] # 'Time (ms)' if print_time else 'TFLOPS'
18971945 configs .append (
1898- triton .testing .Benchmark (x_names = x_names , x_vals = x_vals_list , line_arg = 'provider' , line_vals = [ 'triton' ] ,
1899- line_names = [ line_names ] , styles = [('red' , '-' )], ylabel = 'ms' ,
1900- plot_name = f'fused-attention- { mode } -d { head_size } -layout { args . layout } ' ,
1946+ triton .testing .Benchmark (x_names = x_names , x_vals = x_vals_list , line_arg = 'provider' , line_vals = line_vals ,
1947+ line_names = line_vals , styles = [('red' , '-' ),
1948+ ( 'green' , '-' )], ylabel = 'ms' , plot_name = plot_name ,
19011949 args = {'D_HEAD' : head_size , 'dtype' : dtype , 'causal' : causal , 'mode' : mode }))
19021950
19031951 @triton .testing .perf_report (configs )
1904- def bench_flash_attention (BATCH , HQ , HK , N_CTX_Q , N_CTX_K , D_HEAD , dtype , causal , mode , provider , device = "cuda" ):
1952+ def bench_flash_attention (BATCH , HQ , HK , N_CTX_Q , N_CTX_K , D_HEAD , dtype , causal , mode , provider , device = "cuda" ,
1953+ model = None ):
19051954 assert mode in ["fwd" , "bwd" ]
19061955 assert not (int8_kv and quantize_p )
19071956 warmup = 25
@@ -1942,6 +1991,17 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal
19421991 o , _ = fn ()
19431992 do = torch .randn_like (o )
19441993 fn = lambda : o .backward (do , retain_graph = True )
1994+
1995+ if "torch" in provider :
1996+ if HQ != HK :
1997+ k = k .view (k .shape [0 ], k .shape [1 ], - 1 , k .shape [2 ],
1998+ k .shape [3 ]).expand (- 1 , - 1 , HQ // HK , - 1 , - 1 ).reshape (k .shape [0 ], - 1 , k .shape [2 ], k .shape [3 ])
1999+ v = v .view (v .shape [0 ], v .shape [1 ], - 1 , v .shape [2 ],
2000+ v .shape [3 ]).expand (- 1 , - 1 , HQ // HK , - 1 , - 1 ).reshape (v .shape [0 ], - 1 , v .shape [2 ], v .shape [3 ])
2001+
2002+ fn = lambda : torch .nn .functional .scaled_dot_product_attention (q , k , v , attn_mask = None , dropout_p = 0.0 ,
2003+ is_causal = causal , scale = None )
2004+
19452005 ms = triton .testing .do_bench (fn , warmup = warmup , rep = rep )
19462006 total_flops = 2 * flops_per_matmul
19472007 if causal :
@@ -1959,7 +2019,7 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal
19592019 else :
19602020 return total_flops / ms * 1e-9
19612021
1962- bench_flash_attention .run (save_path = "." , print_data = True )
2022+ bench_flash_attention .run (save_path = "." , print_data = True , show_plots = True )
19632023
19642024
19652025def supported_layouts ():
@@ -1976,6 +2036,21 @@ def parse_args():
19762036 prog = "Benchmark FlashAttention" ,
19772037 allow_abbrev = False ,
19782038 )
2039+ parser .add_argument ('-model_configs' , type = str , default = "model_configs.json" , help = "Model config json file." )
2040+
2041+ def get_available_models (config_file = 'model_configs.json' ):
2042+ import os
2043+ import json
2044+ """Load model names from the configuration file."""
2045+ config_path = os .path .join (os .path .dirname (os .path .abspath (__file__ )), config_file )
2046+ with open (config_path , 'r' ) as f :
2047+ configs = json .load (f )
2048+ return list (configs .keys ())
2049+
2050+ available_models = get_available_models () # Dynamically load model names
2051+ model_help = ("Model name to benchmark. Select from: [" + ", " .join (available_models ) +
2052+ "]. Use 'all' to benchmark all models or leave blank for the default benchmark script." )
2053+ parser .add_argument ('-model' , type = str , default = None , help = model_help )
19792054 parser .add_argument ("-b" , type = int , default = 0 )
19802055 parser .add_argument ("-hq" , type = int , default = 0 )
19812056 parser .add_argument ("-hk" , type = int , default = 0 )
@@ -2006,13 +2081,17 @@ def main():
20062081 custom_config = False
20072082 assert args .layout == 'thd' or not args .equal_seqlens , \
20082083 "Equal sequence lengths arg must be used with the thd layout."
2009- if args .b or args . hq or args .hk or args . sq or args . sk or args .d :
2084+ if args .hq or args .hk or args .d :
20102085 custom_config = True
20112086 assert args .b and args .hq and args .sq and args .d , \
20122087 "If custom config is specified, please provide \
20132088 all of batch, number of Q heads, Q sequence length \
20142089 and head size."
20152090
2091+ if args .model :
2092+ assert not (args .hq or args .hk or args .d ), \
2093+ "Specifying model fixes hq, hk and d already. Do not provide them!"
2094+
20162095 assert args .dtype in arg_to_torch_dtype , \
20172096 "Only fp16, bf16 and f32 types currently supported."
20182097
0 commit comments