Skip to content

Commit c824160

Browse files
refactor model benchmarks (#704)
Added couple fixes that were noticed when benchmarking FA triton compared to torch: output initialization before int8. Otherwise the output will be int8. changed the 'sl' to 'sq' (marking sequence length) in rmsnorm and softmax aswell, to be more continuous between kernels. removed max_ctx_len as its not a well defined model parameter. N_CTX_Q is rather picked from args.sq which has a default value 4096. for clarity I also print the D_HEAD in the output because different models can have different values for this. converting thd and bshd layouts to torch compatible Co-authored-by: Tianxing Wu <[email protected]>
1 parent bc10f6b commit c824160

File tree

5 files changed

+61
-58
lines changed

5 files changed

+61
-58
lines changed

python/perf-kernels/flash-attention.py

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,10 +1879,10 @@ def model_benchmark_configs(args):
18791879
for model_name, config in configs.items():
18801880
HQ = config["num_attention_heads"]
18811881
HK = HQ if config["num_key_value_heads"] is None else config["num_key_value_heads"]
1882-
max_ctx_len = config["max_ctx_len"]
1883-
N_CTX_Q = args.sq if args.sq else max_ctx_len
1884-
N_CTX_K = args.sk if args.sk else max_ctx_len
1885-
fa_configs.append((model_name, batch_size, HQ, HK, N_CTX_Q, N_CTX_K))
1882+
N_CTX_Q = args.sq if args.sq else 4096
1883+
N_CTX_K = args.sk if args.sk else N_CTX_Q
1884+
HEAD_DIM = config["hidden_size"] // HQ
1885+
fa_configs.append((model_name, batch_size, HQ, HK, N_CTX_Q, N_CTX_K, HEAD_DIM))
18861886

18871887
return fa_configs
18881888

@@ -1902,6 +1902,7 @@ def run_benchmark(custom, args):
19021902
varlen = args.layout == 'thd'
19031903
configs = []
19041904
plot_name = f'fused-attention-{mode}-d{head_size}-layout{args.layout}'
1905+
extra_args = {'D_HEAD': head_size, 'dtype': dtype, 'causal': causal, 'mode': mode}
19051906
if custom:
19061907
x_vals_list = [(args.b, args.hq, hk, args.sq, sk)]
19071908
else:
@@ -1912,16 +1913,16 @@ def run_benchmark(custom, args):
19121913

19131914
if args.model:
19141915
x_vals_list = model_benchmark_configs(args)
1915-
x_names = ['model', 'BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K']
1916+
x_names = ['model', 'BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K', 'D_HEAD']
19161917
plot_name = f'fused-attention-{mode}-layout{args.layout}'
1918+
extra_args = {'dtype': dtype, 'causal': causal, 'mode': mode}
19171919

19181920
print_time = args.return_time
19191921
line_vals = ['triton', 'torch'] # 'Time (ms)' if print_time else 'TFLOPS'
19201922
configs.append(
19211923
triton.testing.Benchmark(x_names=x_names, x_vals=x_vals_list, line_arg='provider', line_vals=line_vals,
1922-
line_names=line_vals, styles=[('red', '-'),
1923-
('green', '-')], ylabel='ms', plot_name=plot_name,
1924-
args={'D_HEAD': head_size, 'dtype': dtype, 'causal': causal, 'mode': mode}))
1924+
line_names=line_vals, styles=[('green', '-'), ('red', '-')],
1925+
ylabel='Time (ms)' if print_time else 'TFLOPS', plot_name=plot_name, args=extra_args))
19251926

19261927
@triton.testing.perf_report(configs)
19271928
def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal, mode, provider, device="cuda",
@@ -1956,26 +1957,35 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal
19561957
flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD
19571958
if causal:
19581959
input_metadata.need_causal()
1959-
if int8:
1960-
q, k, v = quantize_input(q, k, v, input_metadata, quantize_p=quantize_p, int8_kv=int8_kv)
19611960

1962-
input_metadata.set_persistent(args.persistent)
1963-
o = torch.empty_like(q)
1964-
fn = lambda: attention(q, k, v, o, input_metadata)
1965-
if mode == 'bwd':
1966-
o, _ = fn()
1967-
do = torch.randn_like(o)
1968-
fn = lambda: o.backward(do, retain_graph=True)
1969-
1970-
if "torch" in provider:
1971-
if HQ != HK:
1972-
k = k.view(k.shape[0], k.shape[1], -1, k.shape[2],
1973-
k.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(k.shape[0], -1, k.shape[2], k.shape[3])
1974-
v = v.view(v.shape[0], v.shape[1], -1, v.shape[2],
1975-
v.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(v.shape[0], -1, v.shape[2], v.shape[3])
1976-
1977-
fn = lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0,
1978-
is_causal=causal, scale=None)
1961+
if "triton" in provider:
1962+
o = torch.empty_like(q)
1963+
if int8:
1964+
q, k, v = quantize_input(q, k, v, input_metadata, quantize_p=quantize_p, int8_kv=int8_kv)
1965+
input_metadata.set_persistent(args.persistent)
1966+
fn = lambda: attention(q, k, v, o, input_metadata)
1967+
if mode == 'bwd':
1968+
o, _ = fn()
1969+
do = torch.randn_like(o)
1970+
fn = lambda: o.backward(do, retain_graph=True)
1971+
1972+
elif "torch" in provider and args.layout in ["thd", "bhsd", "bshd"]:
1973+
# torch requires the layout to be (b (optional),...,h,s,d)
1974+
if args.layout in ["thd", "bshd"]:
1975+
q = q.transpose(-3, -2)
1976+
k = k.transpose(-3, -2)
1977+
v = v.transpose(-3, -2)
1978+
# check if GQA
1979+
HQ = q.shape[-3]
1980+
HK = k.shape[-3]
1981+
if HQ != HK: # TODO: sdpa(..., enable_gqa=True work) should work
1982+
k = k.repeat_interleave(q.size(-3) // k.size(-3), -3)
1983+
v = v.repeat_interleave(q.size(-3) // v.size(-3), -3)
1984+
1985+
fn = lambda: torch.nn.functional.scaled_dot_product_attention(
1986+
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=causal, scale=input_metadata.sm_scale)
1987+
else:
1988+
assert False, f"Unknown provider {provider} in flash-attention."
19791989

19801990
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
19811991
total_flops = 2 * flops_per_matmul
@@ -1984,9 +1994,9 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal
19841994
seqlen_q = N_CTX_Q
19851995
seqlen_k = N_CTX_K
19861996
if seqlen_q > seqlen_k:
1987-
total_flops *= seqlen_k / (2 * seqlen_q)
1997+
total_flops *= (seqlen_k / (2 * seqlen_q))
19881998
else:
1989-
total_flops *= 1 - seqlen_q / (2 * seqlen_k)
1999+
total_flops *= (1 - seqlen_q / (2 * seqlen_k))
19902000
if mode == "bwd":
19912001
total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
19922002
if print_time:
@@ -2014,8 +2024,9 @@ def parse_args():
20142024
parser.add_argument('-model_configs', type=str, default="model_configs.json", help="Model config json file.")
20152025

20162026
available_models = get_available_models(model_families=["llama3"]) # Dynamically load model names
2017-
model_help = ("Model name to benchmark. Select from: [" + ", ".join(available_models) +
2018-
"]. Use 'all' to benchmark all models or leave blank for the default benchmark script.")
2027+
model_help = (
2028+
"Model name to benchmark. Select from: [" + ", ".join(available_models) +
2029+
"]. Use 'all' to benchmark all models. Not providing runs the default benchmark script with custom configs.")
20192030
parser.add_argument('-model', type=str, default=None, help=model_help)
20202031
parser.add_argument("-b", type=int, default=0)
20212032
parser.add_argument("-hq", type=int, default=0)

python/perf-kernels/gemm.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -315,14 +315,12 @@ def parse_args():
315315
parser.add_argument('-model_configs', type=str, default="model_configs.json", help="Model config json file.")
316316

317317
available_models = get_available_models(model_families=["llama3"]) # Dynamically load model names
318-
model_help = ("Model name to benchmark. Select from: [" + ", ".join(available_models) +
319-
"]. Use 'all' to benchmark all models or leave blank for the default benchmark script.")
318+
model_help = (
319+
"Model name to benchmark. Select from: [" + ", ".join(available_models) +
320+
"]. Use 'all' to benchmark all models. Not providing runs the default benchmark script with custom configs.")
320321
parser.add_argument('-model', type=str, default=None, help=model_help)
321-
parser.add_argument('-b', type=int, default=0,
322-
help="Batch size used together with model. Defaults to 1 if not provided.")
323-
parser.add_argument(
324-
'-sl', type=int, default=0,
325-
help="Sequence length used together with model. Defaults to max_seq_len from model config if not provided.")
322+
parser.add_argument('-b', type=int, default=0, help="Batch size used together with model.")
323+
parser.add_argument('-sq', type=int, default=0, help="Sequence length used together with model.")
326324

327325
parser.add_argument("-v", action='store_true', default=False, help="Print out the best tuning config")
328326
parser.add_argument("-M", type=int, default=0)
@@ -348,7 +346,7 @@ def main():
348346
batch_size = args.b if args.b else 1
349347

350348
for model_name, config in configs.items():
351-
seq_len = args.sl if args.sl else config["max_ctx_len"]
349+
seq_len = args.sq if args.sq else 4096
352350
M, N, K = batch_size * seq_len, config["hidden_size"], config["intermediate_size"]
353351
mnk_list.append((model_name, M, N, K))
354352

python/perf-kernels/model_configs.json

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,20 @@
44
"num_attention_heads": 32,
55
"num_key_value_heads": 8,
66
"hidden_size": 4096,
7-
"max_ctx_len": 8192,
87
"intermediate_size": 14336,
98
"vocab_size": 128256
109
},
1110
"70B": {
1211
"num_attention_heads": 64,
1312
"num_key_value_heads": 8,
1413
"hidden_size": 8192,
15-
"max_ctx_len": 8192,
1614
"intermediate_size": 28672,
1715
"vocab_size": 128256
1816
},
1917
"405B": {
2018
"num_attention_heads": 128,
2119
"num_key_value_heads": 8,
2220
"hidden_size": 16384,
23-
"max_ctx_len": 8192,
2421
"intermediate_size": 53248,
2522
"vocab_size": 128256
2623
}
@@ -40,5 +37,6 @@
4037
"num_key_value_heads": 8,
4138
"vocab_size": 32000
4239
}
40+
4341
}
4442
}

python/perf-kernels/rmsnorm.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def model_benchmark_configs(args):
223223
batch_size = args.b if args.b else 1
224224

225225
for model_name, config in configs.items():
226-
seq_len = args.sl if args.sl else config["max_ctx_len"]
226+
seq_len = args.sq if args.sq else 4096
227227
x_vals_list.append((model_name, batch_size * seq_len, config["hidden_size"]))
228228

229229
return x_vals_list
@@ -309,14 +309,12 @@ def parse_args():
309309
parser.add_argument('-model_configs', type=str, default="model_configs.json", help="Model config json file.")
310310

311311
available_models = get_available_models(model_families=["llama3"]) # Dynamically load model names
312-
model_help = ("Model name to benchmark. Select from: [" + ", ".join(available_models) +
313-
"]. Use 'all' to benchmark all models or leave blank for the default benchmark script.")
312+
model_help = (
313+
"Model name to benchmark. Select from: [" + ", ".join(available_models) +
314+
"]. Use 'all' to benchmark all models. Not providing runs the default benchmark script with custom configs.")
314315
parser.add_argument('-model', type=str, default=None, help=model_help)
315-
parser.add_argument('-b', type=int, default=0,
316-
help="Batch size used together with model. Defaults to 1 if not provided.")
317-
parser.add_argument(
318-
'-sl', type=int, default=0,
319-
help="Sequence length used together with model. Defaults to max_seq_len from model config if not provided.")
316+
parser.add_argument('-b', type=int, default=0, help="Batch size used together with model.")
317+
parser.add_argument('-sq', type=int, default=0, help="Sequence length used together with model.")
320318
parser.add_argument('-M', "--M_start", default="1", type=int)
321319
parser.add_argument('-Ms', "--M_step", default="2", type=int) #This is multiplicative step
322320
parser.add_argument('-Me', "--M_end", default="512", type=int)

python/perf-kernels/softmax.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def model_benchmark_configs(args):
142142
batch_size = args.b if args.b else 1
143143

144144
for model_name, config in configs.items():
145-
seq_len = args.sl if args.sl else config["max_ctx_len"]
145+
seq_len = args.sq if args.sq else 4096
146146
x_vals_list.append((model_name, batch_size * seq_len, config["vocab_size"]))
147147

148148
return x_vals_list
@@ -217,14 +217,12 @@ def parse_args():
217217
parser.add_argument('-model_configs', type=str, default="model_configs.json", help="Model config json file.")
218218

219219
available_models = get_available_models(model_families=["llama3"]) # Dynamically load model names
220-
model_help = ("Model name to benchmark. Select from: [" + ", ".join(available_models) +
221-
"]. Use 'all' to benchmark all models or leave blank for the default benchmark script.")
220+
model_help = (
221+
"Model name to benchmark. Select from: [" + ", ".join(available_models) +
222+
"]. Use 'all' to benchmark all models. Not providing runs the default benchmark script with custom configs.")
222223
parser.add_argument('-model', type=str, default=None, help=model_help)
223-
parser.add_argument('-b', type=int, default=0,
224-
help="Batch size used together with model. Defaults to 1 if not provided.")
225-
parser.add_argument(
226-
'-sl', type=int, default=0,
227-
help="Sequence length used together with model. Defaults to max_seq_len from model config if not provided.")
224+
parser.add_argument('-b', type=int, default=0, help="Batch size used together with model.")
225+
parser.add_argument('-sq', type=int, default=0, help="Sequence length used together with model.")
228226
parser.add_argument('-M', "--M_start", default="1", type=int)
229227
parser.add_argument('-Ms', "--M_step", default="2", type=int)
230228
parser.add_argument('-Me', "--M_end", default="512", type=int)

0 commit comments

Comments
 (0)