Skip to content

Commit 211e131

Browse files
authored
feature/tensorrt engine op (#11001)
1 parent 4944920 commit 211e131

File tree

5 files changed

+213
-4
lines changed

5 files changed

+213
-4
lines changed

paddle/fluid/inference/tensorrt/engine.cc

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,20 @@ void* TensorRTEngine::GetOutputInGPU(const std::string& name) {
131131
return buffer(name).buffer;
132132
}
133133

134+
void TensorRTEngine::GetOutputInGPU(const std::string& name, void* dst,
135+
size_t max_size) {
136+
// determine data size
137+
auto it = buffer_sizes_.find(name);
138+
PADDLE_ENFORCE(it != buffer_sizes_.end());
139+
PADDLE_ENFORCE_GT(it->second, 0);
140+
PADDLE_ENFORCE_GE(max_size, it->second);
141+
auto& buf = buffer(name);
142+
PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
143+
PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, it->second,
144+
cudaMemcpyDeviceToDevice, *stream_),
145+
0);
146+
}
147+
134148
void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst,
135149
size_t max_size) {
136150
// determine data size
@@ -152,7 +166,7 @@ Buffer& TensorRTEngine::buffer(const std::string& name) {
152166
return buffers_[slot_offset];
153167
}
154168

155-
void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data,
169+
void TensorRTEngine::SetInputFromCPU(const std::string& name, const void* data,
156170
size_t size) {
157171
auto& buf = buffer(name);
158172
PADDLE_ENFORCE_NOT_NULL(buf.buffer);
@@ -162,6 +176,16 @@ void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data,
162176
cudaMemcpyHostToDevice, *stream_));
163177
}
164178

179+
void TensorRTEngine::SetInputFromGPU(const std::string& name, const void* data,
180+
size_t size) {
181+
auto& buf = buffer(name);
182+
PADDLE_ENFORCE_NOT_NULL(buf.buffer);
183+
PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small");
184+
PADDLE_ENFORCE(buf.device == DeviceType::GPU);
185+
PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size,
186+
cudaMemcpyDeviceToDevice, *stream_));
187+
}
188+
165189
void TensorRTEngine::SetITensor(const std::string& name,
166190
nvinfer1::ITensor* tensor) {
167191
PADDLE_ENFORCE(tensor != nullptr);

paddle/fluid/inference/tensorrt/engine.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,15 @@ class TensorRTEngine : public EngineBase {
9292
cudaStream_t* stream() { return stream_; }
9393

9494
// Fill an input from CPU memory with name and size.
95-
void SetInputFromCPU(const std::string& name, void* data, size_t size);
95+
void SetInputFromCPU(const std::string& name, const void* data, size_t size);
9696
// TODO(Superjomn) is this method necessary given that buffer(xxx) can be
9797
// accessed directly. Fill an input from GPU memory with name and size.
98-
void SetInputFromGPU(const std::string& name, void* data, size_t size);
98+
void SetInputFromGPU(const std::string& name, const void* data, size_t size);
9999
// Get an output called name, the output of tensorrt is in GPU, so this method
100-
// will just return the output's GPU memory address.
100+
// Return the output's GPU memory address without copy.
101101
void* GetOutputInGPU(const std::string& name);
102+
// Copy data into dst inside the GPU device.
103+
void GetOutputInGPU(const std::string& name, void* dst, size_t max_size);
102104
// LOW EFFICENCY! Get output to CPU, this will trigger a memory copy from GPU
103105
// to CPU.
104106
void GetOutputInCPU(const std::string& name, void* dst, size_t max_size);

paddle/fluid/operators/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,9 @@ op_library(cross_entropy_op DEPS cross_entropy)
225225
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
226226
op_library(softmax_op DEPS softmax)
227227
op_library(sequence_softmax_op DEPS softmax)
228+
if (WITH_GPU AND TENSORRT_FOUND)
229+
op_library(tensorrt_engine_op DEPS tensorrt_engine)
230+
endif()
228231
op_library(sum_op DEPS selected_rows_functor)
229232
op_library(sgd_op DEPS selected_rows_functor)
230233
op_library(print_op DEPS lod_tensor)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#ifdef PADDLE_WITH_CUDA
16+
17+
#include "paddle/fluid/operators/tensorrt_engine_op.h"
18+
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
20+
#include "paddle/fluid/inference/utils/singleton.h"
21+
22+
namespace paddle {
23+
namespace operators {
24+
25+
template <typename DeviceContext, typename T>
26+
void paddle::operators::TensorRTEngineKernel<DeviceContext, T>::Prepare(
27+
const framework::ExecutionContext &context) const {
28+
// Get the ProgramDesc and pass to convert.
29+
const auto &block = context.Attr<framework::proto::BlockDesc>("subgraph");
30+
max_batch_ = context.Attr<int>("max_batch");
31+
auto max_workspace = context.Attr<int>("max_workspace");
32+
engine_.reset(new inference::tensorrt::TensorRTEngine(
33+
max_batch_, max_workspace, nullptr));
34+
inference::Singleton<inference::tensorrt::OpConverter>::Global().ConvertBlock(
35+
block, engine_.get());
36+
engine_->FreezeNetwork();
37+
}
38+
39+
class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
40+
public:
41+
void Make() override {
42+
AddInput("Xs", "A list of inputs.").AsDuplicable();
43+
AddOutput("Ys", "A list of outputs").AsDuplicable();
44+
AddAttr<std::string>("subgraph", "the subgraph");
45+
AddComment("TensorRT engine operator.");
46+
}
47+
};
48+
49+
class TensorRTEngineInferVarType : public framework::VarTypeInference {
50+
public:
51+
void operator()(const framework::OpDesc &op_desc,
52+
framework::BlockDesc *block) const override {}
53+
};
54+
55+
} // namespace operators
56+
} // namespace paddle
57+
58+
namespace ops = paddle::operators;
59+
60+
REGISTER_OPERATOR(tensorrt_engine, ops::TensorRTEngineOp,
61+
ops::TensorRTEngineOpMaker, ops::TensorRTEngineOpMaker);
62+
63+
REGISTER_OP_CPU_KERNEL(
64+
tensorrt_engine,
65+
ops::TensorRTEngineKernel<paddle::platform::CPUDeviceContext, float>,
66+
ops::TensorRTEngineKernel<paddle::platform::CPUDeviceContext, double>,
67+
ops::TensorRTEngineKernel<paddle::platform::CPUDeviceContext, int>,
68+
ops::TensorRTEngineKernel<paddle::platform::CPUDeviceContext, int64_t>);
69+
70+
#endif // PADDLE_WITH_CUDA
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#ifdef PADDLE_WITH_CUDA
18+
19+
#include "paddle/fluid/framework/operator.h"
20+
#include "paddle/fluid/inference/analysis/helper.h"
21+
#include "paddle/fluid/inference/tensorrt/engine.h"
22+
23+
namespace paddle {
24+
namespace operators {
25+
26+
class TensorRTEngineOp : public framework::OperatorWithKernel {
27+
public:
28+
using framework::OperatorWithKernel::OperatorWithKernel;
29+
30+
protected:
31+
void InferShape(framework::InferShapeContext* ctx) const override {}
32+
33+
framework::OpKernelType GetExpectedKernelType(
34+
const framework::ExecutionContext& ctx) const override {
35+
framework::OpKernelType kt = framework::OpKernelType(
36+
framework::ToDataType(
37+
ctx.Input<framework::LoDTensor>("pre_ids")->type()),
38+
platform::CPUPlace());
39+
return kt;
40+
}
41+
};
42+
43+
template <typename DeviceContext, typename T>
44+
class TensorRTEngineKernel : public framework::OpKernel<T> {
45+
public:
46+
void Compute(const framework::ExecutionContext& context) const override {
47+
if (!engine_) {
48+
Prepare(context);
49+
}
50+
auto input_names = context.op().Inputs("Xs");
51+
PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs");
52+
// Try to determine a batch_size
53+
auto* tensor0 = context.Input<framework::LoDTensor>(input_names.front());
54+
PADDLE_ENFORCE_NOT_NULL(tensor0);
55+
int batch_size = tensor0->dims()[0];
56+
PADDLE_ENFORCE_LE(batch_size, max_batch_);
57+
58+
// Convert input tensor from fluid to engine.
59+
for (const auto& x : context.Inputs("Xs")) {
60+
// convert input and copy to TRT engine's buffer
61+
auto* v = context.scope().FindVar(x);
62+
PADDLE_ENFORCE_NOT_NULL(v, "no variable called %s", x);
63+
auto& t = v->Get<framework::LoDTensor>();
64+
if (platform::is_cpu_place(t.place())) {
65+
engine_->SetInputFromCPU(x, static_cast<const void*>(t.data<void>()),
66+
t.memory_size());
67+
} else {
68+
engine_->SetInputFromGPU(x, static_cast<const void*>(t.data<void>()),
69+
t.memory_size());
70+
}
71+
}
72+
// Execute the engine.
73+
PADDLE_ENFORCE_GT(batch_size, 0);
74+
engine_->Execute(batch_size);
75+
// Convert output tensor from engine to fluid
76+
for (const auto& y : context.Outputs("Ys")) {
77+
// convert output and copy to fluid.
78+
nvinfer1::ITensor* trt_t = engine_->GetITensor(y);
79+
auto dims = trt_t->getDimensions();
80+
// Use the output ITensor's dims to reshape the Fluid Tensor.
81+
std::vector<int> ddim(dims.d, dims.d + dims.nbDims);
82+
83+
auto* fluid_v = context.scope().FindVar(y);
84+
PADDLE_ENFORCE_NOT_NULL(fluid_v, "no output variable called %s", y);
85+
auto* fluid_t = fluid_v->GetMutable<framework::LoDTensor>();
86+
fluid_t->Resize(framework::make_ddim(ddim));
87+
auto size = inference::analysis::AccuDims(dims.d, dims.nbDims);
88+
if (platform::is_cpu_place(fluid_t->place())) {
89+
engine_->GetOutputInCPU(
90+
y, fluid_t->mutable_data<float>(platform::CPUPlace()), size);
91+
} else {
92+
engine_->GetOutputInGPU(
93+
y, fluid_t->mutable_data<float>(platform::CUDAPlace()), size);
94+
}
95+
}
96+
}
97+
98+
protected:
99+
// Build the engine.
100+
void Prepare(const framework::ExecutionContext& context) const;
101+
102+
private:
103+
mutable std::unique_ptr<inference::tensorrt::TensorRTEngine> engine_;
104+
mutable int max_batch_{0};
105+
};
106+
107+
} // namespace operators
108+
} // namespace paddle
109+
110+
#endif // PADDLE_WITH_CUDA

0 commit comments

Comments
 (0)