Skip to content

Commit cdab9ec

Browse files
authored
Merge pull request #349 from NVIDIA/bowa_development
Add an example to compile converter into .so and run it in Python
2 parents 91ce3e3 + 58fbf12 commit cdab9ec

File tree

5 files changed

+283
-0
lines changed

5 files changed

+283
-0
lines changed

examples/custom_converters/README.md

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Create a new op in C++, compile it to .so library and load it in Python
2+
3+
There are some operators in PyTorch library which are not supported in TRTorch.
4+
To support these ops, users can register converters for missing ops. For example,
5+
if we try to compile a graph with a build of TRTorch that doesn't support the
6+
[ELU](https://pytorch.org/docs/stable/generated/torch.nn.ELU.html) operation,
7+
we will get following error:
8+
9+
> Unable to convert node: %result.2 : Tensor = aten::elu(%x.1, %2, %3, %3) # /home/bowa/.local/lib/python3.6/site-packages/torch/nn/functional.py:1227:17 (conversion.AddLayer)
10+
Schema: aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor)
11+
Converter for aten::elu requested, but no such converter was found.
12+
If you need a converter for this operator, you can try implementing one yourself
13+
or request a converter: https://www.github.com/NVIDIA/TRTorch/issues
14+
15+
Note that ELU converter is now supported in our library. If you want to get above
16+
error and run the example in this document, you can either:
17+
1. get the source code, go to root directory, then run: <br />
18+
`git apply ./examples/custom_converters/elu_converter/disable_core_elu.patch`
19+
2. If you are using a pre-downloaded release of TRTorch, you need to make sure that
20+
it doesn't support elu operator in default. (TRTorch <= v0.1.0)
21+
22+
## Writing Converter in C++
23+
We can register a converter for this operator in our application. You can find more
24+
information on all the details of writing converters in the contributors documentation
25+
([Writing Converters](https://nvidia.github.io/TRTorch/contributors/writing_converters.html)).
26+
Once we are clear about these rules and writing patterns, we can create a seperate new C++ source file as:
27+
28+
```c++
29+
#include "core/conversion/converters/converters.h"
30+
#include "core/util/prelude.h"
31+
32+
namespace my_custom_converters {
33+
34+
auto actelu = trtorch::core::conversion::converters::RegisterNodeConversionPatterns().pattern(
35+
{"aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor)",
36+
[](trtorch::core::conversion::ConversionCtx* ctx,
37+
const torch::jit::Node* n,
38+
trtorch::core::conversion::converters::args& args) -> bool {
39+
auto in = args[0].ITensorOrFreeze(ctx);
40+
auto alpha = args[1].unwrapToDouble();
41+
42+
auto new_layer = ctx->net->addActivation(*in, nvinfer1::ActivationType::kELU);
43+
if (!(new_layer)) {
44+
std::cerr << "Unable to create layer for aten::elu" << std::endl;
45+
}
46+
47+
new_layer->setAlpha(alpha);
48+
new_layer->setName(trtorch::core::util::node_info(n).c_str());
49+
ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
50+
51+
return true;
52+
}});
53+
54+
} // namespace my_custom_converters
55+
```
56+
57+
## Generate `.so` library
58+
To use this converter in Python, it is recommended to use PyTorch's
59+
[C++/CUDA Extension](https://pytorch.org/tutorials/advanced/cpp_extension.html#custom-c-and-cuda-extensions).
60+
We give an example here about how to wrap the converter into a `.so`
61+
library so that you can load it to use in Python applicaton.
62+
```python
63+
import os
64+
from setuptools import setup, Extension
65+
from torch.utils import cpp_extension
66+
67+
68+
# library_dirs should point to the libtrtorch.so, include_dirs should point to the dir that include the headers
69+
# 1) download the latest package from https://github.com/NVIDIA/TRTorch/releases/
70+
# 2) Extract the file from downloaded package, we will get the "trtorch" directory
71+
# 3) Set trtorch_path to that directory
72+
trtorch_path = <PATH TO TRTORCH>
73+
74+
75+
ext_modules = [
76+
cpp_extension.CUDAExtension('elu_converter', ['./csrc/elu_converter.cpp'],
77+
library_dirs=[(trtorch_path + "/lib/")],
78+
libraries=["trtorch"],
79+
include_dirs=[trtorch_path + "/include/trtorch/"])
80+
]
81+
82+
setup(
83+
name='elu_converter',
84+
ext_modules=ext_modules,
85+
cmdclass={'build_ext': cpp_extension.BuildExtension},
86+
)
87+
```
88+
Make sure to include the path for header files in `include_dirs` and the path
89+
for dependent libraries in `library_dirs`. Generally speaking, you should download
90+
the latest package from [here](https://github.com/NVIDIA/TRTorch/releases), extract
91+
the files, and the set the `trtorch_path` to it. You could also add other compilation
92+
flags in cpp_extension if you need. Then, run above python scripts as:
93+
```shell
94+
python3 setup.py install --user
95+
```
96+
You should see the output similar to the contents indicated [here](https://pytorch.org/tutorials/advanced/cpp_extension.html#custom-c-and-cuda-extensions) after running
97+
`python setup.py install`. You should find a couple of new folders generated
98+
by the command above. In build folder, you can find the generated `.so` library,
99+
which could be loaded in our Python application.
100+
101+
## Load `.so` in Python Application
102+
With the new generated library, TRTorch now support the new developed converter.
103+
We use `torch.ops.load_library` to load `.so`. For example, we could load the ELU
104+
converter and use it in our application:
105+
```python
106+
import torch
107+
import trtorch
108+
109+
# After "python3 setup install", you should find this .so file under generated "build" directory
110+
torch.ops.load_library('./elu_converter/build/lib.linux-x86_64-3.6/elu_converter.cpython-36m-x86_64-linux-gnu.so')
111+
112+
113+
class Elu(torch.nn.Module):
114+
115+
def __init__(self):
116+
super(Elu, self).__init__()
117+
self.elu = torch.nn.ELU()
118+
119+
def forward(self, x):
120+
return self.elu(x)
121+
122+
123+
def cal_max_diff(pytorch_out, trtorch_out):
124+
diff = torch.sub(pytorch_out, trtorch_out)
125+
abs_diff = torch.abs(diff)
126+
max_diff = torch.max(abs_diff)
127+
print("Maximum differnce between TRTorch and PyTorch: \n", max_diff)
128+
129+
130+
def main():
131+
model = Elu().eval() #.cuda()
132+
133+
scripted_model = torch.jit.script(model)
134+
compile_settings = {
135+
"input_shapes": [{
136+
"min": [1024, 1, 32, 32],
137+
"opt": [1024, 1, 33, 33],
138+
"max": [1024, 1, 34, 34],
139+
}],
140+
"op_precision":
141+
torch.half # Run with FP16
142+
}
143+
trt_ts_module = trtorch.compile(scripted_model, compile_settings)
144+
input_data = torch.randn((1024, 1, 32, 32))
145+
input_data = input_data.half().to("cuda")
146+
pytorch_out = model.forward(input_data)
147+
148+
trtorch_out = trt_ts_module(input_data)
149+
print('PyTorch output: \n', pytorch_out[0, :, :, 0])
150+
print('TRTorch output: \n', trtorch_out[0, :, :, 0])
151+
cal_max_diff(pytorch_out, trtorch_out)
152+
153+
154+
if __name__ == "__main__":
155+
main()
156+
157+
```
158+
Run this script, we can get the different outputs from PyTorch and TRTorch.
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#include "core/conversion/converters/converters.h"
2+
#include "core/util/prelude.h"
3+
4+
namespace my_custom_converters {
5+
6+
auto actelu = trtorch::core::conversion::converters::RegisterNodeConversionPatterns().pattern(
7+
{"aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor)",
8+
[](trtorch::core::conversion::ConversionCtx* ctx,
9+
const torch::jit::Node* n,
10+
trtorch::core::conversion::converters::args& args) -> bool {
11+
auto in = args[0].ITensorOrFreeze(ctx);
12+
auto alpha = args[1].unwrapToDouble();
13+
14+
auto new_layer = ctx->net->addActivation(*in, nvinfer1::ActivationType::kELU);
15+
if (!(new_layer)) {
16+
std::cerr << "Unable to create layer for aten::elu" << std::endl;
17+
}
18+
19+
new_layer->setAlpha(alpha);
20+
new_layer->setName(trtorch::core::util::node_info(n).c_str());
21+
ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
22+
23+
return true;
24+
}});
25+
26+
} // namespace my_custom_converters
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
diff --git a/core/conversion/converters/impl/activation.cpp b/core/conversion/converters/impl/activation.cpp
2+
index 64edeed..5279413 100644
3+
--- a/core/conversion/converters/impl/activation.cpp
4+
+++ b/core/conversion/converters/impl/activation.cpp
5+
@@ -152,21 +152,6 @@ auto acthardtanh TRTORCH_UNUSED =
6+
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
7+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
8+
return true;
9+
- }})
10+
- .pattern({"aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor)",
11+
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
12+
- auto in = args[0].ITensorOrFreeze(ctx);
13+
- auto alpha = args[1].unwrapToDouble();
14+
-
15+
- auto new_layer = ctx->net->addActivation(*in, nvinfer1::ActivationType::kELU);
16+
- TRTORCH_CHECK(new_layer, "Unable to create layer for aten::elu");
17+
- new_layer->setAlpha(alpha);
18+
-
19+
- new_layer->setName(trtorch::core::util::node_info(n).c_str());
20+
-
21+
- auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
22+
- LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
23+
- return true;
24+
}});
25+
26+
} // namespace
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import os
2+
from setuptools import setup, Extension
3+
from torch.utils import cpp_extension
4+
5+
6+
# library_dirs should point to the libtrtorch.so, include_dirs should point to the dir that include the headers
7+
# 1) download the latest package from https://github.com/NVIDIA/TRTorch/releases/
8+
# 2) Extract the file from downloaded package, we will get the "trtorch" directory
9+
# 3) Set trtorch_path to that directory
10+
trtorch_path = <PATH TO TRTORCH>
11+
12+
ext_modules = [
13+
cpp_extension.CUDAExtension('elu_converter', ['./csrc/elu_converter.cpp'],
14+
library_dirs=[(trtorch_path + "/lib/")],
15+
libraries=["trtorch"],
16+
include_dirs=[trtorch_path + "/include/trtorch/"])
17+
]
18+
19+
setup(
20+
name='elu_converter',
21+
ext_modules=ext_modules,
22+
cmdclass={'build_ext': cpp_extension.BuildExtension},
23+
)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import torch
2+
import trtorch
3+
4+
# After "python3 setup install", you should find this .so file under generated "build" directory
5+
torch.ops.load_library('./elu_converter/build/lib.linux-x86_64-3.6/elu_converter.cpython-36m-x86_64-linux-gnu.so')
6+
7+
8+
class Elu(torch.nn.Module):
9+
10+
def __init__(self):
11+
super(Elu, self).__init__()
12+
self.elu = torch.nn.ELU()
13+
14+
def forward(self, x):
15+
return self.elu(x)
16+
17+
18+
def cal_max_diff(pytorch_out, trtorch_out):
19+
diff = torch.sub(pytorch_out, trtorch_out)
20+
abs_diff = torch.abs(diff)
21+
max_diff = torch.max(abs_diff)
22+
print("Maximum differnce between TRTorch and PyTorch: \n", max_diff)
23+
24+
25+
def main():
26+
model = Elu().eval() #.cuda()
27+
28+
scripted_model = torch.jit.script(model)
29+
compile_settings = {
30+
"input_shapes": [{
31+
"min": [1024, 1, 32, 32],
32+
"opt": [1024, 1, 33, 33],
33+
"max": [1024, 1, 34, 34],
34+
}],
35+
"op_precision":
36+
torch.half # Run with FP16
37+
}
38+
trt_ts_module = trtorch.compile(scripted_model, compile_settings)
39+
input_data = torch.randn((1024, 1, 32, 32))
40+
input_data = input_data.half().to("cuda")
41+
pytorch_out = model.forward(input_data)
42+
43+
trtorch_out = trt_ts_module(input_data)
44+
print('PyTorch output: \n', pytorch_out[0, :, :, 0])
45+
print('TRTorch output: \n', trtorch_out[0, :, :, 0])
46+
cal_max_diff(pytorch_out, trtorch_out)
47+
48+
49+
if __name__ == "__main__":
50+
main()

0 commit comments

Comments
 (0)