Skip to content

Commit 069281e

Browse files
authored
Adding model benchmarks (#691)
Adds benchmarks to perf-kernels where shapes are determined based on the real-life models (like llama3) configured in model_configs.json. rmsnorm, softmax, flash-attention and gemm can now call model benchmarks with the -model command line argument.
1 parent 16ce746 commit 069281e

File tree

5 files changed

+305
-15
lines changed

5 files changed

+305
-15
lines changed

python/perf-kernels/flash-attention.py

Lines changed: 88 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -276,13 +276,12 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
276276
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
277277
qk = tl.where(causal_mask, qk, float("-inf"))
278278
# -- compute qk ----
279-
280279
if INT8_GEMM:
281280
qk += ((((tl.dot(q, k).to(tl.float32) * q_descale)) * k_descale) * QK_SCALE)
282281
else:
283282
if INT8_KV:
284283
k = (k * k_descale).to(q.type.element_ty)
285-
qk += tl.dot(q, k) * QK_SCALE
284+
qk += (tl.dot(q, k) * QK_SCALE)
286285

287286
if bias_ptrs is not None:
288287
bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None
@@ -1870,6 +1869,49 @@ def varlen_benchmark_configs():
18701869
return configs
18711870

18721871

1872+
def model_benchmark_configs(args):
1873+
import os
1874+
import json
1875+
# If user did not provide an absolute path, resolve relative path from script directory
1876+
if not os.path.isabs(args.model_configs):
1877+
config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), args.model_configs)
1878+
else:
1879+
config_file = args.model_configs
1880+
1881+
with open(config_file, 'r') as f:
1882+
configs = json.load(f)
1883+
fa_configs = []
1884+
1885+
if args.model != "all":
1886+
# Check if the model exists
1887+
model_name = args.model
1888+
if model_name not in configs:
1889+
raise ValueError(f"Model '{model_name}' not found in {config_file}")
1890+
# Handle a specific model
1891+
config = configs[model_name]
1892+
HQ = config["num_attention_heads"]
1893+
HK = HQ if config["num_key_value_heads"] is None else config["num_key_value_heads"]
1894+
1895+
max_ctx_len = config["max_ctx_len"]
1896+
N_CTX_Q = args.sq if args.sq else max_ctx_len
1897+
N_CTX_K = args.sk if args.sk else max_ctx_len
1898+
batch_size = args.b if args.b else 1
1899+
1900+
fa_configs.append((model_name, batch_size, HQ, HK, N_CTX_Q, N_CTX_K))
1901+
else:
1902+
# Handle all models
1903+
for model_name, config in configs.items():
1904+
HQ = config["num_attention_heads"]
1905+
HK = HQ if config["num_key_value_heads"] is None else config["num_key_value_heads"]
1906+
max_ctx_len = config["max_ctx_len"]
1907+
N_CTX_Q = args.sq if args.sq else max_ctx_len
1908+
N_CTX_K = args.sk if args.sk else max_ctx_len
1909+
batch_size = args.b if args.b else 1
1910+
fa_configs.append((model_name, batch_size, HQ, HK, N_CTX_Q, N_CTX_K))
1911+
1912+
return fa_configs
1913+
1914+
18731915
def run_benchmark(custom, args):
18741916

18751917
dtype = arg_to_torch_dtype[args.dtype]
@@ -1884,6 +1926,7 @@ def run_benchmark(custom, args):
18841926
int8_kv = args.int8_kv and int8
18851927
varlen = args.layout == 'thd'
18861928
configs = []
1929+
plot_name = f'fused-attention-{mode}-d{head_size}-layout{args.layout}'
18871930
if custom:
18881931
x_vals_list = [(args.b, args.hq, hk, args.sq, sk)]
18891932
else:
@@ -1892,16 +1935,22 @@ def run_benchmark(custom, args):
18921935
else:
18931936
x_vals_list = nonvarlen_benchmark_configs()
18941937

1938+
if args.model:
1939+
x_vals_list = model_benchmark_configs(args)
1940+
x_names = ['model', 'BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K']
1941+
plot_name = f'fused-attention-{mode}-layout{args.layout}'
1942+
18951943
print_time = args.return_time
1896-
line_names = 'Time (ms)' if print_time else 'TFLOPS'
1944+
line_vals = ['triton', 'torch'] # 'Time (ms)' if print_time else 'TFLOPS'
18971945
configs.append(
1898-
triton.testing.Benchmark(x_names=x_names, x_vals=x_vals_list, line_arg='provider', line_vals=['triton'],
1899-
line_names=[line_names], styles=[('red', '-')], ylabel='ms',
1900-
plot_name=f'fused-attention-{mode}-d{head_size}-layout{args.layout}',
1946+
triton.testing.Benchmark(x_names=x_names, x_vals=x_vals_list, line_arg='provider', line_vals=line_vals,
1947+
line_names=line_vals, styles=[('red', '-'),
1948+
('green', '-')], ylabel='ms', plot_name=plot_name,
19011949
args={'D_HEAD': head_size, 'dtype': dtype, 'causal': causal, 'mode': mode}))
19021950

19031951
@triton.testing.perf_report(configs)
1904-
def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal, mode, provider, device="cuda"):
1952+
def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal, mode, provider, device="cuda",
1953+
model=None):
19051954
assert mode in ["fwd", "bwd"]
19061955
assert not (int8_kv and quantize_p)
19071956
warmup = 25
@@ -1942,6 +1991,17 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal
19421991
o, _ = fn()
19431992
do = torch.randn_like(o)
19441993
fn = lambda: o.backward(do, retain_graph=True)
1994+
1995+
if "torch" in provider:
1996+
if HQ != HK:
1997+
k = k.view(k.shape[0], k.shape[1], -1, k.shape[2],
1998+
k.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(k.shape[0], -1, k.shape[2], k.shape[3])
1999+
v = v.view(v.shape[0], v.shape[1], -1, v.shape[2],
2000+
v.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(v.shape[0], -1, v.shape[2], v.shape[3])
2001+
2002+
fn = lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0,
2003+
is_causal=causal, scale=None)
2004+
19452005
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
19462006
total_flops = 2 * flops_per_matmul
19472007
if causal:
@@ -1959,7 +2019,7 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal
19592019
else:
19602020
return total_flops / ms * 1e-9
19612021

1962-
bench_flash_attention.run(save_path=".", print_data=True)
2022+
bench_flash_attention.run(save_path=".", print_data=True, show_plots=True)
19632023

19642024

19652025
def supported_layouts():
@@ -1976,6 +2036,21 @@ def parse_args():
19762036
prog="Benchmark FlashAttention",
19772037
allow_abbrev=False,
19782038
)
2039+
parser.add_argument('-model_configs', type=str, default="model_configs.json", help="Model config json file.")
2040+
2041+
def get_available_models(config_file='model_configs.json'):
2042+
import os
2043+
import json
2044+
"""Load model names from the configuration file."""
2045+
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), config_file)
2046+
with open(config_path, 'r') as f:
2047+
configs = json.load(f)
2048+
return list(configs.keys())
2049+
2050+
available_models = get_available_models() # Dynamically load model names
2051+
model_help = ("Model name to benchmark. Select from: [" + ", ".join(available_models) +
2052+
"]. Use 'all' to benchmark all models or leave blank for the default benchmark script.")
2053+
parser.add_argument('-model', type=str, default=None, help=model_help)
19792054
parser.add_argument("-b", type=int, default=0)
19802055
parser.add_argument("-hq", type=int, default=0)
19812056
parser.add_argument("-hk", type=int, default=0)
@@ -2006,13 +2081,17 @@ def main():
20062081
custom_config = False
20072082
assert args.layout == 'thd' or not args.equal_seqlens, \
20082083
"Equal sequence lengths arg must be used with the thd layout."
2009-
if args.b or args.hq or args.hk or args.sq or args.sk or args.d:
2084+
if args.hq or args.hk or args.d:
20102085
custom_config = True
20112086
assert args.b and args.hq and args.sq and args.d, \
20122087
"If custom config is specified, please provide \
20132088
all of batch, number of Q heads, Q sequence length \
20142089
and head size."
20152090

2091+
if args.model:
2092+
assert not (args.hq or args.hk or args.d), \
2093+
"Specifying model fixes hq, hk and d already. Do not provide them!"
2094+
20162095
assert args.dtype in arg_to_torch_dtype, \
20172096
"Only fp16, bf16 and f32 types currently supported."
20182097

python/perf-kernels/gemm.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import pytest
77
import re
88

9+
import os
10+
911

1012
@triton.autotune(
1113
configs=[
@@ -275,7 +277,7 @@ def get_type(provider):
275277
plot_name="matmul-performance",
276278
args={},
277279
))
278-
def benchmark(M, N, K, provider):
280+
def benchmark(M, N, K, provider, model=None):
279281
in_dtype = name_to_torch_types[get_type(provider)]
280282
out_dtype = in_dtype
281283

@@ -304,14 +306,37 @@ def benchmark(M, N, K, provider):
304306
return perf(ms), perf(max_ms), perf(min_ms)
305307

306308

307-
# TODO(vgokhale): Add more options to benchmarking
308309
def parse_args():
309310
parser = argparse.ArgumentParser(
310311
prog="GEMM tutorial example",
311312
allow_abbrev=False,
312313
)
313314

315+
parser.add_argument('-model_configs', type=str, default="model_configs.json", help="Model config json file.")
316+
317+
def get_available_models(config_file='model_configs.json'):
318+
import json
319+
"""Load model names from the configuration file."""
320+
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), config_file)
321+
with open(config_path, 'r') as f:
322+
configs = json.load(f)
323+
return list(configs.keys())
324+
325+
available_models = get_available_models() # Dynamically load model names
326+
model_help = ("Model name to benchmark. Select from: [" + ", ".join(available_models) +
327+
"]. Use 'all' to benchmark all models or leave blank for the default benchmark script.")
328+
parser.add_argument('-model', type=str, default=None, help=model_help)
329+
parser.add_argument('-b', type=int, default=0,
330+
help="Batch size used together with model. Defaults to 1 if not provided.")
331+
parser.add_argument(
332+
'-sl', type=int, default=0,
333+
help="Sequence length used together with model. Defaults to max_seq_len from model config if not provided.")
334+
314335
parser.add_argument("-v", action='store_true', default=False, help="Print out the best tuning config")
336+
parser.add_argument("-M", type=int, default=0)
337+
parser.add_argument("-N", type=int, default=0)
338+
parser.add_argument("-K", type=int, default=0)
339+
315340
args = parser.parse_args()
316341

317342
return args
@@ -323,6 +348,48 @@ def main():
323348
global verbose
324349
args = parse_args()
325350
verbose = args.v
351+
352+
if args.model:
353+
batch_size = args.b if args.b else 1
354+
import os
355+
import json
356+
# If user did not provide an absolute path, resolve relative path from script directory
357+
if not os.path.isabs(args.model_configs):
358+
config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), args.model_configs)
359+
else:
360+
config_file = args.model_configs
361+
362+
with open(config_file, 'r') as f:
363+
configs = json.load(f)
364+
mnk_list = []
365+
366+
if args.model != "all":
367+
model_name = args.model
368+
# Check if the model exists
369+
if model_name not in configs:
370+
raise ValueError(f"Model '{model_name}' not found in {config_file}")
371+
# Handle a specific model
372+
config = configs[model_name]
373+
seq_len = args.sl if args.sl else config["max_ctx_len"]
374+
M, N, K = batch_size * seq_len, config["hidden_size"], config["intermediate_size"]
375+
mnk_list.append((model_name, M, N, K))
376+
else:
377+
# Handle all models
378+
for model_name, config in configs.items():
379+
seq_len = args.sl if args.sl else config["max_ctx_len"]
380+
M, N, K = batch_size * seq_len, config["hidden_size"], config["intermediate_size"]
381+
mnk_list.append((model_name, M, N, K))
382+
383+
benchmark.benchmarks.x_names = ['model', 'M', 'N', 'K']
384+
benchmark.benchmarks.x_vals = mnk_list
385+
386+
if args.M or args.N or args.K:
387+
assert args.model is None, "Providing both -model and -M/N/K is not compatible! -model already fixes -M/N/K."
388+
389+
if args.M and args.N and args.K:
390+
x_vals = [(args.M, args.N, args.K)]
391+
benchmark.benchmarks.x_vals = x_vals
392+
326393
benchmark.run(show_plots=True, print_data=True)
327394

328395

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
{
2+
"llama3_8B": {
3+
"num_attention_heads": 32,
4+
"num_key_value_heads": 8,
5+
"hidden_size": 4096,
6+
"max_ctx_len": 8192,
7+
"intermediate_size": 14336,
8+
"vocab_size": 128256
9+
},
10+
"llama3_70B": {
11+
"num_attention_heads": 64,
12+
"num_key_value_heads": 8,
13+
"hidden_size": 8192,
14+
"max_ctx_len": 8192,
15+
"intermediate_size": 28672,
16+
"vocab_size": 128256
17+
},
18+
"llama3_405B": {
19+
"num_attention_heads": 128,
20+
"num_key_value_heads": 8,
21+
"hidden_size": 16384,
22+
"max_ctx_len": 8192,
23+
"intermediate_size": 53248,
24+
"vocab_size": 128256
25+
}
26+
}

python/perf-kernels/rmsnorm.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,37 @@ def test_rmsnorm(M, N):
170170
arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32}
171171

172172

173+
def model_benchmark_configs(args):
174+
import os
175+
import json
176+
# If user did not provide an absolute path, resolve relative path from script directory
177+
if not os.path.isabs(args.model_configs):
178+
config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), args.model_configs)
179+
else:
180+
config_file = args.model_configs
181+
182+
with open(config_file, 'r') as f:
183+
configs = json.load(f)
184+
185+
x_vals_list = []
186+
batch_size = args.b if args.b else 1
187+
188+
if args.model == "all":
189+
for model_name, config in configs.items():
190+
seq_len = args.sl if args.sl else config["max_ctx_len"]
191+
x_vals_list.append((model_name, batch_size * seq_len, config["hidden_size"]))
192+
else:
193+
if args.model not in configs:
194+
raise ValueError(f"Model '{args.model}' not found in {config_file}")
195+
# Handle a specific model
196+
model_name = args.model
197+
config = configs[model_name]
198+
seq_len = args.sl if args.sl else config["max_ctx_len"]
199+
x_vals_list.append((model_name, batch_size * seq_len, config["hidden_size"]))
200+
201+
return x_vals_list
202+
203+
173204
def run_benchmark(args):
174205
config = []
175206
if (args.M_benchmark):
@@ -189,6 +220,14 @@ def run_benchmark(args):
189220
plot_name = str("rmsnorm-performance_" + args.dtype + "_M" + str(args.M_start) + "_N" + str(args.N_start) +
190221
"-" + str(args.N_end) + "-" + str(args.N_step))
191222

223+
if args.model:
224+
assert not args.M_benchmark, \
225+
"Trying to provide both -model benchmark and M_benchmark is not supported!"
226+
x_names = ['model', 'M', 'N']
227+
mn_args = {}
228+
plot_name = str("rmsnorm-performance_" + args.dtype)
229+
x_vals_list = model_benchmark_configs(args)
230+
192231
dtype = arg_to_torch_dtype[args.dtype]
193232

194233
print(plot_name)
@@ -206,7 +245,7 @@ def run_benchmark(args):
206245
))
207246

208247
@triton.testing.perf_report(config)
209-
def benchmark(M, N, provider):
248+
def benchmark(M, N, provider, model=None):
210249
x = torch.randn(M, N, device='cuda', dtype=dtype)
211250
y = torch.zeros_like(x, device='cuda')
212251
n_rows, n_cols = x.shape
@@ -237,7 +276,26 @@ def parse_args():
237276
prog="Benchmark RMSNorm",
238277
allow_abbrev=False,
239278
)
240-
279+
parser.add_argument('-model_configs', type=str, default="model_configs.json", help="Model config json file.")
280+
281+
def get_available_models(config_file='model_configs.json'):
282+
import os
283+
import json
284+
"""Load model names from the configuration file."""
285+
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), config_file)
286+
with open(config_path, 'r') as f:
287+
configs = json.load(f)
288+
return list(configs.keys())
289+
290+
available_models = get_available_models() # Dynamically load model names
291+
model_help = ("Model name to benchmark. Select from: [" + ", ".join(available_models) +
292+
"]. Use 'all' to benchmark all models or leave blank for the default benchmark script.")
293+
parser.add_argument('-model', type=str, default=None, help=model_help)
294+
parser.add_argument('-b', type=int, default=0,
295+
help="Batch size used together with model. Defaults to 1 if not provided.")
296+
parser.add_argument(
297+
'-sl', type=int, default=0,
298+
help="Sequence length used together with model. Defaults to max_seq_len from model config if not provided.")
241299
parser.add_argument('-M', "--M_start", default="1", type=int)
242300
parser.add_argument('-Ms', "--M_step", default="2", type=int) #This is multiplicative step
243301
parser.add_argument('-Me', "--M_end", default="512", type=int)

0 commit comments

Comments
 (0)