Skip to content

Commit 0c0c5df

Browse files
authored
feature/add TRT fc converter (#11043)
1 parent 18d6402 commit 0c0c5df

File tree

12 files changed

+240
-35
lines changed

12 files changed

+240
-35
lines changed

paddle/fluid/inference/tensorrt/convert/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@ nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS
88
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
99
nv_test(test_trt_mul_op SRCS test_mul_op.cc mul_op.cc
1010
DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL)
11+
nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc
12+
DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL)

paddle/fluid/inference/tensorrt/convert/conv2d_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ namespace tensorrt {
2121
class Conv2dOpConverter : public OpConverter {
2222
public:
2323
Conv2dOpConverter() {}
24-
void operator()(const framework::proto::OpDesc& op) override {
24+
void operator()(const framework::proto::OpDesc& op,
25+
const framework::Scope& scope) override {
2526
LOG(INFO)
2627
<< "convert a fluid conv2d op to tensorrt conv layer without bias";
2728
}
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/eigen.h"
16+
#include "paddle/fluid/framework/lod_tensor.h"
17+
#include "paddle/fluid/framework/op_registry.h"
18+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
19+
#include "paddle/fluid/inference/tensorrt/engine.h"
20+
#include "paddle/fluid/platform/place.h"
21+
22+
namespace paddle {
23+
namespace inference {
24+
namespace tensorrt {
25+
26+
// Reorder the elements from istrides to ostrides, borrowed from TRT convert in
27+
// tensorflow.
28+
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/tensorrt/convert/convert_nodes.cc#L318
29+
template <typename T>
30+
void Reorder2(nvinfer1::DimsHW shape, const T* idata, nvinfer1::DimsHW istrides,
31+
T* odata, nvinfer1::DimsHW ostrides) {
32+
for (int h = 0; h < shape.h(); ++h) {
33+
for (int w = 0; w < shape.w(); ++w) {
34+
odata[h * ostrides.h() + w * ostrides.w()] =
35+
idata[h * ostrides.h() + w * ostrides.w()];
36+
}
37+
}
38+
}
39+
40+
// Reorder the data layout from CK to KC.
41+
void ReorderCKtoKC(TensorRTEngine::Weight& iweights,
42+
TensorRTEngine::Weight* oweights) {
43+
int c = iweights.dims[0];
44+
int k = iweights.dims[1];
45+
oweights->dims.assign({k, c});
46+
nvinfer1::DimsHW istrides = {1, k};
47+
nvinfer1::DimsHW ostrides = {c, 1};
48+
Reorder2({k, c}, static_cast<float const*>(iweights.get().values), istrides,
49+
static_cast<float*>(const_cast<void*>(oweights->get().values)),
50+
ostrides);
51+
}
52+
53+
/*
54+
* FC converter convert a MUL op in Fluid to a FC layer in TRT.
55+
*/
56+
class FcOpConverter : public OpConverter {
57+
public:
58+
void operator()(const framework::proto::OpDesc& op,
59+
const framework::Scope& scope) override {
60+
VLOG(4) << "convert a fluid fc op to tensorrt fc layer without bias";
61+
62+
framework::OpDesc op_desc(op, nullptr, nullptr);
63+
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
64+
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1); // Y is a weight
65+
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
66+
67+
// Declare inputs
68+
auto* X = engine_->GetITensor(op_desc.Input("X").front());
69+
70+
// Declare weights
71+
auto* Y_v = scope.FindVar(op_desc.Input("Y").front());
72+
PADDLE_ENFORCE_NOT_NULL(Y_v);
73+
auto* Y_t = Y_v->GetMutable<framework::LoDTensor>();
74+
// This may trigger a GPU->CPU copy, because TRT's weight can only be
75+
// assigned from CPU memory, that can't be avoided.
76+
auto* weight_data = Y_t->mutable_data<float>(platform::CPUPlace());
77+
PADDLE_ENFORCE_EQ(Y_t->dims().size(), 2UL); // a matrix
78+
size_t n_output = Y_t->dims()[1];
79+
80+
framework::LoDTensor tmp;
81+
tmp.Resize(Y_t->dims());
82+
memcpy(tmp.mutable_data<float>(platform::CPUPlace()), Y_t->data<float>(),
83+
Y_t->dims()[0] * Y_t->dims()[1]);
84+
85+
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
86+
static_cast<void*>(weight_data),
87+
Y_t->memory_size() / sizeof(float)};
88+
TensorRTEngine::Weight tmp_weight(nvinfer1::DataType::kFLOAT,
89+
static_cast<void*>(tmp.data<float>()),
90+
Y_t->memory_size() / sizeof(float));
91+
weight.dims.assign({Y_t->dims()[0], Y_t->dims()[1]});
92+
tmp_weight.dims = weight.dims;
93+
94+
// The data layout of TRT FC layer's weight is different from fluid's FC,
95+
// need to reorder the elements.
96+
ReorderCKtoKC(tmp_weight, &weight);
97+
98+
// Currently, the framework can only handle one fluid op -> one TRT layer,
99+
// but fc fuses `mul` and `bias` (2 fluid ops), so here is a trick, just
100+
// handle `mul`, leave `add` as another layer.
101+
// DEBUG
102+
TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, nullptr, 0};
103+
104+
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected,
105+
*const_cast<nvinfer1::ITensor*>(X),
106+
n_output, weight.get(), bias.get());
107+
108+
auto output_name = op_desc.Output("Out").front();
109+
engine_->DeclareOutput(layer, 0, output_name);
110+
}
111+
};
112+
113+
REGISTER_TRT_OP_CONVERTER(fc, FcOpConverter);
114+
115+
} // namespace tensorrt
116+
} // namespace inference
117+
} // namespace paddle
118+
119+
USE_OP(mul);

paddle/fluid/inference/tensorrt/convert/mul_op.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ namespace tensorrt {
2424
class MulOpConverter : public OpConverter {
2525
public:
2626
MulOpConverter() {}
27-
void operator()(const framework::proto::OpDesc& op) override {
28-
VLOG(4) << "convert a fluid mul op to tensorrt fc layer without bias";
27+
void operator()(const framework::proto::OpDesc& op,
28+
const framework::Scope& scope) override {
29+
VLOG(4) << "convert a fluid mul op to tensorrt mul layer without bias";
2930

3031
framework::OpDesc op_desc(op, nullptr);
3132
// Declare inputs

paddle/fluid/inference/tensorrt/convert/op_converter.h

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,27 +31,42 @@ namespace tensorrt {
3131
class OpConverter {
3232
public:
3333
OpConverter() {}
34-
virtual void operator()(const framework::proto::OpDesc& op) {}
3534

36-
void Run(const framework::proto::OpDesc& op, TensorRTEngine* engine) {
37-
std::string type = op.type();
38-
auto* it = Registry<OpConverter>::Lookup(type);
39-
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", type);
40-
it->SetEngine(engine);
41-
(*it)(op);
42-
}
35+
// Converter logic for an op.
36+
virtual void operator()(const framework::proto::OpDesc& op,
37+
const framework::Scope& scope) {}
38+
39+
// Convert a single fluid operaotr and add the corresponding layer to TRT.
40+
void ConvertOp(const framework::proto::OpDesc& op,
41+
const std::unordered_set<std::string>& parameters,
42+
const framework::Scope& scope, TensorRTEngine* engine) {
43+
framework::OpDesc op_desc(op, nullptr, nullptr);
44+
45+
OpConverter* it{nullptr};
4346

44-
// convert fluid op to tensorrt layer
45-
void ConvertOp(const framework::proto::OpDesc& op, TensorRTEngine* engine) {
46-
OpConverter::Run(op, engine);
47+
if (op_desc.Type() == "mul") {
48+
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1UL);
49+
std::string Y = op_desc.Input("Y")[0];
50+
if (parameters.count(Y)) {
51+
it = Registry<OpConverter>::Lookup("fc");
52+
}
53+
}
54+
if (!it) {
55+
it = Registry<OpConverter>::Lookup(op_desc.Type());
56+
}
57+
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]",
58+
op_desc.Type());
59+
it->SetEngine(engine);
60+
(*it)(op, scope);
4761
}
4862

4963
// convert fluid block to tensorrt network
5064
void ConvertBlock(const framework::proto::BlockDesc& block,
51-
TensorRTEngine* engine) {
65+
const std::unordered_set<std::string>& parameters,
66+
const framework::Scope& scope, TensorRTEngine* engine) {
5267
for (int i = 0; i < block.ops_size(); i++) {
5368
const auto& op = block.ops(i);
54-
OpConverter::Run(op, engine);
69+
ConvertOp(op, parameters, scope, engine);
5570
}
5671
}
5772

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <gtest/gtest.h>
16+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
17+
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
18+
19+
namespace paddle {
20+
namespace inference {
21+
namespace tensorrt {
22+
23+
TEST(fc_op, test) {
24+
std::unordered_set<std::string> parameters({"mul-Y"});
25+
framework::Scope scope;
26+
TRTConvertValidation validator(20, parameters, scope, 1000);
27+
28+
validator.DeclInputVar("mul-X", nvinfer1::Dims4(8, 3, 1, 1));
29+
validator.DeclParamVar("mul-Y", nvinfer1::Dims2(3, 2));
30+
validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(8, 2));
31+
32+
// Prepare Op description
33+
framework::OpDesc desc;
34+
desc.SetType("mul");
35+
desc.SetInput("X", {"mul-X"});
36+
desc.SetInput("Y", {"mul-Y"});
37+
desc.SetOutput("Out", {"mul-Out"});
38+
39+
validator.SetOp(*desc.Proto());
40+
41+
validator.Execute(10);
42+
}
43+
44+
} // namespace tensorrt
45+
} // namespace inference
46+
} // namespace paddle

paddle/fluid/inference/tensorrt/convert/test_mul_op.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ namespace inference {
2121
namespace tensorrt {
2222

2323
TEST(MulOpConverter, main) {
24-
TRTConvertValidation validator(10, 1000);
24+
framework::Scope scope;
25+
std::unordered_set<std::string> parameters;
26+
TRTConvertValidation validator(10, parameters, scope, 1000);
2527
validator.DeclInputVar("mul-X", nvinfer1::Dims2(10, 6));
2628
validator.DeclInputVar("mul-Y", nvinfer1::Dims2(6, 10));
2729
validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(10, 10));

paddle/fluid/inference/tensorrt/convert/test_op_converter.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
16+
1517
#include <gtest/gtest.h>
1618
#include "paddle/fluid/framework/program_desc.h"
17-
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
1819

1920
namespace paddle {
2021
namespace inference {
@@ -27,7 +28,9 @@ TEST(OpConverter, ConvertBlock) {
2728
conv2d_op->SetType("conv2d");
2829

2930
OpConverter converter;
30-
converter.ConvertBlock(*block->Proto(), nullptr /*TensorRTEngine*/);
31+
framework::Scope scope;
32+
converter.ConvertBlock(*block->Proto(), {}, scope,
33+
nullptr /*TensorRTEngine*/);
3134
}
3235

3336
} // namespace tensorrt

paddle/fluid/inference/tensorrt/convert/ut_helper.h

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ class TRTConvertValidation {
6161
public:
6262
TRTConvertValidation() = delete;
6363

64-
explicit TRTConvertValidation(int batch_size, int workspace_size = 1024) {
64+
TRTConvertValidation(int batch_size,
65+
const std::unordered_set<std::string>& parameters,
66+
framework::Scope& scope, int workspace_size = 1 << 10)
67+
: parameters_(parameters), scope_(scope) {
6568
// create engine.
6669
engine_.reset(new TensorRTEngine(10, 1 << 10, &stream_));
6770
engine_->InitNetwork();
@@ -76,19 +79,22 @@ class TRTConvertValidation {
7679
engine_->DeclareInput(name, nvinfer1::DataType::kFLOAT, dims);
7780
}
7881

82+
// Declare a parameter varaible in the scope.
83+
void DeclParamVar(const std::string& name, const nvinfer1::Dims& dims) {
84+
DeclVar(name, dims);
85+
}
86+
7987
void DeclOutputVar(const std::string& name, const nvinfer1::Dims& dims) {
8088
DeclVar(name, dims);
8189
}
8290

91+
// Declare a variable in a fluid Scope.
8392
void DeclVar(const std::string& name, const nvinfer1::Dims& dims) {
8493
platform::CPUPlace place;
8594
platform::CPUDeviceContext ctx(place);
8695

8796
// Init Fluid tensor.
88-
std::vector<int> dim_vec(dims.nbDims);
89-
for (int i = 0; i < dims.nbDims; i++) {
90-
dim_vec[i] = dims.d[i];
91-
}
97+
std::vector<int> dim_vec(dims.d, dims.d + dims.nbDims);
9298
auto* x = scope_.Var(name);
9399
auto* x_tensor = x->GetMutable<framework::LoDTensor>();
94100
x_tensor->Resize(framework::make_ddim(dim_vec));
@@ -99,7 +105,7 @@ class TRTConvertValidation {
99105
op_ = framework::OpRegistry::CreateOp(desc);
100106

101107
OpConverter op_converter;
102-
op_converter.ConvertOp(desc, engine_.get());
108+
op_converter.ConvertOp(desc, parameters_, scope_, engine_.get());
103109

104110
engine_->FreezeNetwork();
105111

@@ -108,38 +114,43 @@ class TRTConvertValidation {
108114

109115
// Set Inputs.
110116
for (const auto& input : op_desc_->InputArgumentNames()) {
117+
if (parameters_.count(input)) continue;
111118
auto* var = scope_.FindVar(input);
112119
PADDLE_ENFORCE(var);
113120
auto tensor = var->GetMutable<framework::LoDTensor>();
121+
114122
engine_->SetInputFromCPU(
115-
input, static_cast<void*>(tensor->data<float>()),
123+
input, static_cast<void*>(tensor->data<void>()),
116124
sizeof(float) *
117125
analysis::AccuDims(tensor->dims(), tensor->dims().size()));
118126
}
119127
}
120128

121129
void Execute(int batch_size) {
122130
// Execute Fluid Op
123-
// Execute TRT
124131
platform::CPUPlace place;
125132
platform::CPUDeviceContext ctx(place);
126-
engine_->Execute(batch_size);
127-
128133
op_->Run(scope_, place);
134+
// Execute TRT.
135+
engine_->Execute(batch_size);
136+
cudaStreamSynchronize(*engine_->stream());
129137

130138
ASSERT_FALSE(op_desc_->OutputArgumentNames().empty());
139+
const size_t output_space_size = 200;
131140
for (const auto& output : op_desc_->OutputArgumentNames()) {
132141
std::vector<float> fluid_out;
133-
std::vector<float> trt_out(200);
134-
engine_->GetOutputInCPU(output, &trt_out[0], 200 * sizeof(float));
142+
std::vector<float> trt_out(output_space_size);
143+
engine_->GetOutputInCPU(output, &trt_out[0],
144+
output_space_size * sizeof(float));
145+
cudaStreamSynchronize(*engine_->stream());
135146

136147
auto* var = scope_.FindVar(output);
137148
auto tensor = var->GetMutable<framework::LoDTensor>();
138149
framework::TensorToVector(*tensor, ctx, &fluid_out);
139150
// Compare two output
140151
ASSERT_FALSE(fluid_out.empty());
141152
for (size_t i = 0; i < fluid_out.size(); i++) {
142-
EXPECT_LT(std::abs(fluid_out[i] - trt_out[i]), 0.001);
153+
EXPECT_LT(std::abs(fluid_out[i] - trt_out[i]), 1e-6);
143154
}
144155
}
145156
}
@@ -149,9 +160,10 @@ class TRTConvertValidation {
149160
private:
150161
std::unique_ptr<TensorRTEngine> engine_;
151162
cudaStream_t stream_;
152-
framework::Scope scope_;
153163
std::unique_ptr<framework::OperatorBase> op_;
154164
std::unique_ptr<framework::OpDesc> op_desc_;
165+
const std::unordered_set<std::string>& parameters_;
166+
framework::Scope& scope_;
155167
};
156168

157169
} // namespace tensorrt

0 commit comments

Comments
 (0)