|
4 | 4 | from _model_test_utils import get_small_model_config |
5 | 5 | from build_and_run_ad import ExperimentConfig, main |
6 | 6 |
|
7 | | -from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs, _ParallelConfig |
| 7 | +from tensorrt_llm._torch.auto_deploy.llm_args import AutoDeployConfig, LlmArgs, _ParallelConfig |
8 | 8 | from tensorrt_llm._torch.auto_deploy.shim.ad_executor import ADEngine |
9 | 9 |
|
10 | 10 |
|
11 | | -def _check_ad_config(experiment_config: ExperimentConfig, ad_config: LlmArgs): |
| 11 | +def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs): |
12 | 12 | # Verify that llm_args was captured |
13 | | - assert ad_config is not None, "llm_args should have been captured" |
| 13 | + assert llm_args is not None, "llm_args should have been captured" |
14 | 14 |
|
15 | 15 | # Check that llm_args is an instance of LlmArgs and also an instance of AutoDeployConfig |
16 | | - assert isinstance(ad_config, LlmArgs), f"Expected LlmArgs, got {type(ad_config)}" |
| 16 | + assert isinstance(llm_args, LlmArgs), f"Expected LlmArgs, got {type(llm_args)}" |
| 17 | + assert isinstance(llm_args, AutoDeployConfig), ( |
| 18 | + f"Expected AutoDeployConfig, got {type(llm_args)}" |
| 19 | + ) |
17 | 20 |
|
18 | 21 | # check that llm_args and experiment_config have the same args |
19 | | - expected_llm_args: LlmArgs = experiment_config.args |
20 | | - assert expected_llm_args == ad_config, f"Expected llm args {expected_llm_args}, got {ad_config}" |
| 22 | + expected_ad_config: AutoDeployConfig = experiment_config.args |
| 23 | + expected_llm_args: LlmArgs = LlmArgs(**expected_ad_config.to_llm_kwargs()) |
| 24 | + assert expected_llm_args == llm_args, f"Expected llm args {expected_llm_args}, got {llm_args}" |
21 | 25 |
|
22 | 26 | # check expected parallel config |
23 | | - world_size = expected_llm_args.world_size |
| 27 | + world_size = expected_ad_config.world_size |
24 | 28 | expected_parallel_config = _ParallelConfig( |
25 | 29 | tp_size=world_size, gpus_per_node=expected_llm_args.gpus_per_node |
26 | 30 | ) |
27 | 31 | expected_parallel_config.world_size = world_size |
28 | | - assert ad_config._parallel_config == expected_parallel_config, ( |
29 | | - f"Expected parallel_config {expected_parallel_config}, got {ad_config._parallel_config}" |
| 32 | + assert llm_args._parallel_config == expected_parallel_config, ( |
| 33 | + f"Expected parallel_config {expected_parallel_config}, got {llm_args._parallel_config}" |
30 | 34 | ) |
31 | 35 |
|
32 | 36 | # backend should always be "_autodeploy" |
33 | | - assert ad_config.backend == "_autodeploy", ( |
34 | | - f"Expected backend '_autodeploy', got {ad_config.backend}" |
| 37 | + assert llm_args.backend == "_autodeploy", ( |
| 38 | + f"Expected backend '_autodeploy', got {llm_args.backend}" |
35 | 39 | ) |
36 | 40 |
|
37 | 41 |
|
@@ -195,7 +199,7 @@ def test_build_ad(model_hub_id: str, llm_extra_args: dict): |
195 | 199 | original_build_from_config = ADEngine.build_from_config |
196 | 200 |
|
197 | 201 | @classmethod |
198 | | - def check_and_original_build(cls, ad_config: LlmArgs): |
| 202 | + def check_and_original_build(cls, ad_config): |
199 | 203 | _check_ad_config(experiment_config, ad_config) |
200 | 204 | return original_build_from_config.__func__(cls, ad_config) |
201 | 205 |
|
|
0 commit comments