Skip to content

Commit cbed1b9

Browse files
committed
refactor(//py): Correct exports for the package
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent c0866e2 commit cbed1b9

File tree

10 files changed

+349
-83
lines changed

10 files changed

+349
-83
lines changed

py/BUILD

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package(default_visibility = ["//visibility:public"])
2+
load("@trtorch_py_deps//:requirements.bzl", "requirement")
3+
4+
5+
# Exposes the library for testing
6+
py_library(
7+
name = "trtorch",
8+
srcs = [
9+
"trtorch/__init__.py",
10+
"trtorch/_version.py",
11+
"trtorch/_compiler.py",
12+
"trtorch/_extra_info.py",
13+
"trtorch/_types.py",
14+
"trtorch/logging.py"
15+
],
16+
data = [
17+
"trtorch/lib/libtrtorch.so"
18+
] + glob([
19+
"trtorch/_C.cpython*.so"
20+
]),
21+
deps = [
22+
requirement("torch")
23+
]
24+
)

py/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
torch==1.5.0

py/setup.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,33 @@
1111
from torch.utils import cpp_extension
1212
from shutil import copyfile, rmtree
1313

14+
import subprocess
15+
1416
dir_path = os.path.dirname(os.path.realpath(__file__))
1517

1618
__version__ = '0.0.1'
1719

20+
def build_libtrtorch_pre_cxx11_abi(develop=True):
21+
cmd = ["/usr/bin/bazel", "build"]
22+
cmd.append("//cpp/api/lib:libtrtorch.so")
23+
if develop:
24+
cmd.append("--compilation_mode=dbg")
25+
else:
26+
cmd.append("--compilation_mode=opt")
27+
cmd.append("--config=python")
28+
29+
print("building libtrtorch")
30+
status_code = subprocess.run(cmd).returncode
31+
32+
if status_code != 0:
33+
sys.exit(status_code)
34+
35+
1836
def gen_version_file():
19-
if not os.path.exists(dir_path + '/trtorch/version.py'):
20-
os.mknod(dir_path + '/trtorch/version.py')
37+
if not os.path.exists(dir_path + '/trtorch/_version.py'):
38+
os.mknod(dir_path + '/trtorch/_version.py')
2139

22-
with open(dir_path + '/trtorch/version.py', 'w') as f:
40+
with open(dir_path + '/trtorch/_version.py', 'w') as f:
2341
print("creating version file")
2442
f.write("__version__ = \"" + __version__ + '\"')
2543

@@ -40,6 +58,7 @@ def finalize_options(self):
4058
develop.finalize_options(self)
4159

4260
def run(self):
61+
build_libtrtorch_pre_cxx11_abi(develop=True)
4362
gen_version_file()
4463
copy_libtrtorch()
4564
develop.run(self)
@@ -55,6 +74,7 @@ def finalize_options(self):
5574
install.finalize_options(self)
5675

5776
def run(self):
77+
build_libtrtorch_pre_cxx11_abi(develop=False)
5878
gen_version_file()
5979
copy_libtrtorch()
6080
install.run(self)
@@ -110,7 +130,7 @@ def run(self):
110130
setup(
111131
name='trtorch',
112132
version=__version__,
113-
author='NVIDIA Corporation.',
133+
author='NVIDIA',
114134
author_email='[email protected]',
115135
url='https://github.com/nvidia/trtorch',
116136
description='A compiler backend for PyTorch JIT targeting NVIDIA GPUs',

py/trtorch/__init__.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@ def _load_trtorch_lib():
1515

1616
_load_trtorch_lib()
1717

18-
from .version import __version__
19-
from trtorch import _C
20-
from trtorch.compiler import *
21-
from trtorch.types import *
22-
23-
def test(mod, data):
24-
_C._test(mod._c, data)
18+
from trtorch._version import __version__
19+
from trtorch._compiler import *
20+
from trtorch._types import *
21+
from trtorch import logging

py/trtorch/_compiler.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from typing import List, Dict, Any
2+
import torch
3+
import trtorch._C
4+
from trtorch._extra_info import _parse_extra_info
5+
from trtorch._version import __version__
6+
7+
def compile(module: torch.jit.ScriptModule, extra_info: Any) -> torch.jit.ScriptModule:
8+
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
9+
10+
Takes a existing TorchScript module and a set of settings to configure the compiler
11+
and will convert methods to JIT Graphs which call equivalent TensorRT engines
12+
13+
Converts specifically the forward method of a TorchScript Module
14+
15+
Args:
16+
module (torch.jit.ScriptModule): Source module, a result of tracing or scripting a PyTorch
17+
``torch.nn.Module``
18+
extra_info (dict): Compilation settings including operating precision, target device, etc.
19+
One key is required which is ``input_shapes``, describing the input sizes or ranges for inputs
20+
to the graph. All other keys are optional
21+
22+
.. code-block:: py
23+
24+
ExtraInfo = {
25+
"input_shapes": [
26+
(1, 3, 224, 224), # Static input shape for input #1
27+
{
28+
"min": (1, 3, 224, 224),
29+
"opt": (1, 3, 512, 512),
30+
"max": (1, 3, 1024, 1024)
31+
} # Dynamic input shape for input #2
32+
],
33+
"op_precision": torch.half, # Operating precision set to FP16
34+
"refit": false, # enable refit
35+
"debug": false, # enable debuggable engine
36+
"strict_types": false, # kernels should strictly run in operating precision
37+
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
38+
"device": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
39+
"capability": trtorch.EngineCapability.DEFAULT, # Restrict kernel selection to safe gpu kernels or safe dla kernels
40+
"num_min_timing_iters": 2, # Number of minimization timing iterations used to select kernels
41+
"num_avg_timing_iters": 1, # Number of averaging timing iterations used to select kernels
42+
"workspace_size": 0, # Maximum size of workspace given to TensorRT
43+
"max_batch_size": 0, # Maximum batch size (must be >= 1 to be set, 0 means not set)
44+
}
45+
46+
Input Sizes can be specified as torch sizes, tuples or lists. Op precisions can be specified using
47+
torch datatypes or trtorch datatypes and you can use either torch devices or the trtorch device type enum
48+
to select device type.
49+
50+
Returns:
51+
torch.jit.ScriptModule: Compiled TorchScript Module, when run it will execute via TensorRT
52+
"""
53+
compiled_cpp_mod = trtorch._C._compile_graph(module._c, _parse_extra_info(extra_info))
54+
compiled_module = torch.jit._recursive.wrap_cpp_module(compiled_cpp_mod)
55+
return compiled_module
56+
57+
def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: str, extra_info: Any) -> str:
58+
"""Convert a TorchScript module method to a serialized TensorRT engine
59+
60+
Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings
61+
62+
Args:
63+
module (torch.jit.ScriptModule): Source module, a result of tracing or scripting a PyTorch
64+
``torch.nn.Module``
65+
method_name (str): Name of method to convert
66+
extra_info (dict): Compilation settings including operating precision, target device, etc.
67+
One key is required which is ``input_shapes``, describing the input sizes or ranges for inputs
68+
to the graph. All other keys are optional
69+
70+
.. code-block:: py
71+
72+
ExtraInfo = {
73+
"input_shapes": [
74+
(1, 3, 224, 224), # Static input shape for input #1
75+
{
76+
"min": (1, 3, 224, 224),
77+
"opt": (1, 3, 512, 512),
78+
"max": (1, 3, 1024, 1024)
79+
} # Dynamic input shape for input #2
80+
],
81+
"op_precision": torch.half, # Operating precision set to FP16
82+
"refit": false, # enable refit
83+
"debug": false, # enable debuggable engine
84+
"strict_types": false, # kernels should strictly run in operating precision
85+
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
86+
"device": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
87+
"capability": trtorch.EngineCapability.DEFAULT, # Restrict kernel selection to safe gpu kernels or safe dla kernels
88+
"num_min_timing_iters": 2, # Number of minimization timing iterations used to select kernels
89+
"num_avg_timing_iters": 1, # Number of averaging timing iterations used to select kernels
90+
"workspace_size": 0, # Maximum size of workspace given to TensorRT
91+
"max_batch_size": 0, # Maximum batch size (must be >= 1 to be set, 0 means not set)
92+
}
93+
94+
Input Sizes can be specified as torch sizes, tuples or lists. Op precisions can be specified using
95+
torch datatypes or trtorch datatypes and you can use either torch devices or the trtorch device type enum
96+
to select device type.
97+
98+
Returns:
99+
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
100+
"""
101+
return trtorch._C._convert_graph_to_trt_engine(module._c, method_name, _parse_extra_info(extra_info))
102+
103+
def check_method_op_support(module: torch.jit.ScriptModule, method_name: str) -> bool:
104+
"""Checks to see if a method is fully supported by TRTorch
105+
106+
Checks if a method of a TorchScript module can be compiled by TRTorch, if not, a list of operators
107+
that are not supported are printed out and the function returns false, else true.
108+
109+
Args:
110+
module (torch.jit.ScriptModule): Source module, a result of tracing or scripting a PyTorch
111+
``torch.nn.Module``
112+
method_name (str): Name of method to check
113+
114+
Returns:
115+
bool: True if supported Method
116+
"""
117+
return trtorch._C._check_method_op_support(module._c, method_name)
118+
119+
def dump_build_info():
120+
"""Prints build information about the TRTorch distribution to stdout
121+
"""
122+
print(get_build_info())
123+
124+
def get_build_info() -> str:
125+
"""Returns a string containing the build information of TRTorch distribution
126+
127+
Returns:
128+
str: String containing the build information for TRTorch distribution
129+
"""
130+
build_info = trtorch._C._get_build_info()
131+
build_info = "TRTorch Version: " + str(__version__) + '\n' + build_info
132+
return build_info
133+

py/trtorch/compiler.py renamed to py/trtorch/_extra_info.py

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from typing import List, Dict, Any
22
import torch
3-
import tensorrt as trt
43
import trtorch._C
5-
from trtorch import types
6-
from .version import __version__
4+
from trtorch import _types
75

86
def _supported_input_size_type(input_size: Any) -> bool:
97
if isinstance(input_size, torch.Size):
@@ -56,31 +54,31 @@ def _parse_input_ranges(input_sizes: List) -> List:
5654

5755
return parsed_input_sizes
5856

59-
def _parse_op_precision(precision: Any) -> types.dtype:
57+
def _parse_op_precision(precision: Any) -> _types.dtype:
6058
if isinstance(precision, torch.dtype):
6159
if precision == torch.int8:
62-
return types.dtype.int8
60+
return _types.dtype.int8
6361
elif precision == torch.half:
64-
return types.dtype.half
62+
return _types.dtype.half
6563
elif precision == torch.float:
66-
return types.dtype.float
64+
return _types.dtype.float
6765
else:
6866
raise TypeError("Provided an unsupported dtype as operating precision (support: int8, half, float), got: " + str(precision))
6967

70-
elif isinstance(precision, types.DataTypes):
68+
elif isinstance(precision, _types.DataTypes):
7169
return precision
7270

7371
else:
7472
raise TypeError("Op precision type needs to be specified with a torch.dtype or a trtorch.dtype, got: " + str(type(precision)))
7573

76-
def _parse_device_type(device: Any) -> types.DeviceType:
74+
def _parse_device_type(device: Any) -> _types.DeviceType:
7775
if isinstance(device, torch.device):
7876
if torch.device.type == 'cuda':
79-
return types.DeviceType.gpu
77+
return _types.DeviceType.gpu
8078
else:
8179
raise TypeError("Valid device choices are GPU (and DLA if on Jetson platforms) however got device type" + str(device.type))
8280

83-
elif isinstance(device, types.DeviceType):
81+
elif isinstance(device, _types.DeviceType):
8482
return device
8583

8684
else:
@@ -120,7 +118,6 @@ def _parse_extra_info(extra_info: Dict[str, Any]) -> trtorch._C._ExtraInfo:
120118
assert isinstance(extra_info["capability"], type.EngineCapability)
121119
info.capability = extra_info["capability"]
122120

123-
124121
if "num_min_timing_iters" in extra_info:
125122
assert type(extra_info["num_min_timing_iters"]) is int
126123
info.num_min_timing_iters = extra_info["num_min_timing_iters"]
@@ -137,22 +134,4 @@ def _parse_extra_info(extra_info: Dict[str, Any]) -> trtorch._C._ExtraInfo:
137134
assert type(extra_info["max_batch_size"]) is int
138135
info.max_batch_size = extra_info["max_batch_size"]
139136

140-
return info
141-
142-
def compile_module(module: torch.jit.ScriptModule, extra_info: Any) -> torch.jit.ScriptModule:
143-
return module
144-
145-
def convert_graph_to_trt_engine(module: torch.jit.ScriptModule, method_name: str, extra_info: Any) -> str:
146-
return trtorch._C._convert_graph_to_trt_engine(module._c, method_name, _parse_extra_info(extra_info))
147-
148-
def check_method_op_support(module: torch.jit.ScriptModule, method_name: str) -> bool:
149-
return trtorch._C._check_method_op_support(module._c, method_name)
150-
151-
def dump_build_info():
152-
print(get_build_info())
153-
154-
def get_build_info() -> str:
155-
build_info = trtorch._C._get_build_info()
156-
build_info = "TRTorch Version: " + str(__version__) + '\n' + build_info
157-
return build_info
158-
137+
return info
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)