|
1 | 1 | import inspect
|
2 | 2 | import logging
|
3 |
| -import warnings |
4 | 3 | from copy import deepcopy
|
5 | 4 | from enum import Enum, auto
|
6 |
| -from typing import Any, Dict, Iterator, Optional, Union |
| 5 | +from typing import Any, Dict, Iterator, Optional, Set, Union |
7 | 6 |
|
8 | 7 | import numpy as np
|
9 | 8 | import torch
|
10 | 9 | import torch_tensorrt
|
11 | 10 | from torch.export._trace import _export
|
12 | 11 | from torch_tensorrt._Device import Device
|
| 12 | +from torch_tensorrt._enums import dtype |
13 | 13 | from torch_tensorrt.dynamo import _defaults
|
14 | 14 | from torch_tensorrt.dynamo._compiler import compile as dynamo_compile
|
15 | 15 | from torch_tensorrt.dynamo._refit import refit_module_weights
|
@@ -69,6 +69,7 @@ def __init__(
|
69 | 69 | strict: bool = True,
|
70 | 70 | allow_complex_guards_as_runtime_asserts: bool = False,
|
71 | 71 | weight_streaming_budget: Optional[int] = None,
|
| 72 | + enabled_precisions: Optional[Set[torch.dtype | dtype]] = None, |
72 | 73 | **kwargs: Any,
|
73 | 74 | ) -> None:
|
74 | 75 | """
|
@@ -109,6 +110,7 @@ def __init__(
|
109 | 110 | hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
|
110 | 111 | timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
|
111 | 112 | lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
|
| 113 | + enabled_precisions (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels |
112 | 114 | **kwargs: Any,
|
113 | 115 | Returns:
|
114 | 116 | MutableTorchTensorRTModule
|
@@ -159,6 +161,9 @@ def __init__(
|
159 | 161 | logger.warning(
|
160 | 162 | "Weight stremaing budget is not set. Using auto weight streaming budget"
|
161 | 163 | )
|
| 164 | + self.enabled_precisions = enabled_precisions |
| 165 | + if self.enabled_precisions is None: |
| 166 | + self.enabled_precisions = _defaults.ENABLED_PRECISIONS |
162 | 167 |
|
163 | 168 | cls = self.__class__
|
164 | 169 | self.__class__ = type(
|
@@ -325,9 +330,10 @@ def export_fn() -> torch.export.ExportedProgram:
|
325 | 330 | strict=self.strict,
|
326 | 331 | )
|
327 | 332 |
|
328 |
| - if ( |
329 |
| - torch.float8_e4m3fn in self.additional_settings["enabled_precisions"] |
330 |
| - or torch.int8 in self.additional_settings["enabled_precisions"] |
| 333 | + # Check if any quantization precision is enabled |
| 334 | + if self.enabled_precisions and any( |
| 335 | + precision in self.enabled_precisions |
| 336 | + for precision in (torch.float8_e4m3fn, torch.int8) |
331 | 337 | ):
|
332 | 338 | try:
|
333 | 339 | from modelopt.torch.quantization.utils import export_torch_mode
|
@@ -358,6 +364,7 @@ def compile(self) -> None:
|
358 | 364 | kwarg_inputs=self.kwarg_inputs,
|
359 | 365 | immutable_weights=False,
|
360 | 366 | use_python_runtime=self.use_python_runtime,
|
| 367 | + enabled_precisions=self.enabled_precisions, |
361 | 368 | **self.additional_settings,
|
362 | 369 | )
|
363 | 370 | deallocate_module(self.original_model, delete_module=False)
|
|
0 commit comments