Skip to content

Commit 3a96b9c

Browse files
committed
Add sample rt app
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 4a9c595 commit 3a96b9c

File tree

3 files changed

+99
-0
lines changed

3 files changed

+99
-0
lines changed

examples/sample_rt_app/BUILD

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package(default_visibility = ["//visibility:public"])
2+
3+
# trtorch is the dowloaded tar file
4+
# It has include, lib, bin directory and LICENSE file
5+
cc_library(
6+
name = "trtorch_runtime",
7+
srcs = ["trtorch/lib/libtrtorchrt.so", "trtorch/lib/libtrtorch_plugins.so"],
8+
hdrs = ["trtorch/include/trtorch/core/runtime/runtime.h"], # "trtorch/include/trtorch/trtorch.h", "trtorch/include/trtorch/macros.h"],
9+
includes = ["trtorch/include/trtorch"],
10+
#include_prefix="trtorch/include",
11+
)
12+
13+
cc_binary(
14+
name = "samplertapp",
15+
srcs = [
16+
"main.cpp"
17+
],
18+
includes = ["trtorch/include", "trtorch/include/trtorch"],
19+
deps = [
20+
":trtorch_runtime",
21+
"@libtorch//:libtorch",
22+
"@libtorch//:caffe2",
23+
"@tensorrt//:nvinfer",
24+
],
25+
)

examples/sample_rt_app/main.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// #include "examples/sample_rt_app/trtorch/include/trtorch/core/runtime/runtime.h"
2+
#include "trtorch/core/runtime/runtime.h"
3+
#include <iostream>
4+
#include <fstream>
5+
#include <memory>
6+
#include <sstream>
7+
#include <vector>
8+
#include "torch/script.h"
9+
#include "trtorch/include/trtorch/trtorch.h"
10+
11+
// Load the TRT engine from engine_path
12+
std::vector<char> loadEngine(std::string engine_path){
13+
std::ifstream engineFile(engine_path, std::ios::binary);
14+
if (!engineFile)
15+
{
16+
std::cerr << "Error opening TensorRT Engine file at : " << engine_path << std::endl;
17+
}
18+
19+
engineFile.seekg(0, engineFile.end);
20+
long int fsize = engineFile.tellg();
21+
engineFile.seekg(0, engineFile.beg);
22+
23+
std::vector<char> engineData(fsize);
24+
engineFile.read(engineData.data(), fsize);
25+
if (!engineFile)
26+
{
27+
std::cerr << "Error loading engine from: " << engine_path << std::endl;
28+
}
29+
30+
return engineData;
31+
}
32+
33+
int main(int argc, const char* argv[]) {
34+
if (argc < 2) {
35+
std::cerr
36+
<< "usage: samplertapp <path-to-pre-built-trt-engine>\n";
37+
return -1;
38+
}
39+
40+
std::string engine_path = argv[1];
41+
auto engineData = loadEngine(engine_path);
42+
43+
std::cout << "Running TRT engine" << std::endl;
44+
auto engine_ptr = c10::make_intrusive<TRTEngine>("test_engine", engineData.data());
45+
auto inputs = at::randint(-5, 5, {1, 3, 5, 5}, {at::kCUDA});
46+
auto outputs = trtorch::core::runtime::execute_engine(inputs, engine_ptr);
47+
std::cout << "TRT engine execution completed. " << std::endl;
48+
}

examples/sample_rt_app/network.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
# Create a sample network with a conv and gelu node.
5+
# Gelu layer in TRTorch is converted to CustomGeluPluginDynamic from TensorRT plugin registry.
6+
class ConvGelu(torch.nn.Module):
7+
def __init__(self):
8+
super(ConvGelu, self).__init__()
9+
self.conv = nn.Conv2d(3, 32, 3, 1)
10+
self.gelu = nn.GELU()
11+
12+
def forward(self, x):
13+
x = self.conv(x)
14+
x = self.gelu(x)
15+
return x
16+
17+
def main():
18+
19+
model = ConvGelu().eval().cuda()
20+
scripted_model = torch.jit.script(model)
21+
# Save the torchscript model
22+
torch.jit.save(scripted_model, 'conv_gelu.jit')
23+
print("Generated conv_gelu.jit model.")
24+
25+
if __name__ == "__main__":
26+
main()

0 commit comments

Comments
 (0)