Skip to content

Commit 6169d72

Browse files
authored
Merge pull request #12324 from NHZlX/enhance_for_tensorrt_infer
Enhance for tensorrt infer
2 parents 9f0d9df + 4d49e61 commit 6169d72

File tree

11 files changed

+123
-67
lines changed

11 files changed

+123
-67
lines changed

paddle/fluid/inference/tensorrt/convert/fc_op.cc

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ void Reorder2(nvinfer1::DimsHW shape, const T* idata, nvinfer1::DimsHW istrides,
3232
for (int h = 0; h < shape.h(); ++h) {
3333
for (int w = 0; w < shape.w(); ++w) {
3434
odata[h * ostrides.h() + w * ostrides.w()] =
35-
idata[h * ostrides.h() + w * ostrides.w()];
35+
idata[h * istrides.h() + w * istrides.w()];
3636
}
3737
}
3838
}
39-
39+
// indata c * k
4040
// Reorder the data layout from CK to KC.
4141
void ReorderCKtoKC(TensorRTEngine::Weight& iweights,
4242
TensorRTEngine::Weight* oweights) {
@@ -79,9 +79,8 @@ class FcOpConverter : public OpConverter {
7979

8080
framework::LoDTensor tmp;
8181
tmp.Resize(Y_t->dims());
82-
memcpy(tmp.mutable_data<float>(platform::CPUPlace()), Y_t->data<float>(),
83-
Y_t->dims()[0] * Y_t->dims()[1]);
84-
82+
memcpy(tmp.mutable_data<float>(platform::CPUPlace()), weight_data,
83+
Y_t->dims()[0] * Y_t->dims()[1] * sizeof(float));
8584
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
8685
static_cast<void*>(weight_data),
8786
Y_t->memory_size() / sizeof(float)};
@@ -93,7 +92,7 @@ class FcOpConverter : public OpConverter {
9392

9493
// The data layout of TRT FC layer's weight is different from fluid's FC,
9594
// need to reorder the elements.
96-
ReorderCKtoKC(tmp_weight, &weight);
95+
ReorderCKtoKC(weight, &tmp_weight);
9796

9897
// Currently, the framework can only handle one fluid op -> one TRT layer,
9998
// but fc fuses `mul` and `bias` (2 fluid ops), so here is a trick, just
@@ -103,7 +102,7 @@ class FcOpConverter : public OpConverter {
103102

104103
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected,
105104
*const_cast<nvinfer1::ITensor*>(X),
106-
n_output, weight.get(), bias.get());
105+
n_output, tmp_weight.get(), bias.get());
107106

108107
auto output_name = op_desc.Output("Out").front();
109108
engine_->SetITensor(output_name, layer->getOutput(0));
@@ -118,4 +117,3 @@ class FcOpConverter : public OpConverter {
118117
} // namespace paddle
119118

120119
REGISTER_TRT_OP_CONVERTER(fc, FcOpConverter);
121-
USE_OP(mul);

paddle/fluid/inference/tensorrt/convert/test_activation_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ TEST(ReluOpConverter, main) {
3737
validator.SetOp(*desc.Proto());
3838
LOG(INFO) << "execute";
3939

40-
validator.Execute(10);
40+
validator.Execute(1);
4141
}
4242

4343
} // namespace tensorrt

paddle/fluid/inference/tensorrt/convert/test_fc_op.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ namespace tensorrt {
2323
TEST(fc_op, test) {
2424
std::unordered_set<std::string> parameters({"mul-Y"});
2525
framework::Scope scope;
26-
TRTConvertValidation validator(20, parameters, scope, 1000);
27-
28-
validator.DeclInputVar("mul-X", nvinfer1::Dims4(8, 3, 1, 1));
29-
validator.DeclParamVar("mul-Y", nvinfer1::Dims2(3, 2));
30-
validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(8, 2));
26+
TRTConvertValidation validator(10, parameters, scope, 1000);
27+
validator.DeclInputVar("mul-X", nvinfer1::Dims4(1, 10, 1, 1));
28+
validator.DeclParamVar("mul-Y", nvinfer1::Dims2(10, 2));
29+
// validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2));
30+
validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(1, 2));
3131

3232
// Prepare Op description
3333
framework::OpDesc desc;
@@ -38,9 +38,10 @@ TEST(fc_op, test) {
3838

3939
validator.SetOp(*desc.Proto());
4040

41-
validator.Execute(10);
41+
validator.Execute(1);
4242
}
4343

4444
} // namespace tensorrt
4545
} // namespace inference
4646
} // namespace paddle
47+
USE_OP(mul);

paddle/fluid/inference/tensorrt/convert/test_mul_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ TEST(MulOpConverter, main) {
3939
validator.SetOp(*desc.Proto());
4040
LOG(INFO) << "execute";
4141

42-
validator.Execute(10);
42+
validator.Execute(1);
4343
}
4444

4545
} // namespace tensorrt

paddle/fluid/inference/tensorrt/convert/ut_helper.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ namespace tensorrt {
3939
float random(float low, float high) {
4040
static std::random_device rd;
4141
static std::mt19937 mt(rd());
42-
std::uniform_real_distribution<double> dist(1.0, 10.0);
42+
std::uniform_real_distribution<double> dist(low, high);
4343
return dist(mt);
4444
}
4545

@@ -49,6 +49,7 @@ void RandomizeTensor(framework::LoDTensor* tensor, const platform::Place& place,
4949
size_t num_elements = analysis::AccuDims(dims, dims.size());
5050
PADDLE_ENFORCE_GT(num_elements, 0);
5151
auto* data = tensor->mutable_data<float>(place);
52+
5253
for (size_t i = 0; i < num_elements; i++) {
5354
*(data + i) = random(0., 1.);
5455
}
@@ -68,7 +69,7 @@ class TRTConvertValidation {
6869
int workspace_size = 1 << 10)
6970
: parameters_(parameters), scope_(scope) {
7071
// create engine.
71-
engine_.reset(new TensorRTEngine(10, 1 << 10, &stream_));
72+
engine_.reset(new TensorRTEngine(batch_size, workspace_size, &stream_));
7273
engine_->InitNetwork();
7374

7475
PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0);
@@ -138,12 +139,11 @@ class TRTConvertValidation {
138139
cudaStreamSynchronize(*engine_->stream());
139140

140141
ASSERT_FALSE(op_desc_->OutputArgumentNames().empty());
141-
const size_t output_space_size = 200;
142+
const size_t output_space_size = 2000;
142143
for (const auto& output : op_desc_->OutputArgumentNames()) {
143144
std::vector<float> fluid_out;
144145
std::vector<float> trt_out(output_space_size);
145-
engine_->GetOutputInCPU(output, &trt_out[0],
146-
output_space_size * sizeof(float));
146+
engine_->GetOutputInCPU(output, &trt_out[0], output_space_size);
147147
cudaStreamSynchronize(*engine_->stream());
148148

149149
auto* var = scope_.FindVar(output);

paddle/fluid/inference/tensorrt/engine.cc

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
22
3-
Licensed under the Apache License, Version 2.0 (the "License");
4-
you may not use this file except in compliance with the License.
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use
4+
this file except in compliance with the License.
55
You may obtain a copy of the License at
66
77
http://www.apache.org/licenses/LICENSE-2.0
@@ -26,6 +26,8 @@ namespace paddle {
2626
namespace inference {
2727
namespace tensorrt {
2828

29+
int TensorRTEngine::runtime_batch_ = 1;
30+
2931
void TensorRTEngine::Build(const DescType &paddle_model) {
3032
PADDLE_ENFORCE(false, "not implemented");
3133
}
@@ -42,6 +44,7 @@ void TensorRTEngine::Execute(int batch_size) {
4244
PADDLE_ENFORCE_NOT_NULL(stream_);
4345
infer_context_->enqueue(batch_size, buffers.data(), *stream_, nullptr);
4446
cudaStreamSynchronize(*stream_);
47+
SetRuntimeBatch(batch_size);
4548
}
4649

4750
TensorRTEngine::~TensorRTEngine() {
@@ -80,17 +83,17 @@ void TensorRTEngine::FreezeNetwork() {
8083
auto dims = infer_engine_->getBindingDimensions(slot_offset);
8184
item.second = kDataTypeSize[static_cast<int>(
8285
infer_engine_->getBindingDataType(slot_offset))] *
83-
analysis::AccuDims(dims.d, dims.nbDims);
86+
analysis::AccuDims(dims.d, dims.nbDims) * max_batch_;
8487
PADDLE_ENFORCE_GT(item.second, 0);
8588
}
8689

8790
auto &buf = buffer(item.first);
8891
buf.max_size = item.second * max_batch_;
8992
CHECK(buf.buffer == nullptr); // buffer should be allocated only once.
90-
PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, buf.max_size));
91-
PADDLE_ENFORCE_LE(buf.max_size, 1 << 30); // 10G
92-
// buf.size will changed in the runtime.
93+
94+
PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, item.second * max_batch_));
9395
buf.size = 0;
96+
PADDLE_ENFORCE_LE(buf.max_size, 1 << 30); // 10G
9497
buf.device = DeviceType::GPU;
9598
}
9699
}
@@ -105,7 +108,7 @@ nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name,
105108
auto *input = infer_network_->addInput(name.c_str(), dtype, dims);
106109
PADDLE_ENFORCE(input, "infer network add input %s failed", name);
107110
buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] *
108-
analysis::AccuDims(dims.d, dims.nbDims);
111+
analysis::AccuDims(dims.d, dims.nbDims) * max_batch_;
109112
PADDLE_ENFORCE(input->isNetworkInput());
110113
TensorRTEngine::SetITensor(name, input);
111114
return input;
@@ -149,35 +152,42 @@ void *TensorRTEngine::GetOutputInGPU(const std::string &name) {
149152
void TensorRTEngine::GetOutputInGPU(const std::string &name, void *dst,
150153
size_t max_size) {
151154
// determine data size
155+
auto *output = TensorRTEngine::GetITensor(name);
156+
nvinfer1::Dims dims = output->getDimensions();
157+
auto dim_size = analysis::AccuDims(dims.d, dims.nbDims);
158+
size_t dst_size = dim_size * runtime_batch_ *
159+
kDataTypeSize[static_cast<int>(output->getType())];
160+
152161
auto it = buffer_sizes_.find(name);
153162
PADDLE_ENFORCE(it != buffer_sizes_.end());
154163
PADDLE_ENFORCE_GT(it->second, 0);
155-
PADDLE_ENFORCE_GE(max_size, it->second);
164+
PADDLE_ENFORCE_LE(dst_size, it->second);
165+
PADDLE_ENFORCE_GE(max_size, dst_size);
156166
auto &buf = buffer(name);
157167
PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
158-
PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, it->second,
168+
PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, dst_size,
159169
cudaMemcpyDeviceToDevice, *stream_),
160170
0);
161171
}
162172

163173
void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst,
164174
size_t max_size) {
165-
VLOG(4) << "get output in cpu";
166-
auto &buf = buffer(name);
167-
168-
// Update needed buffer size.
169-
auto slot_offset = infer_engine_->getBindingIndex(name.c_str());
170-
auto dims = infer_engine_->getBindingDimensions(slot_offset);
171-
buf.size = kDataTypeSize[static_cast<int>(
172-
infer_engine_->getBindingDataType(slot_offset))] *
173-
analysis::AccuDims(dims.d, dims.nbDims);
174-
PADDLE_ENFORCE_LE(buf.size, buf.max_size);
175175
// determine data size
176+
177+
auto *output = TensorRTEngine::GetITensor(name);
178+
nvinfer1::Dims dims = output->getDimensions();
179+
auto dim_size = analysis::AccuDims(dims.d, dims.nbDims);
180+
size_t dst_size = dim_size * runtime_batch_ *
181+
kDataTypeSize[static_cast<int>(output->getType())];
182+
auto it = buffer_sizes_.find(name);
183+
PADDLE_ENFORCE(it != buffer_sizes_.end());
184+
PADDLE_ENFORCE_GT(it->second, 0);
185+
PADDLE_ENFORCE_LE(dst_size, it->second);
186+
PADDLE_ENFORCE_GE(max_size, dst_size);
187+
auto &buf = buffer(name);
176188
PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
177-
// DEBUG
178-
memset(dst, 0, buf.size);
179-
PADDLE_ENFORCE_EQ(
180-
0, cudaMemcpy(dst, buf.buffer, buf.size, cudaMemcpyDeviceToHost));
189+
PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buf.buffer, dst_size,
190+
cudaMemcpyDeviceToHost, *stream_));
181191
}
182192

183193
Buffer &TensorRTEngine::buffer(const std::string &name) {
@@ -225,6 +235,12 @@ nvinfer1::ITensor *TensorRTEngine::GetITensor(const std::string &name) {
225235
return itensor_map_[name];
226236
}
227237

238+
void TensorRTEngine::SetRuntimeBatch(size_t batch_size) {
239+
runtime_batch_ = batch_size;
240+
}
241+
242+
int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; }
243+
228244
} // namespace tensorrt
229245
} // namespace inference
230246
} // namespace paddle

paddle/fluid/inference/tensorrt/engine.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,14 @@ class TensorRTEngine : public EngineBase {
117117

118118
nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); }
119119
nvinfer1::INetworkDefinition* network() { return infer_network_.get(); }
120+
void SetRuntimeBatch(size_t batch_size);
121+
int GetRuntimeBatch();
120122

121123
private:
122124
// the max batch size
123125
int max_batch_;
126+
// the runtime batch size
127+
static int runtime_batch_;
124128
// the max memory size the engine uses
125129
int max_workspace_;
126130

paddle/fluid/inference/tensorrt/test_engine.cc

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class TensorRTEngineTest : public ::testing::Test {
2828
protected:
2929
void SetUp() override {
3030
ASSERT_EQ(0, cudaStreamCreate(&stream_));
31-
engine_ = new TensorRTEngine(1, 1 << 10, &stream_);
31+
engine_ = new TensorRTEngine(10, 1 << 10, &stream_);
3232
engine_->InitNetwork();
3333
}
3434

@@ -71,7 +71,7 @@ TEST_F(TensorRTEngineTest, add_layer) {
7171

7272
LOG(INFO) << "to get output";
7373
float y_cpu;
74-
engine_->GetOutputInCPU("y", &y_cpu, sizeof(float));
74+
engine_->GetOutputInCPU("y", &y_cpu, 1 * sizeof(float));
7575

7676
LOG(INFO) << "to checkout output";
7777
ASSERT_EQ(y_cpu, x_v * 2 + 3);
@@ -103,15 +103,49 @@ TEST_F(TensorRTEngineTest, add_layer_multi_dim) {
103103

104104
LOG(INFO) << "to get output";
105105
float y_cpu[2] = {-1., -1.};
106+
106107
auto dims = engine_->GetITensor("y")->getDimensions();
107108
ASSERT_EQ(dims.nbDims, 3);
108109
ASSERT_EQ(dims.d[0], 2);
109110
ASSERT_EQ(dims.d[1], 1);
110-
engine_->GetOutputInCPU("y", &y_cpu[0], sizeof(float) * 2);
111+
engine_->GetOutputInCPU("y", &y_cpu[0], 2 * sizeof(float));
111112
ASSERT_EQ(y_cpu[0], 4.5);
112113
ASSERT_EQ(y_cpu[1], 14.5);
113114
}
114115

116+
TEST_F(TensorRTEngineTest, test_conv2d_temp) {
117+
// Weight in CPU memory.
118+
float raw_weight[9] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0};
119+
float raw_bias[1] = {0};
120+
121+
TensorRTEngine::Weight weight(nvinfer1::DataType::kFLOAT, raw_weight, 9);
122+
TensorRTEngine::Weight bias(nvinfer1::DataType::kFLOAT, raw_bias, 1);
123+
auto* x = engine_->DeclareInput("x", nvinfer1::DataType::kFLOAT,
124+
nvinfer1::Dims3{1, 3, 3});
125+
auto* conv_layer =
126+
TRT_ENGINE_ADD_LAYER(engine_, Convolution, *x, 1, nvinfer1::DimsHW{3, 3},
127+
weight.get(), bias.get());
128+
PADDLE_ENFORCE(conv_layer != nullptr);
129+
conv_layer->setStride(nvinfer1::DimsHW{1, 1});
130+
conv_layer->setPadding(nvinfer1::DimsHW{1, 1});
131+
132+
engine_->DeclareOutput(conv_layer, 0, "y");
133+
engine_->FreezeNetwork();
134+
ASSERT_EQ(engine_->engine()->getNbBindings(), 2);
135+
136+
float x_v[18] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
137+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0};
138+
engine_->SetInputFromCPU("x", reinterpret_cast<void*>(&x_v),
139+
18 * sizeof(float));
140+
engine_->Execute(2);
141+
142+
LOG(INFO) << "to get output";
143+
float* y_cpu = new float[18];
144+
engine_->GetOutputInCPU("y", &y_cpu[0], 18 * sizeof(float));
145+
ASSERT_EQ(y_cpu[0], 4.0);
146+
ASSERT_EQ(y_cpu[1], 6.0);
147+
}
148+
115149
} // namespace tensorrt
116150
} // namespace inference
117151
} // namespace paddle

paddle/fluid/operators/tensorrt_engine_op.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,14 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t> &shape) {
5555
"TensorRT' tensor input requires at least 2 dimensions");
5656
PADDLE_ENFORCE_LE(shape.size(), 4UL,
5757
"TensorRT' tensor input requires at most 4 dimensions");
58+
5859
switch (shape.size()) {
5960
case 2:
60-
return nvinfer1::Dims2(shape[0], shape[1]);
61+
return nvinfer1::Dims2(1, shape[1]);
6162
case 3:
62-
return nvinfer1::Dims3(shape[0], shape[1], shape[2]);
63+
return nvinfer1::Dims3(1, shape[1], shape[2]);
6364
case 4:
64-
return nvinfer1::Dims4(shape[0], shape[1], shape[2], shape[3]);
65+
return nvinfer1::Dims4(1, shape[1], shape[2], shape[3]);
6566
default:
6667
return nvinfer1::Dims();
6768
}

paddle/fluid/operators/tensorrt_engine_op.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,15 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
9393
auto* fluid_v = context.scope().FindVar(y);
9494
PADDLE_ENFORCE_NOT_NULL(fluid_v, "no output variable called %s", y);
9595
auto* fluid_t = fluid_v->GetMutable<framework::LoDTensor>();
96-
auto size = inference::analysis::AccuDims(dims.d, dims.nbDims);
96+
9797
fluid_t->Resize(framework::make_ddim(ddim));
9898

9999
// TODO(Superjomn) find some way to determine which device to output the
100100
// tensor.
101101
// if (platform::is_cpu_place(fluid_t->place())) {
102102
// TODO(Superjomn) change this float to dtype size.
103+
auto size = inference::analysis::AccuDims(dims.d, dims.nbDims) *
104+
FLAGS_tensorrt_engine_batch_size;
103105
engine->GetOutputInCPU(y,
104106
fluid_t->mutable_data<float>(platform::CPUPlace()),
105107
size * sizeof(float));

0 commit comments

Comments
 (0)