Skip to content

Commit 902b832

Browse files
authored
benchmark utils added (#694)
added benchmark utils to make the code more readable
1 parent 069281e commit 902b832

File tree

6 files changed

+118
-157
lines changed

6 files changed

+118
-157
lines changed

python/perf-kernels/flash-attention.py

Lines changed: 6 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import triton
3030
import triton.language as tl
31+
from utils.benchmark_utils import get_available_models, get_model_configs
3132

3233

3334
class MetaData():
@@ -1870,44 +1871,18 @@ def varlen_benchmark_configs():
18701871

18711872

18721873
def model_benchmark_configs(args):
1873-
import os
1874-
import json
1875-
# If user did not provide an absolute path, resolve relative path from script directory
1876-
if not os.path.isabs(args.model_configs):
1877-
config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), args.model_configs)
1878-
else:
1879-
config_file = args.model_configs
1880-
1881-
with open(config_file, 'r') as f:
1882-
configs = json.load(f)
1874+
config_file = args.model_configs
1875+
configs = get_model_configs(config_path=config_file, model_families=["llama3"], model=args.model)
18831876
fa_configs = []
1877+
batch_size = args.b if args.b else 1
18841878

1885-
if args.model != "all":
1886-
# Check if the model exists
1887-
model_name = args.model
1888-
if model_name not in configs:
1889-
raise ValueError(f"Model '{model_name}' not found in {config_file}")
1890-
# Handle a specific model
1891-
config = configs[model_name]
1879+
for model_name, config in configs.items():
18921880
HQ = config["num_attention_heads"]
18931881
HK = HQ if config["num_key_value_heads"] is None else config["num_key_value_heads"]
1894-
18951882
max_ctx_len = config["max_ctx_len"]
18961883
N_CTX_Q = args.sq if args.sq else max_ctx_len
18971884
N_CTX_K = args.sk if args.sk else max_ctx_len
1898-
batch_size = args.b if args.b else 1
1899-
19001885
fa_configs.append((model_name, batch_size, HQ, HK, N_CTX_Q, N_CTX_K))
1901-
else:
1902-
# Handle all models
1903-
for model_name, config in configs.items():
1904-
HQ = config["num_attention_heads"]
1905-
HK = HQ if config["num_key_value_heads"] is None else config["num_key_value_heads"]
1906-
max_ctx_len = config["max_ctx_len"]
1907-
N_CTX_Q = args.sq if args.sq else max_ctx_len
1908-
N_CTX_K = args.sk if args.sk else max_ctx_len
1909-
batch_size = args.b if args.b else 1
1910-
fa_configs.append((model_name, batch_size, HQ, HK, N_CTX_Q, N_CTX_K))
19111886

19121887
return fa_configs
19131888

@@ -2038,16 +2013,7 @@ def parse_args():
20382013
)
20392014
parser.add_argument('-model_configs', type=str, default="model_configs.json", help="Model config json file.")
20402015

2041-
def get_available_models(config_file='model_configs.json'):
2042-
import os
2043-
import json
2044-
"""Load model names from the configuration file."""
2045-
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), config_file)
2046-
with open(config_path, 'r') as f:
2047-
configs = json.load(f)
2048-
return list(configs.keys())
2049-
2050-
available_models = get_available_models() # Dynamically load model names
2016+
available_models = get_available_models(model_families=["llama3"]) # Dynamically load model names
20512017
model_help = ("Model name to benchmark. Select from: [" + ", ".join(available_models) +
20522018
"]. Use 'all' to benchmark all models or leave blank for the default benchmark script.")
20532019
parser.add_argument('-model', type=str, default=None, help=model_help)

python/perf-kernels/gemm.py

Lines changed: 6 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77
import re
88

9-
import os
9+
from utils.benchmark_utils import get_available_models, get_model_configs
1010

1111

1212
@triton.autotune(
@@ -314,15 +314,7 @@ def parse_args():
314314

315315
parser.add_argument('-model_configs', type=str, default="model_configs.json", help="Model config json file.")
316316

317-
def get_available_models(config_file='model_configs.json'):
318-
import json
319-
"""Load model names from the configuration file."""
320-
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), config_file)
321-
with open(config_path, 'r') as f:
322-
configs = json.load(f)
323-
return list(configs.keys())
324-
325-
available_models = get_available_models() # Dynamically load model names
317+
available_models = get_available_models(model_families=["llama3"]) # Dynamically load model names
326318
model_help = ("Model name to benchmark. Select from: [" + ", ".join(available_models) +
327319
"]. Use 'all' to benchmark all models or leave blank for the default benchmark script.")
328320
parser.add_argument('-model', type=str, default=None, help=model_help)
@@ -350,35 +342,15 @@ def main():
350342
verbose = args.v
351343

352344
if args.model:
353-
batch_size = args.b if args.b else 1
354-
import os
355-
import json
356-
# If user did not provide an absolute path, resolve relative path from script directory
357-
if not os.path.isabs(args.model_configs):
358-
config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), args.model_configs)
359-
else:
360-
config_file = args.model_configs
361-
362-
with open(config_file, 'r') as f:
363-
configs = json.load(f)
345+
config_file = args.model_configs
346+
configs = get_model_configs(config_path=config_file, model_families=["llama3"], model=args.model)
364347
mnk_list = []
348+
batch_size = args.b if args.b else 1
365349

366-
if args.model != "all":
367-
model_name = args.model
368-
# Check if the model exists
369-
if model_name not in configs:
370-
raise ValueError(f"Model '{model_name}' not found in {config_file}")
371-
# Handle a specific model
372-
config = configs[model_name]
350+
for model_name, config in configs.items():
373351
seq_len = args.sl if args.sl else config["max_ctx_len"]
374352
M, N, K = batch_size * seq_len, config["hidden_size"], config["intermediate_size"]
375353
mnk_list.append((model_name, M, N, K))
376-
else:
377-
# Handle all models
378-
for model_name, config in configs.items():
379-
seq_len = args.sl if args.sl else config["max_ctx_len"]
380-
M, N, K = batch_size * seq_len, config["hidden_size"], config["intermediate_size"]
381-
mnk_list.append((model_name, M, N, K))
382354

383355
benchmark.benchmarks.x_names = ['model', 'M', 'N', 'K']
384356
benchmark.benchmarks.x_vals = mnk_list
Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,28 @@
11
{
2-
"llama3_8B": {
3-
"num_attention_heads": 32,
4-
"num_key_value_heads": 8,
5-
"hidden_size": 4096,
6-
"max_ctx_len": 8192,
7-
"intermediate_size": 14336,
8-
"vocab_size": 128256
9-
},
10-
"llama3_70B": {
11-
"num_attention_heads": 64,
12-
"num_key_value_heads": 8,
13-
"hidden_size": 8192,
14-
"max_ctx_len": 8192,
15-
"intermediate_size": 28672,
16-
"vocab_size": 128256
17-
},
18-
"llama3_405B": {
19-
"num_attention_heads": 128,
20-
"num_key_value_heads": 8,
21-
"hidden_size": 16384,
22-
"max_ctx_len": 8192,
23-
"intermediate_size": 53248,
24-
"vocab_size": 128256
2+
"llama3": {
3+
"8B": {
4+
"num_attention_heads": 32,
5+
"num_key_value_heads": 8,
6+
"hidden_size": 4096,
7+
"max_ctx_len": 8192,
8+
"intermediate_size": 14336,
9+
"vocab_size": 128256
10+
},
11+
"70B": {
12+
"num_attention_heads": 64,
13+
"num_key_value_heads": 8,
14+
"hidden_size": 8192,
15+
"max_ctx_len": 8192,
16+
"intermediate_size": 28672,
17+
"vocab_size": 128256
18+
},
19+
"405B": {
20+
"num_attention_heads": 128,
21+
"num_key_value_heads": 8,
22+
"hidden_size": 16384,
23+
"max_ctx_len": 8192,
24+
"intermediate_size": 53248,
25+
"vocab_size": 128256
26+
}
2527
}
2628
}

python/perf-kernels/rmsnorm.py

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import triton
88
import triton.language as tl
9+
from utils.benchmark_utils import get_available_models, get_model_configs
910

1011

1112
def is_cuda():
@@ -171,30 +172,13 @@ def test_rmsnorm(M, N):
171172

172173

173174
def model_benchmark_configs(args):
174-
import os
175-
import json
176-
# If user did not provide an absolute path, resolve relative path from script directory
177-
if not os.path.isabs(args.model_configs):
178-
config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), args.model_configs)
179-
else:
180-
config_file = args.model_configs
181-
182-
with open(config_file, 'r') as f:
183-
configs = json.load(f)
175+
config_file = args.model_configs
176+
configs = get_model_configs(config_path=config_file, model_families=["llama3"], model=args.model)
184177

185178
x_vals_list = []
186179
batch_size = args.b if args.b else 1
187180

188-
if args.model == "all":
189-
for model_name, config in configs.items():
190-
seq_len = args.sl if args.sl else config["max_ctx_len"]
191-
x_vals_list.append((model_name, batch_size * seq_len, config["hidden_size"]))
192-
else:
193-
if args.model not in configs:
194-
raise ValueError(f"Model '{args.model}' not found in {config_file}")
195-
# Handle a specific model
196-
model_name = args.model
197-
config = configs[model_name]
181+
for model_name, config in configs.items():
198182
seq_len = args.sl if args.sl else config["max_ctx_len"]
199183
x_vals_list.append((model_name, batch_size * seq_len, config["hidden_size"]))
200184

@@ -278,16 +262,7 @@ def parse_args():
278262
)
279263
parser.add_argument('-model_configs', type=str, default="model_configs.json", help="Model config json file.")
280264

281-
def get_available_models(config_file='model_configs.json'):
282-
import os
283-
import json
284-
"""Load model names from the configuration file."""
285-
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), config_file)
286-
with open(config_path, 'r') as f:
287-
configs = json.load(f)
288-
return list(configs.keys())
289-
290-
available_models = get_available_models() # Dynamically load model names
265+
available_models = get_available_models(model_families=["llama3"]) # Dynamically load model names
291266
model_help = ("Model name to benchmark. Select from: [" + ", ".join(available_models) +
292267
"]. Use 'all' to benchmark all models or leave blank for the default benchmark script.")
293268
parser.add_argument('-model', type=str, default=None, help=model_help)

python/perf-kernels/softmax.py

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import triton
77
import triton.language as tl
8+
from utils.benchmark_utils import get_available_models, get_model_configs
89

910

1011
def is_cuda():
@@ -134,30 +135,13 @@ def test_softmax(M, N):
134135

135136

136137
def model_benchmark_configs(args):
137-
import os
138-
import json
139-
# If user did not provide an absolute path, resolve relative path from script directory
140-
if not os.path.isabs(args.model_configs):
141-
config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), args.model_configs)
142-
else:
143-
config_file = args.model_configs
144-
145-
with open(config_file, 'r') as f:
146-
configs = json.load(f)
138+
config_file = args.model_configs
139+
configs = get_model_configs(config_path=config_file, model_families=["llama3"], model=args.model)
147140

148141
x_vals_list = []
149142
batch_size = args.b if args.b else 1
150143

151-
if args.model == "all":
152-
for model_name, config in configs.items():
153-
seq_len = args.sl if args.sl else config["max_ctx_len"]
154-
x_vals_list.append((model_name, batch_size * seq_len, config["vocab_size"]))
155-
else:
156-
if args.model not in configs:
157-
raise ValueError(f"Model '{args.model}' not found in {config_file}")
158-
# Handle a specific model
159-
model_name = args.model
160-
config = configs[model_name]
144+
for model_name, config in configs.items():
161145
seq_len = args.sl if args.sl else config["max_ctx_len"]
162146
x_vals_list.append((model_name, batch_size * seq_len, config["vocab_size"]))
163147

@@ -232,16 +216,7 @@ def parse_args():
232216
)
233217
parser.add_argument('-model_configs', type=str, default="model_configs.json", help="Model config json file.")
234218

235-
def get_available_models(config_file='model_configs.json'):
236-
import os
237-
import json
238-
"""Load model names from the configuration file."""
239-
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), config_file)
240-
with open(config_path, 'r') as f:
241-
configs = json.load(f)
242-
return list(configs.keys())
243-
244-
available_models = get_available_models() # Dynamically load model names
219+
available_models = get_available_models(model_families=["llama3"]) # Dynamically load model names
245220
model_help = ("Model name to benchmark. Select from: [" + ", ".join(available_models) +
246221
"]. Use 'all' to benchmark all models or leave blank for the default benchmark script.")
247222
parser.add_argument('-model', type=str, default=None, help=model_help)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import os
2+
import json
3+
4+
# Base directory where configs are located
5+
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))
6+
7+
8+
def get_model_configs(config_path='model_configs.json', model_families=["llama3"], model="all"):
9+
"""
10+
Load model names from the configuration file.
11+
12+
Args:
13+
config_path (str): User-provided path to the configuration JSON file.
14+
model_families (list): List of model family names to retrieve.
15+
16+
Returns:
17+
dict: A dictionary of available models and their configurations for the specified families.
18+
"""
19+
# Resolve config path relative to ./perf-kernels/
20+
config_path = os.path.join(BASE_DIR, config_path)
21+
22+
with open(config_path, 'r') as f:
23+
configs = json.load(f)
24+
25+
# Extract models and their configurations for the specified families
26+
filtered_configs = {}
27+
28+
for family in model_families:
29+
if family in configs:
30+
# Check if model filtering is required
31+
if model == "all":
32+
# Include all models in the family
33+
for model_size, model_configs in configs[family].items():
34+
filtered_configs[f"{family}-{model_size}"] = model_configs
35+
else:
36+
# Parse the model string (e.g., llama3_8B or llama3-8B)
37+
delimiter = "_" if "_" in model else "-"
38+
model_parts = model.split(delimiter)
39+
40+
# Check if the family and size match
41+
if len(model_parts) == 2 and model_parts[0] == family:
42+
model_size = model_parts[1]
43+
if model_size in configs[family]:
44+
filtered_configs[f"{family}-{model_size}"] = configs[family][model_size]
45+
46+
if not filtered_configs:
47+
print(f"Warning: No models selected for families: {model_families} with filter: '{model}'")
48+
49+
return filtered_configs
50+
51+
52+
def get_available_models(config_file='model_configs.json', model_families=["llama3"]):
53+
"""
54+
Load model names from the configuration file.
55+
56+
Args:
57+
config_file (str): Path to the configuration JSON file.
58+
model_families (list): List of model family names to retrieve.
59+
60+
Returns:
61+
list: A list of available models for the specified families.
62+
"""
63+
# Resolve config path relative to ./perf-kernels/
64+
config_path = os.path.join(BASE_DIR, config_file)
65+
66+
with open(config_path, 'r') as f:
67+
configs = json.load(f)
68+
69+
models = [f"{family}-{model}" for family in model_families if family in configs for model in configs[family]]
70+
71+
return models

0 commit comments

Comments
 (0)