Skip to content

Commit cae6b7c

Browse files
authored
Merge pull request #1975 from pytorch/torch_version_upgrade_jun23_nightly
chore: Upgrade Torch nightly to `2.1.0.dev20230605` [4 / x]
2 parents 82631fa + 784ec9b commit cae6b7c

File tree

13 files changed

+98
-32
lines changed

13 files changed

+98
-32
lines changed

.bazelrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
# +------------------------------------------------------------+
2323
# Enable colorful output of GCC
2424
build --cxxopt="-fdiagnostics-color=always"
25-
build --cxxopt='-std=c++14'
25+
build --cxxopt='-std=c++17'
2626
#build --linkopt="-Wl,--no-as-needed"
2727

2828

.circleci/config.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,10 +269,10 @@ commands:
269269
parameters:
270270
torch-build:
271271
type: string
272-
default: "2.1.0.dev20230419+cu118"
272+
default: "2.1.0.dev20230605+cu118"
273273
torchvision-build:
274274
type: string
275-
default: "0.16.0.dev20230419+cu118"
275+
default: "0.16.0.dev20230605+cu118"
276276
torch-build-index:
277277
type: string
278278
default: "https://download.pytorch.org/whl/nightly/cu118"
@@ -1352,10 +1352,10 @@ parameters:
13521352
# Nightly platform config
13531353
torch-build:
13541354
type: string
1355-
default: "2.1.0.dev20230419+cu118"
1355+
default: "2.1.0.dev20230605+cu118"
13561356
torchvision-build:
13571357
type: string
1358-
default: "0.16.0.dev20230419+cu118"
1358+
default: "0.16.0.dev20230605+cu118"
13591359
torch-build-index:
13601360
type: string
13611361
default: "https://download.pytorch.org/whl/nightly/cu118"

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
cmake_minimum_required(VERSION 3.17)
33
project(Torch-TensorRT LANGUAGES CXX)
44

5-
# use c++14 like PyTorch
6-
set(CMAKE_CXX_STANDARD 14)
5+
# use c++17 like PyTorch
6+
set(CMAKE_CXX_STANDARD 17)
77

88
# Build the libraries with -fPIC
99
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedd
116116
These are the following dependencies used to verify the testcases. Torch-TensorRT can work with other versions, but the tests are not guaranteed to pass.
117117

118118
- Bazel 5.2.0
119-
- Libtorch 2.1.0.dev20230419 (built with CUDA 11.8)
119+
- Libtorch 2.1.0.dev20230605 (built with CUDA 11.8)
120120
- CUDA 11.8
121121
- cuDNN 8.8.0
122122
- TensorRT 8.6.1

WORKSPACE

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,17 @@ new_local_repository(
5151
http_archive(
5252
name = "libtorch",
5353
build_file = "@//third_party/libtorch:BUILD",
54-
sha256 = "1a526a9cd19c1015674d26921dbb94bcd2d632a6f9c431a21c43f4e24768d834",
54+
sha256 = "999becce82b73e566d0ffe010cd21fea8cf3a33f90f09dcc6b01150b820ae063",
5555
strip_prefix = "libtorch",
56-
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"],
56+
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230605%2Bcu118.zip"],
5757
)
5858

5959
http_archive(
6060
name = "libtorch_pre_cxx11_abi",
6161
build_file = "@//third_party/libtorch:BUILD",
62-
sha256 = "60c5912a5085a6a7073b3804b10d41d6cc054693bbeb7a45e0247050c2837bac",
62+
sha256 = "786cc728c63ea69c40bd8fb535cf8e5e1dfff1d43eaad3eb5256b9ed89c1b268",
6363
strip_prefix = "libtorch",
64-
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"],
64+
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230605%2Bcu118.zip"],
6565
)
6666

6767
# Download these tarballs manually from the NVIDIA website

py/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ numpy
22
packaging
33
pybind11==2.6.2
44
--extra-index-url https://download.pytorch.org/whl/nightly/cu118
5-
torch==2.1.0.dev20230419+cu118
6-
torchvision==0.16.0.dev20230419+cu118
5+
torch==2.1.0.dev20230605+cu118
6+
torchvision==0.16.0.dev20230605+cu118
77
--extra-index-url https://pypi.ngc.nvidia.com
88
tensorrt==8.6.1

py/setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,8 @@ def run(self):
427427
ext_modules=ext_modules,
428428
install_requires=[
429429
"torch >=2.1.dev,<2.2" if not LEGACY else "torch >=1.13.0,<2.0",
430+
"pyyaml",
431+
"packaging",
430432
],
431433
setup_requires=[],
432434
cmdclass={

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
partition,
1313
get_submod_inputs,
1414
)
15+
from torch_tensorrt.dynamo.backend.utils import parse_dynamo_kwargs
1516
from torch_tensorrt.dynamo.backend.conversion import convert_module
1617

1718
from torch._dynamo.backends.common import fake_tensor_unsupported
@@ -25,22 +26,20 @@
2526
@td.register_backend(name="torch_tensorrt")
2627
@fake_tensor_unsupported
2728
def torch_tensorrt_backend(
28-
gm: torch.fx.GraphModule,
29-
sample_inputs: Sequence[torch.Tensor],
30-
settings: CompilationSettings = CompilationSettings(),
29+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs
3130
):
3231
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
3332

34-
return DEFAULT_BACKEND(gm, sample_inputs, settings=settings)
33+
return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
3534

3635

3736
@td.register_backend(name="aot_torch_tensorrt_aten")
3837
@fake_tensor_unsupported
3938
def aot_torch_tensorrt_aten_backend(
40-
gm: torch.fx.GraphModule,
41-
sample_inputs: Sequence[torch.Tensor],
42-
settings: CompilationSettings = CompilationSettings(),
39+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs
4340
):
41+
settings = parse_dynamo_kwargs(kwargs)
42+
4443
custom_backend = partial(
4544
_pretraced_backend,
4645
settings=settings,

py/torch_tensorrt/dynamo/backend/conversion.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,21 @@ def convert_module(
2424
Returns:
2525
TRTModule or TRTModuleNext
2626
"""
27+
# Specify module output data types to ensure TRT output types agree with
28+
# that of the equivalent Torch module
29+
module_outputs = module(*inputs)
30+
31+
if not isinstance(module_outputs, (list, tuple)):
32+
module_outputs = [module_outputs]
33+
34+
output_dtypes = list(output.dtype for output in module_outputs)
35+
2736
interpreter = TRTInterpreter(
2837
module,
2938
InputTensorSpec.from_tensors(inputs),
3039
explicit_batch_dimension=True,
3140
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
41+
output_dtypes=output_dtypes,
3242
)
3343

3444
interpreter_result = interpreter.run(

py/torch_tensorrt/dynamo/backend/utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
import torch
2+
import logging
3+
from dataclasses import replace, fields
24

5+
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
36
from typing import Any, Union, Sequence, Dict
47
from torch_tensorrt import _Input, Device
58

69

10+
logger = logging.getLogger(__name__)
11+
12+
713
def prepare_inputs(
814
inputs: Union[_Input.Input, torch.Tensor, Sequence, Dict],
915
device: torch.device = torch.device("cuda"),
@@ -66,3 +72,36 @@ def prepare_device(device: Union[Device, torch.device]) -> torch.device:
6672
)
6773

6874
return device
75+
76+
77+
def parse_dynamo_kwargs(kwargs: Dict) -> CompilationSettings:
78+
"""Parses the kwargs field of a Dynamo backend
79+
80+
Args:
81+
kwargs: Keyword arguments dictionary provided to the backend
82+
Returns:
83+
CompilationSettings object with relevant kwargs
84+
"""
85+
86+
# Initialize an empty CompilationSettings object
87+
settings = CompilationSettings()
88+
89+
# If the user specifies keyword args, overwrite those fields in settings
90+
# Validate all specified kwargs to ensure they are true fields of the dataclass
91+
#
92+
# Note: kwargs provided by torch.compile are wrapped in the "options" key
93+
if kwargs:
94+
if "options" in kwargs and len(kwargs) == 1:
95+
kwargs = kwargs["options"]
96+
97+
valid_attrs = {attr.name for attr in fields(settings)}
98+
valid_kwargs = {k: v for k, v in kwargs.items() if k in valid_attrs}
99+
settings = replace(settings, **valid_kwargs)
100+
101+
# Enable debug/verbose mode if requested
102+
if settings.debug:
103+
logger.setLevel(logging.DEBUG)
104+
105+
logger.debug(f"Compiling with Settings:\n{settings}")
106+
107+
return settings

0 commit comments

Comments
 (0)