Skip to content

Commit af359e2

Browse files
removed unneeded functions in integration test, fixed unit test mocking
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
1 parent ad0bda6 commit af359e2

File tree

2 files changed

+28
-24
lines changed

2 files changed

+28
-24
lines changed

tests/integration/defs/examples/test_ad_guided_decoding.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
116
import json
217
import os
318

@@ -7,29 +22,7 @@
722
from tensorrt_llm.sampling_params import GuidedDecodingParams
823

924

10-
def autodeploy_example_root(llm_root):
11-
example_root = os.path.join(llm_root, "examples", "auto_deploy")
12-
return example_root
13-
14-
15-
def prepare_model_symlinks(llm_venv):
16-
"""Create local symlinks for models to avoid re-downloading in examples."""
17-
src_dst_dict = {
18-
# TinyLlama-1.1B-Chat-v1.0 used by the guided decoding example
19-
f"{llm_models_root()}/llama-models-v2/TinyLlama-1.1B-Chat-v1.0":
20-
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
21-
}
22-
23-
for src, dst in src_dst_dict.items():
24-
if not os.path.islink(dst):
25-
os.makedirs(os.path.dirname(dst), exist_ok=True)
26-
try:
27-
os.symlink(src, dst, target_is_directory=True)
28-
except FileExistsError:
29-
pass
30-
31-
32-
def test_autodeploy_guided_decoding_main_json(llm_root, llm_venv):
25+
def test_autodeploy_guided_decoding_main_json():
3326
schema = (
3427
"{"
3528
'"title": "WirelessAccessPoint", "type": "object", "properties": {'

tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_guided_decoding.py renamed to tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_create_ad_executor.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515

1616
from dataclasses import dataclass
17-
from typing import Any
17+
from typing import Any, Optional
1818
from unittest.mock import Mock, patch
1919

2020
import pytest
@@ -64,6 +64,13 @@ class MockPyExecutor:
6464
guided_decoder: Any
6565

6666

67+
@dataclass
68+
class MockFactory:
69+
"""Mock Factory that stores initialization arguments."""
70+
71+
vocab_size_padded: Optional[int] = None
72+
73+
6774
"""Unit tests for create_autodeploy_executor function."""
6875

6976

@@ -113,6 +120,10 @@ def test_create_autodeploy_executor_with_guided_decoding(
113120
patch(
114121
"tensorrt_llm._torch.auto_deploy.shim.ad_executor.ADEngine.build_from_config"
115122
) as mock_ad_engine,
123+
patch(
124+
"tensorrt_llm._torch.auto_deploy.llm_args.LlmArgs.create_factory",
125+
return_value=MockFactory(vocab_size_padded=vocab_size_padded),
126+
),
116127
):
117128
mock_ad_engine.return_value = mock_engine
118129

0 commit comments

Comments
 (0)