Skip to content

Commit 195de17

Browse files
revert changes to test_ad_build_small_single
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
1 parent 8c52dc5 commit 195de17

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,38 @@
44
from _model_test_utils import get_small_model_config
55
from build_and_run_ad import ExperimentConfig, main
66

7-
from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs, _ParallelConfig
7+
from tensorrt_llm._torch.auto_deploy.llm_args import AutoDeployConfig, LlmArgs, _ParallelConfig
88
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import ADEngine
99

1010

11-
def _check_ad_config(experiment_config: ExperimentConfig, ad_config: LlmArgs):
11+
def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
1212
# Verify that llm_args was captured
13-
assert ad_config is not None, "llm_args should have been captured"
13+
assert llm_args is not None, "llm_args should have been captured"
1414

1515
# Check that llm_args is an instance of LlmArgs and also an instance of AutoDeployConfig
16-
assert isinstance(ad_config, LlmArgs), f"Expected LlmArgs, got {type(ad_config)}"
16+
assert isinstance(llm_args, LlmArgs), f"Expected LlmArgs, got {type(llm_args)}"
17+
assert isinstance(llm_args, AutoDeployConfig), (
18+
f"Expected AutoDeployConfig, got {type(llm_args)}"
19+
)
1720

1821
# check that llm_args and experiment_config have the same args
19-
expected_llm_args: LlmArgs = experiment_config.args
20-
assert expected_llm_args == ad_config, f"Expected llm args {expected_llm_args}, got {ad_config}"
22+
expected_ad_config: AutoDeployConfig = experiment_config.args
23+
expected_llm_args: LlmArgs = LlmArgs(**expected_ad_config.to_llm_kwargs())
24+
assert expected_llm_args == llm_args, f"Expected llm args {expected_llm_args}, got {llm_args}"
2125

2226
# check expected parallel config
23-
world_size = expected_llm_args.world_size
27+
world_size = expected_ad_config.world_size
2428
expected_parallel_config = _ParallelConfig(
2529
tp_size=world_size, gpus_per_node=expected_llm_args.gpus_per_node
2630
)
2731
expected_parallel_config.world_size = world_size
28-
assert ad_config._parallel_config == expected_parallel_config, (
29-
f"Expected parallel_config {expected_parallel_config}, got {ad_config._parallel_config}"
32+
assert llm_args._parallel_config == expected_parallel_config, (
33+
f"Expected parallel_config {expected_parallel_config}, got {llm_args._parallel_config}"
3034
)
3135

3236
# backend should always be "_autodeploy"
33-
assert ad_config.backend == "_autodeploy", (
34-
f"Expected backend '_autodeploy', got {ad_config.backend}"
37+
assert llm_args.backend == "_autodeploy", (
38+
f"Expected backend '_autodeploy', got {llm_args.backend}"
3539
)
3640

3741

@@ -195,7 +199,7 @@ def test_build_ad(model_hub_id: str, llm_extra_args: dict):
195199
original_build_from_config = ADEngine.build_from_config
196200

197201
@classmethod
198-
def check_and_original_build(cls, ad_config: LlmArgs):
202+
def check_and_original_build(cls, ad_config):
199203
_check_ad_config(experiment_config, ad_config)
200204
return original_build_from_config.__func__(cls, ad_config)
201205

0 commit comments

Comments
 (0)