Skip to content

Commit 8950223

Browse files
authored
[fix] Remove SpecConfig and fix thread leak issues (NVIDIA#5931)
Signed-off-by: Mike Iovine <[email protected]>
1 parent bc1d4fb commit 8950223

File tree

5 files changed

+15
-24
lines changed

5 files changed

+15
-24
lines changed

tensorrt_llm/_torch/speculative/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .eagle3 import Eagle3SpecMetadata
2-
from .interface import SpecConfig, SpecMetadata
2+
from .interface import SpecMetadata
33
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
44
from .ngram import NGramDrafter, NGramPoolManager
55
from .utils import (get_num_spec_layers, get_spec_decoder, get_spec_drafter,
@@ -13,7 +13,6 @@
1313
"MTPWorker",
1414
"NGramDrafter",
1515
"NGramPoolManager",
16-
"SpecConfig",
1716
"SpecMetadata",
1817
"get_num_spec_layers",
1918
"get_spec_decoder",

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,17 +105,6 @@ def from_string(name: Optional[str]) -> "SpeculativeDecodingMode":
105105
return SpeculativeDecodingMode[name.upper()]
106106

107107

108-
@dataclass
109-
class SpecConfig:
110-
"""
111-
Configuration for speculative decoding.
112-
This class is deprecated, but thread-leak of pytest raises flaky error if removing it.
113-
TODO: remove this class safely.
114-
"""
115-
# The name of speculative decoding.
116-
spec_dec_name = None
117-
118-
119108
@dataclass
120109
class SpecMetadata:
121110
"""

tensorrt_llm/llmapi/llm_args.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -389,15 +389,15 @@ def supports_backend(self, backend: str) -> bool:
389389

390390

391391
class MTPDecodingConfig(DecodingBaseConfig):
392-
num_nextn_predict_layers: Optional[int] = 1
393-
use_relaxed_acceptance_for_thinking: Optional[bool] = False
394-
relaxed_topk: Optional[int] = 1
395-
relaxed_delta: Optional[float] = 0.
396-
use_mtp_vanilla: Optional[bool] = False
392+
num_nextn_predict_layers: int = 1
393+
use_relaxed_acceptance_for_thinking: bool = False
394+
relaxed_topk: int = 1
395+
relaxed_delta: float = 0.
396+
use_mtp_vanilla: bool = False
397397

398398
# TODO: remove this after distinguishing `max_draft_len` and `num_nextn_predict_layers`
399399
# Now we need a flag when MTPDecodingConfig is updated by PyTorchModelEngine.
400-
num_nextn_predict_layers_from_model_config: Optional[int] = 1
400+
num_nextn_predict_layers_from_model_config: int = 1
401401

402402
# TODO: Hard code for DeepSeek R1
403403
# When encounter <think>, start thinking phase.

tests/integration/defs/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,11 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
16+
# This import inexplicably starts a thread!
17+
# This causes problems for our test infra. The issue is that TRTLLM will import
18+
# this module. If the import happens before the test starts, there are no problems.
19+
# But if the import happens lazily after the test starts, pytest will think you leaked
20+
# the thread. We thus do the import here to prevent thread leak issues cropping up when messing
21+
# with the import statements in tests.
22+
from torch._inductor import lowering # NOQA

tests/integration/defs/accuracy/accuracy_core.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import tensorrt_llm.evaluate
2626
from tensorrt_llm import LLM as PyTorchLLM
2727
from tensorrt_llm._tensorrt_engine import LLM
28-
from tensorrt_llm._torch.speculative import SpecConfig
2928
from tensorrt_llm.builder import BuildConfig
3029
from tensorrt_llm.llmapi import SamplingParams
3130
from tensorrt_llm.llmapi.llm_args import DecodingBaseConfig
@@ -156,10 +155,6 @@ def evaluate(self,
156155
spec_dec_algo = None
157156
elif isinstance(llm.args.speculative_config, DecodingBaseConfig):
158157
spec_dec_algo = llm.args.speculative_config.decoding_type
159-
elif isinstance(llm.args.speculative_config, SpecConfig):
160-
# This branch is deprecated, but thread-leak of pytest raises flaky error if removing it.
161-
# TODO: remove this branch safely.
162-
spec_dec_algo = llm.args.speculative_config.spec_dec_name
163158
else:
164159
raise ValueError(
165160
f"Not recognized speculative_config: {llm.args.speculative_config}."

0 commit comments

Comments
 (0)