Skip to content

Commit 0499493

Browse files
authored
2.3 cherry pick feat: Adding support for native int64 (#2789) (#2802)
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 9e0b547 commit 0499493

28 files changed

+383
-106
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ repos:
1616
- --fix=lf
1717
exclude: ^docs
1818
- repo: https://github.com/pre-commit/mirrors-clang-format
19-
rev: v18.1.1
19+
rev: v14.0.6
2020
hooks:
2121
- id: clang-format
2222
types_or: [c++, c, cuda]

core/runtime/register_jit_hooks.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@ TORCH_LIBRARY(tensorrt, m) {
122122
m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void {
123123
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
124124
});
125+
m.def("set_logging_level", [](int64_t level) -> void {
126+
util::logging::get_logger().set_reportable_log_level(util::logging::LogLevel(level));
127+
});
128+
m.def(
129+
"get_logging_level", []() -> int64_t { return int64_t(util::logging::get_logger().get_reportable_log_level()); });
125130
}
126131

127132
} // namespace

core/util/trt_util.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
292292
{at::kFloat, nvinfer1::DataType::kFLOAT},
293293
{at::kHalf, nvinfer1::DataType::kHALF},
294294
{at::kInt, nvinfer1::DataType::kINT32},
295-
{at::kLong, nvinfer1::DataType::kINT32},
295+
{at::kLong, nvinfer1::DataType::kINT64},
296296
{at::kChar, nvinfer1::DataType::kINT8},
297297
{at::kByte, nvinfer1::DataType::kINT8},
298298
{at::kBool, nvinfer1::DataType::kBOOL}};
@@ -304,6 +304,7 @@ const std::unordered_map<nvinfer1::DataType, at::ScalarType>& get_trt_at_type_ma
304304
{nvinfer1::DataType::kFLOAT, at::kFloat},
305305
{nvinfer1::DataType::kHALF, at::kHalf},
306306
{nvinfer1::DataType::kINT32, at::kInt},
307+
{nvinfer1::DataType::kINT64, at::kLong},
307308
{nvinfer1::DataType::kINT8, at::kChar},
308309
{nvinfer1::DataType::kBOOL, at::kBool},
309310
};

core/util/trt_util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::DataType&
5353
return stream << "Int8";
5454
case nvinfer1::DataType::kINT32:
5555
return stream << "Int32";
56+
case nvinfer1::DataType::kINT64:
57+
return stream << "Int64";
5658
case nvinfer1::DataType::kBOOL:
5759
return stream << "Bool";
5860
default:

py/torch_tensorrt/_enums.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
from typing import Any, Optional, Type, Union
66

77
import numpy as np
8-
import tensorrt as trt
98
import torch
109
from torch_tensorrt._features import ENABLED_FEATURES
1110

11+
import tensorrt as trt
12+
1213

1314
class dtype(Enum):
1415
"""Enum to set supported dtypes in the compiler"""
@@ -103,6 +104,8 @@ def _from(
103104
return dtype.i8
104105
elif t == trt.int32:
105106
return dtype.i32
107+
elif t == trt.int64:
108+
return dtype.i64
106109
elif t == trt.float16:
107110
return dtype.f16
108111
elif t == trt.float32:
@@ -227,6 +230,8 @@ def to(
227230
return trt.DataType.INT8
228231
elif self == dtype.i32:
229232
return trt.DataType.INT32
233+
elif self == dtype.i64:
234+
return trt.DataType.INT64
230235
elif self == dtype.f16:
231236
return trt.DataType.HALF
232237
elif self == dtype.f32:

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import collections.abc
44
import logging
5+
import warnings
56
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
67

78
import torch
@@ -22,7 +23,7 @@
2223
UnsupportedOperatorException,
2324
convert_module,
2425
interpret_module_to_result,
25-
repair_long_or_double_inputs,
26+
repair_double_inputs,
2627
)
2728
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
2829
DYNAMO_CONVERTERS as CONVERTERS,
@@ -58,7 +59,7 @@ def compile(
5859
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
5960
dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE,
6061
dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE,
61-
truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE,
62+
truncate_double: bool = _defaults.TRUNCATE_DOUBLE,
6263
require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION,
6364
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
6465
torch_executed_ops: Optional[Collection[Target]] = None,
@@ -74,7 +75,7 @@ def compile(
7475
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
7576
**kwargs: Any,
7677
) -> torch.fx.GraphModule:
77-
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
78+
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
7879
7980
Takes a existing TorchScript module and a set of settings to configure the compiler
8081
and will convert methods to JIT Graphs which call equivalent TensorRT engines
@@ -115,7 +116,7 @@ def compile(
115116
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
116117
dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations
117118
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
118-
truncate_long_and_double (bool): Truncate weights provided in int64 or double (float64) to int32 and float32
119+
truncate_double (bool): Truncate weights provided in double (float64) to float32
119120
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
120121
require_full_compilation (bool): Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch
121122
min_block_size (int): The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT
@@ -138,6 +139,19 @@ def compile(
138139
if debug:
139140
set_log_level(logger.parent, logging.DEBUG)
140141

142+
if "truncate_long_and_double" in kwargs.keys():
143+
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
144+
raise ValueError(
145+
'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"'
146+
)
147+
else:
148+
truncate_double = kwargs["truncate_long_and_double"]
149+
warnings.warn(
150+
'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version',
151+
DeprecationWarning,
152+
stacklevel=2,
153+
)
154+
141155
engine_capability = EngineCapability._from(engine_capability)
142156

143157
if torch_executed_modules is not None and torch_executed_modules:
@@ -185,7 +199,7 @@ def compile(
185199
"version_compatible": version_compatible,
186200
"optimization_level": optimization_level,
187201
"use_python_runtime": use_python_runtime,
188-
"truncate_long_and_double": truncate_long_and_double,
202+
"truncate_double": truncate_double,
189203
"use_fast_partitioner": use_fast_partitioner,
190204
"num_avg_timing_iters": num_avg_timing_iters,
191205
"enable_experimental_decompositions": enable_experimental_decompositions,
@@ -349,8 +363,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
349363

350364
assert submodule_inputs is not None
351365
# Handle long/double inputs if requested by the user
352-
if settings.truncate_long_and_double:
353-
submodule_inputs = repair_long_or_double_inputs(
366+
if settings.truncate_double:
367+
submodule_inputs = repair_double_inputs(
354368
partitioned_module,
355369
submodule,
356370
submodule_inputs,
@@ -423,7 +437,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
423437

424438
def convert_module_to_trt_engine(
425439
exported_program: ExportedProgram,
426-
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
440+
inputs: Tuple[Any, ...],
441+
*,
427442
enabled_precisions: (
428443
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
429444
) = _defaults.ENABLED_PRECISIONS,
@@ -436,7 +451,7 @@ def convert_module_to_trt_engine(
436451
version_compatible: bool = _defaults.VERSION_COMPATIBLE,
437452
optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL,
438453
use_python_runtime: Optional[bool] = _defaults.USE_PYTHON_RUNTIME,
439-
truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE,
454+
truncate_double: bool = _defaults.TRUNCATE_DOUBLE,
440455
use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER,
441456
enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
442457
device: Device = Device._current_device(),
@@ -451,6 +466,7 @@ def convert_module_to_trt_engine(
451466
dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE,
452467
calibrator: object = None,
453468
allow_shape_tensors: bool = False,
469+
**kwargs: Any,
454470
) -> bytes:
455471
"""Convert an ExportedProgram to a serialized TensorRT engine
456472
@@ -488,7 +504,7 @@ def convert_module_to_trt_engine(
488504
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
489505
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
490506
argument as None
491-
truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32
507+
truncate_double (bool): Whether to truncate float64 TRT engine inputs or weights to float32
492508
use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system
493509
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
494510
or only a selected subset of them
@@ -512,6 +528,19 @@ def convert_module_to_trt_engine(
512528
if debug:
513529
set_log_level(logger.parent, logging.DEBUG)
514530

531+
if "truncate_long_and_double" in kwargs.keys():
532+
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
533+
raise ValueError(
534+
'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"'
535+
)
536+
else:
537+
truncate_double = kwargs["truncate_long_and_double"]
538+
warnings.warn(
539+
'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version',
540+
DeprecationWarning,
541+
stacklevel=2,
542+
)
543+
515544
input_list = list(inputs) if inputs is not None else []
516545
torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set()
517546
# Prepare torch_trt inputs
@@ -531,7 +560,7 @@ def convert_module_to_trt_engine(
531560
"version_compatible": version_compatible,
532561
"optimization_level": optimization_level,
533562
"use_python_runtime": use_python_runtime,
534-
"truncate_long_and_double": truncate_long_and_double,
563+
"truncate_double": truncate_double,
535564
"use_fast_partitioner": use_fast_partitioner,
536565
"enable_experimental_decompositions": enable_experimental_decompositions,
537566
"device": device,

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
VERSION_COMPATIBLE = False
1919
OPTIMIZATION_LEVEL = None
2020
SPARSE_WEIGHTS = False
21-
TRUNCATE_LONG_AND_DOUBLE = False
21+
TRUNCATE_DOUBLE = False
2222
USE_PYTHON_RUNTIME = False
2323
USE_FAST_PARTITIONER = True
2424
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import Collection, Optional, Union
2+
from typing import Collection, Optional, Set, Union
33

44
from torch.fx.node import Target
55
from torch_tensorrt._Device import Device
@@ -23,7 +23,7 @@
2323
REFIT,
2424
REQUIRE_FULL_COMPILATION,
2525
SPARSE_WEIGHTS,
26-
TRUNCATE_LONG_AND_DOUBLE,
26+
TRUNCATE_DOUBLE,
2727
USE_FAST_PARTITIONER,
2828
USE_PYTHON_RUNTIME,
2929
VERSION_COMPATIBLE,
@@ -50,7 +50,7 @@ class CompilationSettings:
5050
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
5151
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
5252
argument as None
53-
truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32
53+
truncate_double (bool): Whether to truncate float64 TRT engine inputs or weights to float32
5454
use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system
5555
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
5656
or only a selected subset of them
@@ -71,7 +71,7 @@ class CompilationSettings:
7171
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
7272
"""
7373

74-
enabled_precisions: dtype = field(default_factory=lambda: ENABLED_PRECISIONS)
74+
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
7575
debug: bool = DEBUG
7676
workspace_size: int = WORKSPACE_SIZE
7777
min_block_size: int = MIN_BLOCK_SIZE
@@ -81,7 +81,7 @@ class CompilationSettings:
8181
version_compatible: bool = VERSION_COMPATIBLE
8282
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
8383
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
84-
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
84+
truncate_double: bool = TRUNCATE_DOUBLE
8585
use_fast_partitioner: bool = USE_FAST_PARTITIONER
8686
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
8787
device: Device = field(default_factory=default_device)

py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,20 @@
2121
from torch.fx.node import Argument, Node, Target, _get_qualified_name
2222
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
2323
from torch_tensorrt.fx.converter_registry import CONVERTERS as FX_CONVERTERS
24-
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
24+
25+
import tensorrt as trt
2526

2627
logger = logging.getLogger(__name__)
2728

2829
LegacyConverterImplSignature = Callable[
2930
[
30-
TRTNetwork,
31+
trt.INetworkDefinition,
3132
Target,
3233
Tuple[Argument, ...],
3334
Dict[str, Argument],
3435
str,
3536
],
36-
Union[TRTTensor, Sequence[TRTTensor]],
37+
Union[trt.ITensor, Sequence[trt.ITensor]],
3738
]
3839

3940
DynamoConverterImplSignature = Callable[
@@ -44,7 +45,7 @@
4445
Dict[str, Argument],
4546
str,
4647
],
47-
Union[TRTTensor, Sequence[TRTTensor]],
48+
Union[trt.ITensor, Sequence[trt.ITensor]],
4849
]
4950

5051
ConverterImplSignature = Union[

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
55

66
import numpy as np
7-
import tensorrt as trt
87
import torch
98
import torch.fx
109
from torch.fx.node import _get_qualified_name
@@ -26,6 +25,7 @@
2625
from torch_tensorrt.fx.observer import Observer
2726
from torch_tensorrt.logging import TRT_LOGGER
2827

28+
import tensorrt as trt
2929
from packaging import version
3030

3131
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -498,6 +498,9 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
498498
)
499499

500500
for i, output in enumerate(outputs):
501+
name = f"output{i}"
502+
503+
output_dtype = dtype.unknown
501504
if any(
502505
op_name in output.name.split("_")
503506
for op_name in (
@@ -514,16 +517,20 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
514517
"any",
515518
)
516519
):
517-
output_bool = True
518-
else:
519-
output_bool = False
520-
name = f"output{i}"
521-
output.name = name
522-
self.ctx.net.mark_output(output)
523-
if output_bool:
524-
output.dtype = trt.DataType.BOOL
520+
output_dtype = dtype.b
525521
elif self.output_dtypes is not None:
526-
output.dtype = self.output_dtypes[i].to(trt.DataType)
522+
if self.output_dtypes[i] == dtype.i64:
523+
output = self.ctx.net.add_cast(
524+
output, dtype.i64.to(trt.DataType)
525+
).get_output(0)
526+
output_dtype = dtype.i64
527+
else:
528+
output_dtype = self.output_dtypes[i]
529+
530+
self.ctx.net.mark_output(output)
531+
if output_dtype is not dtype.unknown:
532+
output.dtype = output_dtype.to(trt.DataType, use_default=True)
533+
output.name = name
527534

528535
self._output_names.append(name)
529536
_LOGGER.debug(

0 commit comments

Comments
 (0)