@@ -1879,10 +1879,10 @@ def model_benchmark_configs(args):
18791879 for model_name , config in configs .items ():
18801880 HQ = config ["num_attention_heads" ]
18811881 HK = HQ if config ["num_key_value_heads" ] is None else config ["num_key_value_heads" ]
1882- max_ctx_len = config [ "max_ctx_len" ]
1883- N_CTX_Q = args .sq if args .sq else max_ctx_len
1884- N_CTX_K = args . sk if args . sk else max_ctx_len
1885- fa_configs .append ((model_name , batch_size , HQ , HK , N_CTX_Q , N_CTX_K ))
1882+ N_CTX_Q = args . sq if args . sq else 4096
1883+ N_CTX_K = args .sk if args .sk else N_CTX_Q
1884+ HEAD_DIM = config [ "hidden_size" ] // HQ
1885+ fa_configs .append ((model_name , batch_size , HQ , HK , N_CTX_Q , N_CTX_K , HEAD_DIM ))
18861886
18871887 return fa_configs
18881888
@@ -1902,6 +1902,7 @@ def run_benchmark(custom, args):
19021902 varlen = args .layout == 'thd'
19031903 configs = []
19041904 plot_name = f'fused-attention-{ mode } -d{ head_size } -layout{ args .layout } '
1905+ extra_args = {'D_HEAD' : head_size , 'dtype' : dtype , 'causal' : causal , 'mode' : mode }
19051906 if custom :
19061907 x_vals_list = [(args .b , args .hq , hk , args .sq , sk )]
19071908 else :
@@ -1912,16 +1913,16 @@ def run_benchmark(custom, args):
19121913
19131914 if args .model :
19141915 x_vals_list = model_benchmark_configs (args )
1915- x_names = ['model' , 'BATCH' , 'HQ' , 'HK' , 'N_CTX_Q' , 'N_CTX_K' ]
1916+ x_names = ['model' , 'BATCH' , 'HQ' , 'HK' , 'N_CTX_Q' , 'N_CTX_K' , 'D_HEAD' ]
19161917 plot_name = f'fused-attention-{ mode } -layout{ args .layout } '
1918+ extra_args = {'dtype' : dtype , 'causal' : causal , 'mode' : mode }
19171919
19181920 print_time = args .return_time
19191921 line_vals = ['triton' , 'torch' ] # 'Time (ms)' if print_time else 'TFLOPS'
19201922 configs .append (
19211923 triton .testing .Benchmark (x_names = x_names , x_vals = x_vals_list , line_arg = 'provider' , line_vals = line_vals ,
1922- line_names = line_vals , styles = [('red' , '-' ),
1923- ('green' , '-' )], ylabel = 'ms' , plot_name = plot_name ,
1924- args = {'D_HEAD' : head_size , 'dtype' : dtype , 'causal' : causal , 'mode' : mode }))
1924+ line_names = line_vals , styles = [('green' , '-' ), ('red' , '-' )],
1925+ ylabel = 'Time (ms)' if print_time else 'TFLOPS' , plot_name = plot_name , args = extra_args ))
19251926
19261927 @triton .testing .perf_report (configs )
19271928 def bench_flash_attention (BATCH , HQ , HK , N_CTX_Q , N_CTX_K , D_HEAD , dtype , causal , mode , provider , device = "cuda" ,
@@ -1956,26 +1957,35 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal
19561957 flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD
19571958 if causal :
19581959 input_metadata .need_causal ()
1959- if int8 :
1960- q , k , v = quantize_input (q , k , v , input_metadata , quantize_p = quantize_p , int8_kv = int8_kv )
19611960
1962- input_metadata .set_persistent (args .persistent )
1963- o = torch .empty_like (q )
1964- fn = lambda : attention (q , k , v , o , input_metadata )
1965- if mode == 'bwd' :
1966- o , _ = fn ()
1967- do = torch .randn_like (o )
1968- fn = lambda : o .backward (do , retain_graph = True )
1969-
1970- if "torch" in provider :
1971- if HQ != HK :
1972- k = k .view (k .shape [0 ], k .shape [1 ], - 1 , k .shape [2 ],
1973- k .shape [3 ]).expand (- 1 , - 1 , HQ // HK , - 1 , - 1 ).reshape (k .shape [0 ], - 1 , k .shape [2 ], k .shape [3 ])
1974- v = v .view (v .shape [0 ], v .shape [1 ], - 1 , v .shape [2 ],
1975- v .shape [3 ]).expand (- 1 , - 1 , HQ // HK , - 1 , - 1 ).reshape (v .shape [0 ], - 1 , v .shape [2 ], v .shape [3 ])
1976-
1977- fn = lambda : torch .nn .functional .scaled_dot_product_attention (q , k , v , attn_mask = None , dropout_p = 0.0 ,
1978- is_causal = causal , scale = None )
1961+ if "triton" in provider :
1962+ o = torch .empty_like (q )
1963+ if int8 :
1964+ q , k , v = quantize_input (q , k , v , input_metadata , quantize_p = quantize_p , int8_kv = int8_kv )
1965+ input_metadata .set_persistent (args .persistent )
1966+ fn = lambda : attention (q , k , v , o , input_metadata )
1967+ if mode == 'bwd' :
1968+ o , _ = fn ()
1969+ do = torch .randn_like (o )
1970+ fn = lambda : o .backward (do , retain_graph = True )
1971+
1972+ elif "torch" in provider and args .layout in ["thd" , "bhsd" , "bshd" ]:
1973+ # torch requires the layout to be (b (optional),...,h,s,d)
1974+ if args .layout in ["thd" , "bshd" ]:
1975+ q = q .transpose (- 3 , - 2 )
1976+ k = k .transpose (- 3 , - 2 )
1977+ v = v .transpose (- 3 , - 2 )
1978+ # check if GQA
1979+ HQ = q .shape [- 3 ]
1980+ HK = k .shape [- 3 ]
1981+ if HQ != HK : # TODO: sdpa(..., enable_gqa=True work) should work
1982+ k = k .repeat_interleave (q .size (- 3 ) // k .size (- 3 ), - 3 )
1983+ v = v .repeat_interleave (q .size (- 3 ) // v .size (- 3 ), - 3 )
1984+
1985+ fn = lambda : torch .nn .functional .scaled_dot_product_attention (
1986+ q , k , v , attn_mask = None , dropout_p = 0.0 , is_causal = causal , scale = input_metadata .sm_scale )
1987+ else :
1988+ assert False , f"Unknown provider { provider } in flash-attention."
19791989
19801990 ms = triton .testing .do_bench (fn , warmup = warmup , rep = rep )
19811991 total_flops = 2 * flops_per_matmul
@@ -1984,9 +1994,9 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal
19841994 seqlen_q = N_CTX_Q
19851995 seqlen_k = N_CTX_K
19861996 if seqlen_q > seqlen_k :
1987- total_flops *= seqlen_k / (2 * seqlen_q )
1997+ total_flops *= ( seqlen_k / (2 * seqlen_q ) )
19881998 else :
1989- total_flops *= 1 - seqlen_q / (2 * seqlen_k )
1999+ total_flops *= ( 1 - seqlen_q / (2 * seqlen_k ) )
19902000 if mode == "bwd" :
19912001 total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
19922002 if print_time :
@@ -2014,8 +2024,9 @@ def parse_args():
20142024 parser .add_argument ('-model_configs' , type = str , default = "model_configs.json" , help = "Model config json file." )
20152025
20162026 available_models = get_available_models (model_families = ["llama3" ]) # Dynamically load model names
2017- model_help = ("Model name to benchmark. Select from: [" + ", " .join (available_models ) +
2018- "]. Use 'all' to benchmark all models or leave blank for the default benchmark script." )
2027+ model_help = (
2028+ "Model name to benchmark. Select from: [" + ", " .join (available_models ) +
2029+ "]. Use 'all' to benchmark all models. Not providing runs the default benchmark script with custom configs." )
20192030 parser .add_argument ('-model' , type = str , default = None , help = model_help )
20202031 parser .add_argument ("-b" , type = int , default = 0 )
20212032 parser .add_argument ("-hq" , type = int , default = 0 )
0 commit comments