Skip to content

Commit 5653590

Browse files
authored
Changes to FA benchmarking (#742)
1) Make causal default if using one of the canned models 2) Fix computation of TFLOPs
1 parent 0d975d4 commit 5653590

File tree

1 file changed

+29
-20
lines changed

1 file changed

+29
-20
lines changed

python/perf-kernels/flash-attention.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,8 +1330,12 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlen
13301330
if not equal_seqlens:
13311331
max_seqlens_q = N_CTX_Q // Z
13321332
max_seqlens_k = N_CTX_K // Z
1333-
seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z, ), dtype=torch.int32)
1334-
seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z, ), dtype=torch.int32)
1333+
if N_CTX_Q == N_CTX_K:
1334+
seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z, ), dtype=torch.int32)
1335+
seqlens_k = seqlens_q
1336+
else:
1337+
seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z, ), dtype=torch.int32)
1338+
seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z, ), dtype=torch.int32)
13351339
else:
13361340
seqlens_q = torch.full((Z, ), N_CTX_Q // Z)
13371341
seqlens_k = torch.full((Z, ), N_CTX_K // Z)
@@ -1900,7 +1904,7 @@ def model_benchmark_configs(args):
19001904
for model_name, config in configs.items():
19011905
HQ = config["num_attention_heads"]
19021906
HK = HQ if config["num_key_value_heads"] is None else config["num_key_value_heads"]
1903-
N_CTX_Q = args.sq if args.sq else 4096
1907+
N_CTX_Q = args.sq if args.sq else 8192
19041908
N_CTX_K = args.sk if args.sk else N_CTX_Q
19051909
HEAD_DIM = config["hidden_size"] // HQ
19061910
fa_configs.append((model_name, batch_size, HQ, HK, N_CTX_Q, N_CTX_K, HEAD_DIM))
@@ -1916,11 +1920,11 @@ def run_benchmark(custom, args):
19161920
head_size = 128 if not args.d else args.d
19171921
mode = 'fwd'
19181922
x_names = ['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K']
1919-
causal = args.causal
1923+
causal = args.causal if not args.model else True
19201924
int8 = args.int8
19211925
quantize_p = args.quantize_p and int8
19221926
int8_kv = args.int8_kv and int8
1923-
varlen = args.layout == 'thd'
1927+
varlen = True if args.model else args.layout == 'thd'
19241928
configs = []
19251929
plot_name = f'fused-attention-{mode}-d{head_size}-layout{args.layout}'
19261930
extra_args = {'D_HEAD': head_size, 'dtype': dtype, 'causal': causal, 'mode': mode}
@@ -1969,13 +1973,23 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal
19691973
q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype,
19701974
args.equal_seqlens)
19711975
for i in range(0, input_metadata.num_contexts):
1972-
seqlen_q = input_metadata.cu_seqlens_q[i + 1] - input_metadata.cu_seqlens_q[i]
1973-
seqlen_k = input_metadata.cu_seqlens_k[i + 1] - input_metadata.cu_seqlens_k[i]
1974-
# x2 for 2 GEMMs
1975-
flops_per_matmul += seqlen_q.item() * seqlen_k.item() * HQ * D_HEAD * 2
1976+
seqlen_q = (input_metadata.cu_seqlens_q[i + 1] - input_metadata.cu_seqlens_q[i]).item()
1977+
seqlen_k = (input_metadata.cu_seqlens_k[i + 1] - input_metadata.cu_seqlens_k[i]).item()
1978+
# x2 in both cases for 2 GEMMs
1979+
if causal:
1980+
# If seqlen_q != seqlen_k then the causal mask ignores computation
1981+
# depending on which seqlen is larger. Either the lower triangle, or right triangle
1982+
causal_correction = seqlen_k if seqlen_q > seqlen_k else seqlen_q
1983+
flops_per_matmul += (seqlen_q * seqlen_k - (causal_correction**2) / 2) * HQ * D_HEAD * 2
1984+
else:
1985+
flops_per_matmul += seqlen_q * seqlen_k * HQ * D_HEAD * 2
19761986
else:
19771987
q, k, v, input_metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, args.layout)
1978-
flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD
1988+
if causal:
1989+
causal_correction = N_CTX_K if N_CTX_Q > N_CTX_K else N_CTX_Q
1990+
flops_per_matmul = 2.0 * BATCH * HQ * (N_CTX_Q * N_CTX_K - (causal_correction**2) / 2) * D_HEAD
1991+
else:
1992+
flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD
19791993
if causal:
19801994
input_metadata.need_causal()
19811995

@@ -2010,14 +2024,6 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal
20102024

20112025
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
20122026
total_flops = 2 * flops_per_matmul
2013-
if causal:
2014-
# total_flops *= 0.5 # normally, but we have to take into account the unequal seqlen_q/k
2015-
seqlen_q = N_CTX_Q
2016-
seqlen_k = N_CTX_K
2017-
if seqlen_q > seqlen_k:
2018-
total_flops *= (seqlen_k / (2 * seqlen_q))
2019-
else:
2020-
total_flops *= (1 - seqlen_q / (2 * seqlen_k))
20212027
if mode == "bwd":
20222028
total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
20232029
if print_time:
@@ -2077,8 +2083,8 @@ def parse_args():
20772083
def main():
20782084
args = parse_args()
20792085
custom_config = False
2080-
assert args.layout == 'thd' or not args.equal_seqlens, \
2081-
"Equal sequence lengths arg must be used with the thd layout."
2086+
assert args.layout == 'thd' or not args.equal_seqlens or args.model, \
2087+
"Equal sequence lengths arg must be used with the thd layout or a model config."
20822088
if args.hq or args.hk or args.d:
20832089
custom_config = True
20842090
assert args.b and args.hq and args.sq and args.d, \
@@ -2093,6 +2099,9 @@ def main():
20932099
assert args.dtype in arg_to_torch_dtype, \
20942100
"Only fp16, bf16 and f32 types currently supported."
20952101

2102+
if args.model:
2103+
print("Note: Model config sets causal masking and THD layout (varlen) by default.")
2104+
20962105
run_benchmark(custom_config, args)
20972106

20982107

0 commit comments

Comments
 (0)