Skip to content

Commit 380bb57

Browse files
authored
Cherrypick #3455 for release/2.7 (#3480)
1 parent 2015075 commit 380bb57

19 files changed

+245
-44
lines changed

.github/scripts/filter-matrix.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import argparse
44
import json
55
import sys
6+
from typing import List
67

7-
disabled_python_versions = "3.13"
8+
disabled_python_versions: List[str] = []
89

910

1011
def main(args: list[str]) -> None:

.github/scripts/generate-tensorrt-test-matrix.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
# please update the future tensorRT version you want to test here
2929
TENSORRT_VERSIONS_DICT = {
3030
"windows": {
31+
"10.3.0": {
32+
"urls": "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/zip/TensorRT-10.3.0.26.Windows.win10.cuda-12.5.zip",
33+
"strip_prefix": "TensorRT-10.3.0.26",
34+
},
3135
"10.7.0": {
3236
"urls": "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.7.0/zip/TensorRT-10.7.0.23.Windows.win10.cuda-12.6.zip",
3337
"strip_prefix": "TensorRT-10.7.0.23",
@@ -42,6 +46,10 @@
4246
},
4347
},
4448
"linux": {
49+
"10.3.0": {
50+
"urls": "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/tars/TensorRT-10.3.0.26.Linux.x86_64-gnu.cuda-12.5.tar.gz",
51+
"strip_prefix": "TensorRT-10.3.0.26",
52+
},
4553
"10.7.0": {
4654
"urls": "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.7.0/tars/TensorRT-10.7.0.23.Linux.x86_64-gnu.cuda-12.6.tar.gz",
4755
"strip_prefix": "TensorRT-10.7.0.23",

.github/scripts/generate_binary_build_matrix.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@
1818
import sys
1919
from typing import Any, Callable, Dict, List, Optional, Tuple
2020

21+
PYTHON_VERSIONS_FOR_PR_BUILD = ["3.11"]
2122
PYTHON_ARCHES_DICT = {
22-
"nightly": ["3.9", "3.10", "3.11", "3.12"],
23-
"test": ["3.9", "3.10", "3.11", "3.12"],
24-
"release": ["3.9", "3.10", "3.11", "3.12"],
23+
"nightly": ["3.9", "3.10", "3.11", "3.12", "3.13"],
24+
"test": ["3.9", "3.10", "3.11", "3.12", "3.13"],
25+
"release": ["3.9", "3.10", "3.11", "3.12", "3.13"],
2526
}
2627
CUDA_ARCHES_DICT = {
2728
"nightly": ["11.8", "12.6", "12.8"],
2829
"test": ["11.8", "12.6", "12.8"],
29-
"release": ["11.8", "12.6", "12.8"],
30+
"release": ["11.8", "12.4", "12.6"],
3031
}
3132
ROCM_ARCHES_DICT = {
3233
"nightly": ["6.1", "6.2"],
@@ -422,11 +423,6 @@ def generate_wheels_matrix(
422423
# Define default python version
423424
python_versions = list(PYTHON_ARCHES)
424425

425-
# If the list of python versions is set explicitly by the caller, stick with it instead
426-
# of trying to add more versions behind the scene
427-
if channel == NIGHTLY and (os in (LINUX, MACOS_ARM64, LINUX_AARCH64)):
428-
python_versions += ["3.13"]
429-
430426
if os == LINUX:
431427
# NOTE: We only build manywheel packages for linux
432428
package_type = "manywheel"
@@ -456,7 +452,7 @@ def generate_wheels_matrix(
456452
arches += [XPU]
457453

458454
if limit_pr_builds:
459-
python_versions = [python_versions[0]]
455+
python_versions = PYTHON_VERSIONS_FOR_PR_BUILD
460456

461457
global WHEEL_CONTAINER_IMAGES
462458

.github/workflows/build-test-linux.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ jobs:
2323
test-infra-ref: main
2424
with-rocm: false
2525
with-cpu: false
26-
python-versions: '["3.11", "3.12", "3.10", "3.9"]'
2726

2827
filter-matrix:
2928
needs: [generate-matrix]

.github/workflows/build-test-windows.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ jobs:
2323
test-infra-ref: main
2424
with-rocm: false
2525
with-cpu: false
26-
python-versions: '["3.11", "3.12", "3.10", "3.9"]'
2726

2827
substitute-runner:
2928
needs: generate-matrix

py/torch_tensorrt/_features.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"torch_tensorrt_runtime",
1515
"dynamo_frontend",
1616
"fx_frontend",
17+
"refit",
1718
],
1819
)
1920

@@ -36,9 +37,10 @@
3637
_TORCHTRT_RT_AVAIL = _TS_FE_AVAIL or os.path.isfile(linked_file_runtime_full_path)
3738
_DYNAMO_FE_AVAIL = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev")
3839
_FX_FE_AVAIL = True
40+
_REFIT_AVAIL = version.parse(sys.version.split()[0]) < version.parse("3.13")
3941

4042
ENABLED_FEATURES = FeatureSet(
41-
_TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL
43+
_TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL, _REFIT_AVAIL
4244
)
4345

4446

@@ -62,6 +64,22 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
6264
return wrapper
6365

6466

67+
def needs_refit(f: Callable[..., Any]) -> Callable[..., Any]:
68+
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
69+
if ENABLED_FEATURES.refit:
70+
return f(*args, **kwargs)
71+
else:
72+
73+
def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
74+
raise NotImplementedError(
75+
"Refit feature is currently not available in Python 3.13 or higher"
76+
)
77+
78+
return not_implemented(*args, **kwargs)
79+
80+
return wrapper
81+
82+
6583
T = TypeVar("T")
6684

6785

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
from torch.export import ExportedProgram
1212
from torch_tensorrt._enums import dtype
13+
from torch_tensorrt._features import needs_refit
1314
from torch_tensorrt._Input import Input
1415
from torch_tensorrt.dynamo import partitioning
1516
from torch_tensorrt.dynamo._exporter import inline_torch_modules
@@ -46,6 +47,7 @@
4647
logger = logging.getLogger(__name__)
4748

4849

50+
@needs_refit
4951
def construct_refit_mapping(
5052
module: torch.fx.GraphModule,
5153
inputs: Sequence[Input],
@@ -107,8 +109,11 @@ def construct_refit_mapping(
107109
return weight_map
108110

109111

112+
@needs_refit
110113
def construct_refit_mapping_from_weight_name_map(
111-
weight_name_map: dict[Any, Any], state_dict: dict[Any, Any]
114+
weight_name_map: dict[Any, Any],
115+
state_dict: dict[Any, Any],
116+
settings: CompilationSettings,
112117
) -> dict[Any, Any]:
113118
engine_weight_map = {}
114119
for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items():
@@ -119,7 +124,9 @@ def construct_refit_mapping_from_weight_name_map(
119124
# If weights is not in sd, we can leave it unchanged
120125
continue
121126
else:
122-
engine_weight_map[engine_weight_name] = state_dict[sd_weight_name]
127+
engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to(
128+
to_torch_device(settings.device)
129+
)
123130

124131
engine_weight_map[engine_weight_name] = (
125132
engine_weight_map[engine_weight_name]
@@ -133,6 +140,7 @@ def construct_refit_mapping_from_weight_name_map(
133140
return engine_weight_map
134141

135142

143+
@needs_refit
136144
def _refit_single_trt_engine_with_gm(
137145
new_gm: torch.fx.GraphModule,
138146
old_engine: trt.ICudaEngine,
@@ -161,7 +169,7 @@ def _refit_single_trt_engine_with_gm(
161169
"constant_mapping", {}
162170
) # type: ignore
163171
mapping = construct_refit_mapping_from_weight_name_map(
164-
weight_name_map, new_gm.state_dict()
172+
weight_name_map, new_gm.state_dict(), settings
165173
)
166174
constant_mapping_with_type = {}
167175

@@ -211,6 +219,7 @@ def _refit_single_trt_engine_with_gm(
211219
raise AssertionError("Refitting failed.")
212220

213221

222+
@needs_refit
214223
def refit_module_weights(
215224
compiled_module: torch.fx.GraphModule | ExportedProgram,
216225
new_weight_module: ExportedProgram,

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torch.fx.passes.shape_prop import TensorMetadata
2626
from torch.utils._python_dispatch import _disable_current_modes
2727
from torch_tensorrt._enums import dtype
28+
from torch_tensorrt._features import needs_refit
2829
from torch_tensorrt._Input import Input
2930
from torch_tensorrt.dynamo import _defaults
3031
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
@@ -42,7 +43,7 @@
4243
get_node_name,
4344
get_trt_tensor,
4445
)
45-
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device, to_torch_device
46+
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device
4647
from torch_tensorrt.fx.observer import Observer
4748
from torch_tensorrt.logging import TRT_LOGGER
4849

@@ -430,6 +431,7 @@ def check_weight_equal(
430431
except Exception:
431432
return torch.all(sd_weight == network_weight)
432433

434+
@needs_refit
433435
def _save_weight_mapping(self) -> None:
434436
"""
435437
Construct the weight name mapping from engine weight name to state_dict weight name.
@@ -487,15 +489,10 @@ def _save_weight_mapping(self) -> None:
487489
_LOGGER.info("Building weight name mapping...")
488490
# Stage 1: Name mapping
489491
torch_device = to_torch_device(self.compilation_settings.device)
490-
gm_is_on_cuda = get_model_device(self.module).type == "cuda"
491-
if not gm_is_on_cuda:
492-
# If the model original position is on CPU, move it GPU
493-
sd = {
494-
k: v.reshape(-1).to(torch_device)
495-
for k, v in self.module.state_dict().items()
496-
}
497-
else:
498-
sd = {k: v.reshape(-1) for k, v in self.module.state_dict().items()}
492+
sd = {
493+
k: v.reshape(-1).to(torch_device)
494+
for k, v in self.module.state_dict().items()
495+
}
499496
weight_name_map: dict[str, Any] = {}
500497
np_map = {}
501498
constant_mapping = {}
@@ -579,6 +576,7 @@ def _save_weight_mapping(self) -> None:
579576
gc.collect()
580577
torch.cuda.empty_cache()
581578

579+
@needs_refit
582580
def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None:
583581
# TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
584582
# if not self.compilation_settings.strip_engine_weights:
@@ -606,6 +604,7 @@ def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> No
606604
),
607605
)
608606

607+
@needs_refit
609608
def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
610609
# query the cached TRT engine
611610
cached_data = self.engine_cache.check(hash_val) # type: ignore[union-attr]
@@ -716,7 +715,7 @@ def run(
716715
if self.compilation_settings.reuse_cached_engines:
717716
interpreter_result = self._pull_cached_engine(hash_val)
718717
if interpreter_result is not None: # hit the cache
719-
return interpreter_result
718+
return interpreter_result # type: ignore[no-any-return]
720719

721720
self._construct_trt_network_def()
722721

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[build-system]
22
requires = [
3-
"setuptools>=68.0.0",
3+
"setuptools>=77.0.0",
44
"packaging>=23.1",
55
"wheel>=0.40.0",
66
"ninja>=1.11.0",

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
import torch
1919
import yaml
2020
from setuptools import Extension, find_namespace_packages, setup
21+
from setuptools.command.bdist_wheel import bdist_wheel
2122
from setuptools.command.build_ext import build_ext
2223
from setuptools.command.develop import develop
2324
from setuptools.command.editable_wheel import editable_wheel
2425
from setuptools.command.install import install
2526
from torch.utils.cpp_extension import IS_WINDOWS, BuildExtension, CUDAExtension
26-
from wheel.bdist_wheel import bdist_wheel
2727

2828
__version__: str = "0.0.0"
2929
__cuda_version__: str = "0.0"

0 commit comments

Comments
 (0)