|
4 | 4 |
|
5 | 5 | import pytest |
6 | 6 | import yaml |
7 | | -from _model_test_utils import _hf_model_dir_or_hub_id |
| 7 | +from _model_test_utils import get_small_model_config |
8 | 8 | from click.testing import CliRunner |
9 | 9 | from utils.cpp_paths import llm_root # noqa: F401 |
10 | 10 |
|
11 | 11 | from tensorrt_llm.commands.bench import main |
12 | 12 |
|
13 | 13 |
|
14 | | -def tiny_llama_details(): |
15 | | - model_path = "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" |
16 | | - model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
17 | | - model_path_or_name = _hf_model_dir_or_hub_id(model_path, model_name) |
18 | | - return model_path_or_name, model_name, model_path |
19 | | - |
20 | | - |
21 | 14 | def run_benchmark(model_name: str, dataset_path: str, extra_llm_api_options_path: str): |
22 | 15 | runner = CliRunner() |
23 | 16 |
|
@@ -74,20 +67,19 @@ def prepare_dataset(root_dir: str, temp_dir: str, model_path_or_name: str): |
74 | 67 |
|
75 | 68 |
|
76 | 69 | @pytest.mark.parametrize("compile_backend", ["torch-compile", "torch-opt", "torch-cudagraph"]) |
77 | | -def test_trtllm_bench(llm_root, compile_backend): # noqa: F811 |
78 | | - model_path_or_name, model_name, model_path = tiny_llama_details() |
| 70 | +@pytest.mark.parametrize("model_name", ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]) |
| 71 | +def test_trtllm_bench(llm_root, compile_backend, model_name): # noqa: F811 |
| 72 | + config = get_small_model_config(model_name) |
79 | 73 | with tempfile.TemporaryDirectory() as temp_dir: |
80 | 74 | extra_llm_api_options_path = f"{temp_dir}/extra_llm_api_options.yaml" |
81 | 75 | with open(extra_llm_api_options_path, "w") as f: |
82 | 76 | yaml.dump( |
83 | 77 | { |
84 | | - "model_kwargs": {"num_hidden_layers": 2}, |
85 | | - "cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128], |
86 | | - "max_batch_size": 128, |
87 | 78 | "compile_backend": compile_backend, |
| 79 | + **config["args"], |
88 | 80 | }, |
89 | 81 | f, |
90 | 82 | ) |
91 | 83 |
|
92 | | - dataset_path = prepare_dataset(llm_root, temp_dir, model_path_or_name) |
| 84 | + dataset_path = prepare_dataset(llm_root, temp_dir, config["args"]["model"]) |
93 | 85 | run_benchmark(model_name, dataset_path, extra_llm_api_options_path) |
0 commit comments