|
28 | 28 |
|
29 | 29 | import triton |
30 | 30 | import triton.language as tl |
| 31 | +from utils.benchmark_utils import get_available_models, get_model_configs |
31 | 32 |
|
32 | 33 |
|
33 | 34 | class MetaData(): |
@@ -1870,44 +1871,18 @@ def varlen_benchmark_configs(): |
1870 | 1871 |
|
1871 | 1872 |
|
1872 | 1873 | def model_benchmark_configs(args): |
1873 | | - import os |
1874 | | - import json |
1875 | | - # If user did not provide an absolute path, resolve relative path from script directory |
1876 | | - if not os.path.isabs(args.model_configs): |
1877 | | - config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), args.model_configs) |
1878 | | - else: |
1879 | | - config_file = args.model_configs |
1880 | | - |
1881 | | - with open(config_file, 'r') as f: |
1882 | | - configs = json.load(f) |
| 1874 | + config_file = args.model_configs |
| 1875 | + configs = get_model_configs(config_path=config_file, model_families=["llama3"], model=args.model) |
1883 | 1876 | fa_configs = [] |
| 1877 | + batch_size = args.b if args.b else 1 |
1884 | 1878 |
|
1885 | | - if args.model != "all": |
1886 | | - # Check if the model exists |
1887 | | - model_name = args.model |
1888 | | - if model_name not in configs: |
1889 | | - raise ValueError(f"Model '{model_name}' not found in {config_file}") |
1890 | | - # Handle a specific model |
1891 | | - config = configs[model_name] |
| 1879 | + for model_name, config in configs.items(): |
1892 | 1880 | HQ = config["num_attention_heads"] |
1893 | 1881 | HK = HQ if config["num_key_value_heads"] is None else config["num_key_value_heads"] |
1894 | | - |
1895 | 1882 | max_ctx_len = config["max_ctx_len"] |
1896 | 1883 | N_CTX_Q = args.sq if args.sq else max_ctx_len |
1897 | 1884 | N_CTX_K = args.sk if args.sk else max_ctx_len |
1898 | | - batch_size = args.b if args.b else 1 |
1899 | | - |
1900 | 1885 | fa_configs.append((model_name, batch_size, HQ, HK, N_CTX_Q, N_CTX_K)) |
1901 | | - else: |
1902 | | - # Handle all models |
1903 | | - for model_name, config in configs.items(): |
1904 | | - HQ = config["num_attention_heads"] |
1905 | | - HK = HQ if config["num_key_value_heads"] is None else config["num_key_value_heads"] |
1906 | | - max_ctx_len = config["max_ctx_len"] |
1907 | | - N_CTX_Q = args.sq if args.sq else max_ctx_len |
1908 | | - N_CTX_K = args.sk if args.sk else max_ctx_len |
1909 | | - batch_size = args.b if args.b else 1 |
1910 | | - fa_configs.append((model_name, batch_size, HQ, HK, N_CTX_Q, N_CTX_K)) |
1911 | 1886 |
|
1912 | 1887 | return fa_configs |
1913 | 1888 |
|
@@ -2038,16 +2013,7 @@ def parse_args(): |
2038 | 2013 | ) |
2039 | 2014 | parser.add_argument('-model_configs', type=str, default="model_configs.json", help="Model config json file.") |
2040 | 2015 |
|
2041 | | - def get_available_models(config_file='model_configs.json'): |
2042 | | - import os |
2043 | | - import json |
2044 | | - """Load model names from the configuration file.""" |
2045 | | - config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), config_file) |
2046 | | - with open(config_path, 'r') as f: |
2047 | | - configs = json.load(f) |
2048 | | - return list(configs.keys()) |
2049 | | - |
2050 | | - available_models = get_available_models() # Dynamically load model names |
| 2016 | + available_models = get_available_models(model_families=["llama3"]) # Dynamically load model names |
2051 | 2017 | model_help = ("Model name to benchmark. Select from: [" + ", ".join(available_models) + |
2052 | 2018 | "]. Use 'all' to benchmark all models or leave blank for the default benchmark script.") |
2053 | 2019 | parser.add_argument('-model', type=str, default=None, help=model_help) |
|
0 commit comments