Skip to content

Commit 2ae7cd1

Browse files
committed
Created ELU converter, compiled it to .so and run it in Python
Signed-off-by: Bo Wang <[email protected]>
1 parent 6442fce commit 2ae7cd1

File tree

4 files changed

+254
-0
lines changed

4 files changed

+254
-0
lines changed

examples/README.md

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Create a new op in C++, compile it to .so library and load it in Python
2+
3+
There are some operators in PyTorch library which are not supported in TRTorch.
4+
To support these ops, users can register converters for missing ops. For example,
5+
if we try to compile a graph with a build of TRTorch that doesn't support the
6+
[ELU](https://pytorch.org/docs/stable/generated/torch.nn.ELU.html) operation,
7+
we will get following error:
8+
9+
> Unable to convert node: %result.2 : Tensor = aten::elu(%x.1, %2, %3, %3) # /home/bowa/.local/lib/python3.6/site-packages/torch/nn/functional.py:1227:17 (conversion.AddLayer)
10+
Schema: aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor)
11+
Converter for aten::elu requested, but no such converter was found.
12+
If you need a converter for this operator, you can try implementing one yourself
13+
or request a converter: https://www.github.com/NVIDIA/TRTorch/issues
14+
15+
## Writing Converter in C++
16+
We can register a converter for this operator in our application. You can find more
17+
information on all the details of writing converters in the contributors documentation
18+
([Writing Converters](https://nvidia.github.io/TRTorch/contributors/writing_converters.html)).
19+
Once we are clear about these rules and writing patterns, we can create a seperate new C++ source file as:
20+
21+
```c++
22+
#include "core/conversion/converters/converters.h"
23+
#include "core/util/prelude.h"
24+
25+
namespace trtorch {
26+
namespace core {
27+
namespace conversion {
28+
namespace converters {
29+
namespace impl {
30+
namespace {
31+
32+
auto acthardtanh TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
33+
{"aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor)",
34+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
35+
auto in = args[0].ITensorOrFreeze(ctx);
36+
auto alpha = args[1].unwrapToDouble();
37+
38+
auto new_layer = ctx->net->addActivation(*in, nvinfer1::ActivationType::kELU);
39+
TRTORCH_CHECK(new_layer, "Unable to create layer for aten::elu");
40+
41+
new_layer->setAlpha(alpha);
42+
new_layer->setName(util::node_info(n).c_str());
43+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
44+
45+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
46+
return true;
47+
}});
48+
49+
} // namespace
50+
} // namespace impl
51+
} // namespace converters
52+
} // namespace conversion
53+
} // namespace core
54+
} // namespace trtorch
55+
```
56+
57+
## Generate `.so` library
58+
To use this converter in Python, it is recommended to use PyTorch's
59+
[C++/CUDA Extension](https://pytorch.org/tutorials/advanced/cpp_extension.html#custom-c-and-cuda-extensions).
60+
We give an example here about how to wrap the converter into a `.so`
61+
library so that you can load it to use in Python applicaton.
62+
```python
63+
import os
64+
from setuptools import setup, Extension
65+
from torch.utils import cpp_extension
66+
67+
dir_path = os.path.dirname(os.path.realpath(__file__))
68+
69+
ext_modules = [
70+
cpp_extension.CUDAExtension('elu_converter', ['elu_converter.cpp'],
71+
library_dirs=[(
72+
dir_path + "/../../bazel-bin/cpp/api/lib/"
73+
)],
74+
libraries=["trtorch"],
75+
include_dirs=[dir_path + "/../../"]
76+
)
77+
]
78+
79+
setup(
80+
name='elu_converter',
81+
ext_modules=ext_modules,
82+
cmdclass={'build_ext': cpp_extension.BuildExtension},
83+
)
84+
```
85+
Make sure to include the path for header files in `include_dirs` and the path
86+
for dependent libraries in `library_dirs`. You could also add other compilation
87+
flags in cpp_extension if you need. Then, run above python scripts as:
88+
```shell
89+
python3 setup.py install --user
90+
```
91+
You should see the output similar to the contents indicated [here](https://pytorch.org/tutorials/advanced/cpp_extension.html#custom-c-and-cuda-extensions) after running
92+
`python setup.py install`. You should find a couple of new folders generated
93+
by the command above. In build folder, you can find the generated `.so` library,
94+
which could be loaded in our Python application.
95+
96+
## Load `.so` in Python Application
97+
With the new generated library, TRTorch now support the new developed converter.
98+
We use `torch.ops.load_library` to load `.so`. For example, we could load the ELU
99+
converter and use it in our application:
100+
```python
101+
import torch
102+
import trtorch
103+
104+
torch.ops.load_library('./build/lib.linux-x86_64-3.6/elu_converter.cpython-36m-x86_64-linux-gnu.so')
105+
106+
class Elu(torch.nn.Module):
107+
def __init__(self):
108+
super(Elu, self).__init__()
109+
self.elu = torch.nn.ELU()
110+
111+
def forward(self, x):
112+
return self.elu(x)
113+
114+
def main():
115+
data = torch.randn((1, 1, 2, 2)).to("cuda")
116+
model = Elu().eval() #.cuda()
117+
118+
scripted_model = torch.jit.script(model)
119+
print(scripted_model.graph)
120+
compile_settings = {
121+
"input_shapes": [{
122+
"min": [1024, 1, 32, 32],
123+
"opt": [1024, 1, 33, 33],
124+
"max": [1024, 1, 34, 34],
125+
}],
126+
"op_precision":
127+
torch.half # Run with FP16
128+
}
129+
trt_ts_module = trtorch.compile(scripted_model, compile_settings)
130+
input_data = torch.randn((1024, 1, 32, 32))
131+
print(input_data[0, :, :, 0])
132+
input_data = input_data.half().to("cuda")
133+
result = trt_ts_module(input_data)
134+
print(result[0, :, :, 0])
135+
136+
if __name__ == "__main__":
137+
main()
138+
139+
```
140+
Run this script, we can get the Tensor before and after ELU operator.
141+
### Example Output
142+
```bash
143+
graph(%self : __torch__.Elu,
144+
%x.1 : Tensor):
145+
%2 : __torch__.torch.nn.modules.activation.ELU = prim::GetAttr[name="elu"](%self)
146+
%4 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) # elu_converter_test.py:13:15
147+
return (%4)
148+
149+
tensor([[ 1.3482, 1.9848, -1.0818, -1.3252, 0.2470, 0.7011, 0.3174, -1.8349,
150+
0.3024, -0.0453, -0.0681, -1.7377, 1.5909, 0.2549, -0.3029, 0.2583,
151+
0.0242, 2.0748, -0.5454, 0.7137, 1.6688, 0.7108, -0.8681, 0.2486,
152+
-1.3981, 1.0241, 1.2413, 0.2725, 1.4265, 0.9329, 0.4020, -2.6813]])
153+
tensor([[ 1.3486, 1.9844, -0.6611, -0.7344, 0.2471, 0.7012, 0.3174, -0.8403,
154+
0.3025, -0.0443, -0.0659, -0.8242, 1.5908, 0.2549, -0.2615, 0.2583,
155+
0.0242, 2.0742, -0.4204, 0.7139, 1.6689, 0.7109, -0.5801, 0.2485,
156+
-0.7529, 1.0244, 1.2412, 0.2725, 1.4268, 0.9331, 0.4021, -0.9316]],
157+
device='cuda:0', dtype=torch.float16)
158+
159+
```
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#include "core/conversion/converters/converters.h"
2+
#include "core/util/prelude.h"
3+
4+
namespace trtorch {
5+
namespace core {
6+
namespace conversion {
7+
namespace converters {
8+
namespace impl {
9+
namespace {
10+
11+
auto acthardtanh TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
12+
{"aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor)",
13+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
14+
auto in = args[0].ITensorOrFreeze(ctx);
15+
auto alpha = args[1].unwrapToDouble();
16+
17+
auto new_layer = ctx->net->addActivation(*in, nvinfer1::ActivationType::kELU);
18+
TRTORCH_CHECK(new_layer, "Unable to create layer for aten::elu");
19+
20+
new_layer->setAlpha(alpha);
21+
new_layer->setName(util::node_info(n).c_str());
22+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
23+
24+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
25+
return true;
26+
}});
27+
28+
} // namespace
29+
} // namespace impl
30+
} // namespace converters
31+
} // namespace conversion
32+
} // namespace core
33+
} // namespace trtorch
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
import trtorch
3+
4+
torch.ops.load_library('./build/lib.linux-x86_64-3.6/elu_converter.cpython-36m-x86_64-linux-gnu.so')
5+
6+
7+
class Elu(torch.nn.Module):
8+
9+
def __init__(self):
10+
super(Elu, self).__init__()
11+
self.elu = torch.nn.ELU()
12+
13+
def forward(self, x):
14+
return self.elu(x)
15+
16+
17+
def main():
18+
data = torch.randn((1, 1, 2, 2)).to("cuda")
19+
model = Elu().eval() #.cuda()
20+
21+
# traced_model = torch.jit.trace(model, [data])
22+
scripted_model = torch.jit.script(model)
23+
print(scripted_model.graph)
24+
# torch.jit.save(scripted_model, 'elu.jit')
25+
compile_settings = {
26+
"input_shapes": [{
27+
"min": [1024, 1, 32, 32],
28+
"opt": [1024, 1, 33, 33],
29+
"max": [1024, 1, 34, 34],
30+
}],
31+
"op_precision":
32+
torch.half # Run with FP16
33+
}
34+
trt_ts_module = trtorch.compile(scripted_model, compile_settings)
35+
input_data = torch.randn((1024, 1, 32, 32))
36+
print(input_data[0, :, :, 0])
37+
input_data = input_data.half().to("cuda")
38+
result = trt_ts_module(input_data)
39+
print(result[0, :, :, 0])
40+
# torch.jit.save(trt_ts_module, "trt_ts_module.ts")
41+
42+
43+
if __name__ == "__main__":
44+
main()

examples/elu_converter/setup.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import os
2+
from setuptools import setup, Extension
3+
from torch.utils import cpp_extension
4+
5+
dir_path = os.path.dirname(os.path.realpath(__file__))
6+
7+
ext_modules = [
8+
cpp_extension.CUDAExtension('elu_converter', ['elu_converter.cpp'],
9+
library_dirs=[(dir_path + "/../../bazel-bin/cpp/api/lib/")],
10+
libraries=["trtorch"],
11+
include_dirs=[dir_path + "/../../"])
12+
]
13+
14+
setup(
15+
name='elu_converter',
16+
ext_modules=ext_modules,
17+
cmdclass={'build_ext': cpp_extension.BuildExtension},
18+
)

0 commit comments

Comments
 (0)