@@ -1330,8 +1330,12 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlen
13301330 if not equal_seqlens :
13311331 max_seqlens_q = N_CTX_Q // Z
13321332 max_seqlens_k = N_CTX_K // Z
1333- seqlens_q = torch .randint (1 , max_seqlens_q + 1 , (Z , ), dtype = torch .int32 )
1334- seqlens_k = torch .randint (1 , max_seqlens_k + 1 , (Z , ), dtype = torch .int32 )
1333+ if N_CTX_Q == N_CTX_K :
1334+ seqlens_q = torch .randint (1 , max_seqlens_q + 1 , (Z , ), dtype = torch .int32 )
1335+ seqlens_k = seqlens_q
1336+ else :
1337+ seqlens_q = torch .randint (1 , max_seqlens_q + 1 , (Z , ), dtype = torch .int32 )
1338+ seqlens_k = torch .randint (1 , max_seqlens_k + 1 , (Z , ), dtype = torch .int32 )
13351339 else :
13361340 seqlens_q = torch .full ((Z , ), N_CTX_Q // Z )
13371341 seqlens_k = torch .full ((Z , ), N_CTX_K // Z )
@@ -1900,7 +1904,7 @@ def model_benchmark_configs(args):
19001904 for model_name , config in configs .items ():
19011905 HQ = config ["num_attention_heads" ]
19021906 HK = HQ if config ["num_key_value_heads" ] is None else config ["num_key_value_heads" ]
1903- N_CTX_Q = args .sq if args .sq else 4096
1907+ N_CTX_Q = args .sq if args .sq else 8192
19041908 N_CTX_K = args .sk if args .sk else N_CTX_Q
19051909 HEAD_DIM = config ["hidden_size" ] // HQ
19061910 fa_configs .append ((model_name , batch_size , HQ , HK , N_CTX_Q , N_CTX_K , HEAD_DIM ))
@@ -1916,11 +1920,11 @@ def run_benchmark(custom, args):
19161920 head_size = 128 if not args .d else args .d
19171921 mode = 'fwd'
19181922 x_names = ['BATCH' , 'HQ' , 'HK' , 'N_CTX_Q' , 'N_CTX_K' ]
1919- causal = args .causal
1923+ causal = args .causal if not args . model else True
19201924 int8 = args .int8
19211925 quantize_p = args .quantize_p and int8
19221926 int8_kv = args .int8_kv and int8
1923- varlen = args .layout == 'thd'
1927+ varlen = True if args . model else args .layout == 'thd'
19241928 configs = []
19251929 plot_name = f'fused-attention-{ mode } -d{ head_size } -layout{ args .layout } '
19261930 extra_args = {'D_HEAD' : head_size , 'dtype' : dtype , 'causal' : causal , 'mode' : mode }
@@ -1969,13 +1973,23 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal
19691973 q , k , v , input_metadata = varlen_input_helper (BATCH , HQ , HK , N_CTX_Q , N_CTX_K , D_HEAD , dtype ,
19701974 args .equal_seqlens )
19711975 for i in range (0 , input_metadata .num_contexts ):
1972- seqlen_q = input_metadata .cu_seqlens_q [i + 1 ] - input_metadata .cu_seqlens_q [i ]
1973- seqlen_k = input_metadata .cu_seqlens_k [i + 1 ] - input_metadata .cu_seqlens_k [i ]
1974- # x2 for 2 GEMMs
1975- flops_per_matmul += seqlen_q .item () * seqlen_k .item () * HQ * D_HEAD * 2
1976+ seqlen_q = (input_metadata .cu_seqlens_q [i + 1 ] - input_metadata .cu_seqlens_q [i ]).item ()
1977+ seqlen_k = (input_metadata .cu_seqlens_k [i + 1 ] - input_metadata .cu_seqlens_k [i ]).item ()
1978+ # x2 in both cases for 2 GEMMs
1979+ if causal :
1980+ # If seqlen_q != seqlen_k then the causal mask ignores computation
1981+ # depending on which seqlen is larger. Either the lower triangle, or right triangle
1982+ causal_correction = seqlen_k if seqlen_q > seqlen_k else seqlen_q
1983+ flops_per_matmul += (seqlen_q * seqlen_k - (causal_correction ** 2 ) / 2 ) * HQ * D_HEAD * 2
1984+ else :
1985+ flops_per_matmul += seqlen_q * seqlen_k * HQ * D_HEAD * 2
19761986 else :
19771987 q , k , v , input_metadata = input_helper (BATCH , HQ , HK , N_CTX_Q , N_CTX_K , D_HEAD , dtype , args .layout )
1978- flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD
1988+ if causal :
1989+ causal_correction = N_CTX_K if N_CTX_Q > N_CTX_K else N_CTX_Q
1990+ flops_per_matmul = 2.0 * BATCH * HQ * (N_CTX_Q * N_CTX_K - (causal_correction ** 2 ) / 2 ) * D_HEAD
1991+ else :
1992+ flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD
19791993 if causal :
19801994 input_metadata .need_causal ()
19811995
@@ -2010,14 +2024,6 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal
20102024
20112025 ms = triton .testing .do_bench (fn , warmup = warmup , rep = rep )
20122026 total_flops = 2 * flops_per_matmul
2013- if causal :
2014- # total_flops *= 0.5 # normally, but we have to take into account the unequal seqlen_q/k
2015- seqlen_q = N_CTX_Q
2016- seqlen_k = N_CTX_K
2017- if seqlen_q > seqlen_k :
2018- total_flops *= (seqlen_k / (2 * seqlen_q ))
2019- else :
2020- total_flops *= (1 - seqlen_q / (2 * seqlen_k ))
20212027 if mode == "bwd" :
20222028 total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
20232029 if print_time :
@@ -2077,8 +2083,8 @@ def parse_args():
20772083def main ():
20782084 args = parse_args ()
20792085 custom_config = False
2080- assert args .layout == 'thd' or not args .equal_seqlens , \
2081- "Equal sequence lengths arg must be used with the thd layout."
2086+ assert args .layout == 'thd' or not args .equal_seqlens or args . model , \
2087+ "Equal sequence lengths arg must be used with the thd layout or a model config ."
20822088 if args .hq or args .hk or args .d :
20832089 custom_config = True
20842090 assert args .b and args .hq and args .sq and args .d , \
@@ -2093,6 +2099,9 @@ def main():
20932099 assert args .dtype in arg_to_torch_dtype , \
20942100 "Only fp16, bf16 and f32 types currently supported."
20952101
2102+ if args .model :
2103+ print ("Note: Model config sets causal masking and THD layout (varlen) by default." )
2104+
20962105 run_benchmark (custom_config , args )
20972106
20982107
0 commit comments