Skip to content

Commit e6086b7

Browse files
fix enabled_precisions error in test cases (#3606)
1 parent fb692c7 commit e6086b7

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import inspect
22
import logging
3-
import warnings
43
from copy import deepcopy
54
from enum import Enum, auto
6-
from typing import Any, Dict, Iterator, Optional, Union
5+
from typing import Any, Dict, Iterator, Optional, Set, Union
76

87
import numpy as np
98
import torch
109
import torch_tensorrt
1110
from torch.export._trace import _export
1211
from torch_tensorrt._Device import Device
12+
from torch_tensorrt._enums import dtype
1313
from torch_tensorrt.dynamo import _defaults
1414
from torch_tensorrt.dynamo._compiler import compile as dynamo_compile
1515
from torch_tensorrt.dynamo._refit import refit_module_weights
@@ -69,6 +69,7 @@ def __init__(
6969
strict: bool = True,
7070
allow_complex_guards_as_runtime_asserts: bool = False,
7171
weight_streaming_budget: Optional[int] = None,
72+
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
7273
**kwargs: Any,
7374
) -> None:
7475
"""
@@ -109,6 +110,7 @@ def __init__(
109110
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
110111
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
111112
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
113+
enabled_precisions (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
112114
**kwargs: Any,
113115
Returns:
114116
MutableTorchTensorRTModule
@@ -159,6 +161,9 @@ def __init__(
159161
logger.warning(
160162
"Weight stremaing budget is not set. Using auto weight streaming budget"
161163
)
164+
self.enabled_precisions = enabled_precisions
165+
if self.enabled_precisions is None:
166+
self.enabled_precisions = _defaults.ENABLED_PRECISIONS
162167

163168
cls = self.__class__
164169
self.__class__ = type(
@@ -325,9 +330,10 @@ def export_fn() -> torch.export.ExportedProgram:
325330
strict=self.strict,
326331
)
327332

328-
if (
329-
torch.float8_e4m3fn in self.additional_settings["enabled_precisions"]
330-
or torch.int8 in self.additional_settings["enabled_precisions"]
333+
# Check if any quantization precision is enabled
334+
if self.enabled_precisions and any(
335+
precision in self.enabled_precisions
336+
for precision in (torch.float8_e4m3fn, torch.int8)
331337
):
332338
try:
333339
from modelopt.torch.quantization.utils import export_torch_mode
@@ -358,6 +364,7 @@ def compile(self) -> None:
358364
kwarg_inputs=self.kwarg_inputs,
359365
immutable_weights=False,
360366
use_python_runtime=self.use_python_runtime,
367+
enabled_precisions=self.enabled_precisions,
361368
**self.additional_settings,
362369
)
363370
deallocate_module(self.original_model, delete_module=False)

tests/py/dynamo/runtime/test_mutable_torchtrt_module.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,9 @@ def test_save():
299299
not torch_trt.ENABLED_FEATURES.refit,
300300
"Refit feature is not supported in Python 3.13 or higher",
301301
)
302+
@unittest.skipIf(
303+
not importlib.util.find_spec("torchvision"), "torchvision not installed"
304+
)
302305
@pytest.mark.unit
303306
def test_resnet18_modify_attribute():
304307
torch.manual_seed(0)
@@ -343,6 +346,9 @@ def test_resnet18_modify_attribute():
343346
not torch_trt.ENABLED_FEATURES.refit,
344347
"Refit feature is not supported in Python 3.13 or higher",
345348
)
349+
@unittest.skipIf(
350+
not importlib.util.find_spec("torchvision"), "torchvision not installed"
351+
)
346352
@pytest.mark.unit
347353
def test_resnet18_modify_attribute_no_refit():
348354
torch.manual_seed(0)

0 commit comments

Comments
 (0)