Skip to content

Commit 1866597

Browse files
authored
add tensorrt build support(#9891)
1 parent 0032b4a commit 1866597

File tree

12 files changed

+308
-0
lines changed

12 files changed

+308
-0
lines changed

CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_F
3939
option(WITH_AMD_GPU "Compile PaddlePaddle with AMD GPU" OFF)
4040
option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND})
4141
option(WITH_MKL "Compile PaddlePaddle with MKL support." ${AVX_FOUND})
42+
option(WITH_TENSORRT "Compile PaddlePaddle with TensorRT support." OFF)
4243
option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON)
4344
option(WITH_TESTING "Compile PaddlePaddle with unit testing" OFF)
4445
option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON)
@@ -181,6 +182,11 @@ if(WITH_GPU)
181182
include(cuda)
182183
endif(WITH_GPU)
183184

185+
# TensorRT depends on GPU.
186+
if (NOT WITH_GPU)
187+
set(WITH_TENSORRT OFF)
188+
endif()
189+
184190
if(WITH_AMD_GPU)
185191
find_package(HIP)
186192
include(hip)

Dockerfile

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ ENV PATH=${PATH}:${GOROOT}/bin:${GOPATH}/bin
4545
# install glide
4646
RUN curl -s -q https://glide.sh/get | sh
4747

48+
# Install TensorRT
49+
# The unnecessary files has been removed to make the library small.
50+
RUN wget -qO- http://paddlepaddledeps.bj.bcebos.com/TensorRT-4.0.0.3.Ubuntu-16.04.4.x86_64-gnu.cuda-8.0.cudnn7.0.tar.gz | \
51+
tar -xz -C /usr/local && \
52+
cp -rf /usr/local/TensorRT/include /usr && \
53+
cp -rf /usr/local/TensorRT/lib /usr
54+
4855
# git credential to skip password typing
4956
RUN git config --global credential.helper store
5057

paddle/fluid/inference/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,7 @@ endif()
2121

2222
if(WITH_TESTING)
2323
add_subdirectory(tests/book)
24+
if (WITH_TENSORRT)
25+
add_subdirectory(tensorrt)
26+
endif()
2427
endif()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <glog/logging.h>
16+
#include <gtest/gtest.h>
17+
#include "NvInfer.h"
18+
#include "cuda.h"
19+
#include "cuda_runtime_api.h"
20+
#include "paddle/fluid/platform/dynload/tensorrt.h"
21+
22+
namespace dy = paddle::platform::dynload;
23+
24+
class Logger : public nvinfer1::ILogger {
25+
public:
26+
void log(nvinfer1::ILogger::Severity severity, const char* msg) override {
27+
switch (severity) {
28+
case Severity::kINFO:
29+
LOG(INFO) << msg;
30+
break;
31+
case Severity::kWARNING:
32+
LOG(WARNING) << msg;
33+
break;
34+
case Severity::kINTERNAL_ERROR:
35+
case Severity::kERROR:
36+
LOG(ERROR) << msg;
37+
break;
38+
default:
39+
break;
40+
}
41+
}
42+
};
43+
44+
class ScopedWeights {
45+
public:
46+
ScopedWeights(float value) : value_(value) {
47+
w.type = nvinfer1::DataType::kFLOAT;
48+
w.values = &value_;
49+
w.count = 1;
50+
}
51+
const nvinfer1::Weights& get() { return w; }
52+
53+
private:
54+
float value_;
55+
nvinfer1::Weights w;
56+
};
57+
58+
// The following two API are implemented in TensorRT's header file, cannot load
59+
// from the dynamic library. So create our own implementation and directly
60+
// trigger the method from the dynamic library.
61+
nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger& logger) {
62+
return static_cast<nvinfer1::IBuilder*>(
63+
dy::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION));
64+
}
65+
nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger& logger) {
66+
return static_cast<nvinfer1::IRuntime*>(
67+
dy::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION));
68+
}
69+
70+
const char* kInputTensor = "input";
71+
const char* kOutputTensor = "output";
72+
73+
// Creates a network to compute y = 2x + 3
74+
nvinfer1::IHostMemory* CreateNetwork() {
75+
Logger logger;
76+
// Create the engine.
77+
nvinfer1::IBuilder* builder = createInferBuilder(logger);
78+
ScopedWeights weights(2.);
79+
ScopedWeights bias(3.);
80+
81+
nvinfer1::INetworkDefinition* network = builder->createNetwork();
82+
// Add the input
83+
auto input = network->addInput(kInputTensor, nvinfer1::DataType::kFLOAT,
84+
nvinfer1::DimsCHW{1, 1, 1});
85+
EXPECT_NE(input, nullptr);
86+
// Add the hidden layer.
87+
auto layer = network->addFullyConnected(*input, 1, weights.get(), bias.get());
88+
EXPECT_NE(layer, nullptr);
89+
// Mark the output.
90+
auto output = layer->getOutput(0);
91+
output->setName(kOutputTensor);
92+
network->markOutput(*output);
93+
// Build the engine.
94+
builder->setMaxBatchSize(1);
95+
builder->setMaxWorkspaceSize(1 << 10);
96+
auto engine = builder->buildCudaEngine(*network);
97+
EXPECT_NE(engine, nullptr);
98+
// Serialize the engine to create a model, then close.
99+
nvinfer1::IHostMemory* model = engine->serialize();
100+
network->destroy();
101+
engine->destroy();
102+
builder->destroy();
103+
return model;
104+
}
105+
106+
void Execute(nvinfer1::IExecutionContext& context, const float* input,
107+
float* output) {
108+
const nvinfer1::ICudaEngine& engine = context.getEngine();
109+
// Two binds, input and output
110+
ASSERT_EQ(engine.getNbBindings(), 2);
111+
const int input_index = engine.getBindingIndex(kInputTensor);
112+
const int output_index = engine.getBindingIndex(kOutputTensor);
113+
// Create GPU buffers and a stream
114+
void* buffers[2];
115+
ASSERT_EQ(0, cudaMalloc(&buffers[input_index], sizeof(float)));
116+
ASSERT_EQ(0, cudaMalloc(&buffers[output_index], sizeof(float)));
117+
cudaStream_t stream;
118+
ASSERT_EQ(0, cudaStreamCreate(&stream));
119+
// Copy the input to the GPU, execute the network, and copy the output back.
120+
ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input, sizeof(float),
121+
cudaMemcpyHostToDevice, stream));
122+
context.enqueue(1, buffers, stream, nullptr);
123+
ASSERT_EQ(0, cudaMemcpyAsync(output, buffers[output_index], sizeof(float),
124+
cudaMemcpyDeviceToHost, stream));
125+
cudaStreamSynchronize(stream);
126+
127+
// Release the stream and the buffers
128+
cudaStreamDestroy(stream);
129+
ASSERT_EQ(0, cudaFree(buffers[input_index]));
130+
ASSERT_EQ(0, cudaFree(buffers[output_index]));
131+
}
132+
133+
TEST(TensorrtTest, BasicFunction) {
134+
// Create the network serialized model.
135+
nvinfer1::IHostMemory* model = CreateNetwork();
136+
137+
// Use the model to create an engine and an execution context.
138+
Logger logger;
139+
nvinfer1::IRuntime* runtime = createInferRuntime(logger);
140+
nvinfer1::ICudaEngine* engine =
141+
runtime->deserializeCudaEngine(model->data(), model->size(), nullptr);
142+
model->destroy();
143+
nvinfer1::IExecutionContext* context = engine->createExecutionContext();
144+
145+
// Execute the network.
146+
float input = 1234;
147+
float output;
148+
Execute(*context, &input, &output);
149+
EXPECT_EQ(output, input * 2 + 3);
150+
151+
// Destroy the engine.
152+
context->destroy();
153+
engine->destroy();
154+
runtime->destroy();
155+
}

paddle/fluid/platform/dynload/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce)
22

33
list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc nccl.cc)
4+
if (WITH_TENSORRT)
5+
list(APPEND CUDA_SRCS tensorrt.cc)
6+
endif()
7+
8+
49
configure_file(cupti_lib_path.h.in ${CMAKE_CURRENT_BINARY_DIR}/cupti_lib_path.h)
510
if (CUPTI_FOUND)
611
list(APPEND CUDA_SRCS cupti.cc)

paddle/fluid/platform/dynload/dynamic_loader.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ DEFINE_string(nccl_dir, "",
4545

4646
DEFINE_string(cupti_dir, "", "Specify path for loading cupti.so.");
4747

48+
DEFINE_string(
49+
tensorrt_dir, "",
50+
"Specify path for loading tensorrt library, such as libnvinfer.so.");
51+
4852
namespace paddle {
4953
namespace platform {
5054
namespace dynload {
@@ -194,6 +198,14 @@ void* GetNCCLDsoHandle() {
194198
#endif
195199
}
196200

201+
void* GetTensorRtDsoHandle() {
202+
#if defined(__APPLE__) || defined(__OSX__)
203+
return GetDsoHandleFromSearchPath(FLAGS_tensorrt_dir, "libnvinfer.dylib");
204+
#else
205+
return GetDsoHandleFromSearchPath(FLAGS_tensorrt_dir, "libnvinfer.so");
206+
#endif
207+
}
208+
197209
} // namespace dynload
198210
} // namespace platform
199211
} // namespace paddle

paddle/fluid/platform/dynload/dynamic_loader.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ void* GetCurandDsoHandle();
2525
void* GetWarpCTCDsoHandle();
2626
void* GetLapackDsoHandle();
2727
void* GetNCCLDsoHandle();
28+
void* GetTensorRtDsoHandle();
2829

2930
} // namespace dynload
3031
} // namespace platform
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/platform/dynload/tensorrt.h"
16+
17+
namespace paddle {
18+
namespace platform {
19+
namespace dynload {
20+
21+
std::once_flag tensorrt_dso_flag;
22+
void *tensorrt_dso_handle;
23+
24+
#define DEFINE_WRAP(__name) DynLoad__##__name __name
25+
26+
TENSORRT_RAND_ROUTINE_EACH(DEFINE_WRAP);
27+
28+
} // namespace dynload
29+
} // namespace platform
30+
} // namespace paddle
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
#pragma once
15+
16+
#include <NvInfer.h>
17+
#include <dlfcn.h>
18+
19+
#include <mutex> // NOLINT
20+
21+
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
22+
#include "paddle/fluid/platform/enforce.h"
23+
24+
namespace paddle {
25+
namespace platform {
26+
namespace dynload {
27+
28+
extern std::once_flag tensorrt_dso_flag;
29+
extern void* tensorrt_dso_handle;
30+
31+
#ifdef PADDLE_USE_DSO
32+
33+
#define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \
34+
struct DynLoad__##__name { \
35+
template <typename... Args> \
36+
auto operator()(Args... args) -> decltype(__name(args...)) { \
37+
using tensorrt_func = decltype(__name(args...)) (*)(Args...); \
38+
std::call_once(tensorrt_dso_flag, []() { \
39+
tensorrt_dso_handle = \
40+
paddle::platform::dynload::GetTensorRtDsoHandle(); \
41+
PADDLE_ENFORCE(tensorrt_dso_handle, "load tensorrt so failed"); \
42+
}); \
43+
void* p_##__name = dlsym(tensorrt_dso_handle, #__name); \
44+
PADDLE_ENFORCE(p_##__name, "load %s failed", #__name); \
45+
return reinterpret_cast<tensorrt_func>(p_##__name)(args...); \
46+
} \
47+
}; \
48+
extern DynLoad__##__name __name
49+
50+
#else
51+
#define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \
52+
struct DynLoad__##__name { \
53+
template <typename... Args> \
54+
tensorrtResult_t operator()(Args... args) { \
55+
return __name(args...); \
56+
} \
57+
}; \
58+
extern DynLoad__##__name __name
59+
#endif
60+
61+
#define TENSORRT_RAND_ROUTINE_EACH(__macro) \
62+
__macro(createInferBuilder_INTERNAL); \
63+
__macro(createInferRuntime_INTERNAL);
64+
65+
TENSORRT_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP)
66+
67+
} // namespace dynload
68+
} // namespace platform
69+
} // namespace paddle

0 commit comments

Comments
 (0)