Skip to content

Commit 830aa12

Browse files
committed
add elementwise init code
1 parent fb204fb commit 830aa12

File tree

5 files changed

+313
-1
lines changed

5 files changed

+313
-1
lines changed
Binary file not shown.

paddle/fluid/inference/tensorrt/convert/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Add TRT tests
22
nv_library(tensorrt_converter
3-
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc
3+
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
44
DEPS tensorrt_engine mul_op)
55

66
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
@@ -16,3 +16,6 @@ nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
1616

1717
nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc
1818
DEPS ${FLUID_CORE_MODULES} tensorrt_engine pool_op SERIAL)
19+
20+
nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc
21+
DEPS ${FLUID_CORE_MODULES} tensorrt_engine elementwise_add_op SERIAL)
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/op_registry.h"
16+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
17+
18+
namespace paddle {
19+
namespace inference {
20+
namespace tensorrt {
21+
22+
class ElementwiseWeightOpConverter : public OpConverter {
23+
public:
24+
ElementwiseWeightOpConverter() {}
25+
void operator()(const framework::proto::OpDesc& op,
26+
const framework::Scope& scope, bool test_mode) override {
27+
// Here the two nullptr looks strange, that's because the
28+
// framework::OpDesc's constructor is strange.
29+
framework::OpDesc op_desc(op, nullptr);
30+
LOG(INFO) << "convert a fluid elementwise op to tensorrt IScaleLayer";
31+
32+
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
33+
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1); // Y is a weight
34+
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
35+
36+
auto* X = engine_->GetITensor(op_desc.Input("X").front());
37+
nvinfer1::Dims dims_x = X->getDimensions();
38+
PADDLE_ENFORCE(dims_x.nbDims >= 3);
39+
40+
auto* Y_v = scope.FindVar(op_desc.Input("Y").front());
41+
PADDLE_ENFORCE_NOT_NULL(Y_v);
42+
auto* Y_t = Y_v->GetMutable<framework::LoDTensor>();
43+
auto* weight_data = Y_t->mutable_data<float>(platform::CPUPlace());
44+
auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
45+
46+
std::vector<int> dims_y = framework::vectorize2int(Y_t->dims());
47+
if (static_cast<int>(dims_y.size()) == dims_x.nbDims + 1) {
48+
if (dims_y[0] == 1) dims_y.erase(dims_y.begin());
49+
}
50+
51+
if (static_cast<int>(dims_y.size()) == 1 && dims_y[0] == dims_x.d[0]) {
52+
scale_mode = nvinfer1::ScaleMode::kCHANNEL;
53+
} else if (static_cast<int>(dims_y.size()) == dims_x.nbDims &&
54+
dims_y[0] == dims_x.d[0]) {
55+
scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
56+
for (int i = 1; i < dims_x.nbDims; i++) {
57+
if (dims_y[i] != dims_x.d[i]) {
58+
scale_mode = nvinfer1::ScaleMode::kCHANNEL;
59+
break;
60+
}
61+
}
62+
if (scale_mode == nvinfer1::ScaleMode::kCHANNEL) {
63+
for (int i = 1; i < dims_x.nbDims; i++) {
64+
if (dims_y[i] != 1)
65+
PADDLE_THROW(
66+
"TensorRT unsupported weight shape for Elementwise op!");
67+
}
68+
}
69+
} else {
70+
PADDLE_THROW("TensorRT unsupported weight Shape for Elementwise op!");
71+
}
72+
73+
TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT,
74+
static_cast<void*>(weight_data),
75+
Y_t->memory_size() / sizeof(float)};
76+
TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT, nullptr,
77+
0};
78+
TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr,
79+
0};
80+
81+
nvinfer1::IScaleLayer* layer = TRT_ENGINE_ADD_LAYER(
82+
engine_, Scale, *const_cast<nvinfer1::ITensor*>(X), scale_mode,
83+
shift_weights.get(), scale_weights.get(), power_weights.get());
84+
auto output_name = op_desc.Output("Out")[0];
85+
engine_->SetITensor(output_name, layer->getOutput(0));
86+
if (test_mode) { // the test framework can not determine which is the
87+
// output, so place the declaration inside.
88+
engine_->DeclareOutput(output_name);
89+
}
90+
}
91+
};
92+
93+
class ElementwiseTensorOpConverter : public OpConverter {
94+
public:
95+
ElementwiseTensorOpConverter() {}
96+
void operator()(const framework::proto::OpDesc& op,
97+
const framework::Scope& scope, bool test_mode) override {
98+
// Here the two nullptr looks strange, that's because the
99+
// framework::OpDesc's constructor is strange.
100+
framework::OpDesc op_desc(op, nullptr);
101+
LOG(INFO) << "convert a fluid elementwise op to tensorrt IScaleLayer";
102+
103+
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
104+
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1); // Y is a weight
105+
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
106+
107+
auto* X = engine_->GetITensor(op_desc.Input("X").front());
108+
auto* Y = engine_->GetITensor(op_desc.Input("Y").front());
109+
nvinfer1::Dims dims_x = X->getDimensions();
110+
nvinfer1::Dims dims_y = Y->getDimensions();
111+
112+
// only support the C * H * W input format
113+
PADDLE_ENFORCE(dims_x.nbDims >= 3);
114+
if (dims_x.nbDims == dims_y.nbDims) {
115+
for (int i = 0; i < dims_x.nbDims; i++) {
116+
if (dims_x.d[i] != dims_y.d[i])
117+
PADDLE_THROW("TensorRT unsupported tensor shape for Elementwise op!");
118+
}
119+
} else {
120+
PADDLE_THROW("TensorRT unsupported tensor shape for Elementwise op!");
121+
}
122+
123+
auto op_pair = ops.find(op_type_);
124+
if (op_pair == ops.end()) {
125+
PADDLE_THROW("Wrong elementwise op type!");
126+
}
127+
nvinfer1::IElementWiseLayer* layer = TRT_ENGINE_ADD_LAYER(
128+
engine_, ElementWise, *const_cast<nvinfer1::ITensor*>(X),
129+
*const_cast<nvinfer1::ITensor*>(Y), op_pair->second);
130+
131+
auto output_name = op_desc.Output("Out")[0];
132+
engine_->SetITensor(output_name, layer->getOutput(0));
133+
if (test_mode) { // the test framework can not determine which is the
134+
// output, so place the declaration inside.
135+
engine_->DeclareOutput(output_name);
136+
}
137+
}
138+
139+
protected:
140+
static const std::unordered_map<std::string, nvinfer1::ElementWiseOperation>
141+
ops;
142+
std::string op_type_;
143+
};
144+
145+
const std::unordered_map<std::string, nvinfer1::ElementWiseOperation>
146+
ElementwiseTensorOpConverter::ops = {
147+
{"add", nvinfer1::ElementWiseOperation::kSUM},
148+
{"mul", nvinfer1::ElementWiseOperation::kPROD},
149+
{"sub", nvinfer1::ElementWiseOperation::kSUB},
150+
{"div", nvinfer1::ElementWiseOperation::kDIV},
151+
{"min", nvinfer1::ElementWiseOperation::kMIN},
152+
{"pow", nvinfer1::ElementWiseOperation::kPOW},
153+
{"max", nvinfer1::ElementWiseOperation::kMAX},
154+
};
155+
156+
class ElementwiseTensorAddOpConverter : public ElementwiseTensorOpConverter {
157+
public:
158+
ElementwiseTensorAddOpConverter() { op_type_ = "add"; }
159+
};
160+
161+
class ElementwiseTensorMulOpConverter : public ElementwiseTensorOpConverter {
162+
public:
163+
ElementwiseTensorMulOpConverter() { op_type_ = "mul"; }
164+
};
165+
166+
class ElementwiseTensorSubOpConverter : public ElementwiseTensorOpConverter {
167+
public:
168+
ElementwiseTensorSubOpConverter() { op_type_ = "sub"; }
169+
};
170+
171+
class ElementwiseTensorDivOpConverter : public ElementwiseTensorOpConverter {
172+
public:
173+
ElementwiseTensorDivOpConverter() { op_type_ = "div"; }
174+
};
175+
176+
class ElementwiseTensorMinOpConverter : public ElementwiseTensorOpConverter {
177+
public:
178+
ElementwiseTensorMinOpConverter() { op_type_ = "min"; }
179+
};
180+
181+
class ElementwiseTensorMaxOpConverter : public ElementwiseTensorOpConverter {
182+
public:
183+
ElementwiseTensorMaxOpConverter() { op_type_ = "max"; }
184+
};
185+
186+
class ElementwiseTensorPowOpConverter : public ElementwiseTensorOpConverter {
187+
public:
188+
ElementwiseTensorPowOpConverter() { op_type_ = "pow"; }
189+
};
190+
191+
} // namespace tensorrt
192+
} // namespace inference
193+
} // namespace paddle
194+
195+
REGISTER_TRT_OP_CONVERTER(elementwise_add_weight, ElementwiseWeightOpConverter);
196+
197+
REGISTER_TRT_OP_CONVERTER(elementwise_add_tensor,
198+
ElementwiseTensorAddOpConverter);
199+
REGISTER_TRT_OP_CONVERTER(elementwise_sub_tensor,
200+
ElementwiseTensorSubOpConverter);
201+
REGISTER_TRT_OP_CONVERTER(elementwise_div_tensor,
202+
ElementwiseTensorDivOpConverter);
203+
REGISTER_TRT_OP_CONVERTER(elementwise_mul_tensor,
204+
ElementwiseTensorMulOpConverter);
205+
REGISTER_TRT_OP_CONVERTER(elementwise_max_tensor,
206+
ElementwiseTensorMaxOpConverter);
207+
REGISTER_TRT_OP_CONVERTER(elementwise_min_tensor,
208+
ElementwiseTensorMinOpConverter);
209+
REGISTER_TRT_OP_CONVERTER(elementwise_pow_tensor,
210+
ElementwiseTensorPowOpConverter);

paddle/fluid/inference/tensorrt/convert/op_converter.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,31 @@ class OpConverter {
5555
it = Registry<OpConverter>::Lookup("fc");
5656
}
5757
}
58+
59+
if (op_desc.Type().find("elementwise") != std::string::npos) {
60+
static std::unordered_set<std::string> add_tensor_op_set{
61+
"add", "mul", "sub", "div", "max", "min", "pow"};
62+
// TODO(xingzhaolong): all mul, sub, div
63+
// static std::unordered_set<std::string> add_weight_op_set {"add", "mul",
64+
// "sub", "div"};
65+
static std::unordered_set<std::string> add_weight_op_set{"add"};
66+
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1UL);
67+
int op_type_len = op_desc.Type().size();
68+
std::string op_type = op_desc.Type().substr(op_type_len - 3, op_type_len);
69+
std::string Y = op_desc.Input("Y")[0];
70+
if (parameters.count(Y)) {
71+
PADDLE_ENFORCE(add_weight_op_set.count(op_type) > 0,
72+
"Unsupported elementwise type" + op_type);
73+
it =
74+
Registry<OpConverter>::Lookup("elementwise_" + op_type + "_weight");
75+
} else {
76+
PADDLE_ENFORCE(add_tensor_op_set.count(op_type) > 0,
77+
"Unsupported elementwise type" + op_type);
78+
it =
79+
Registry<OpConverter>::Lookup("elementwise_" + op_type + "_tensor");
80+
}
81+
}
82+
5883
if (!it) {
5984
it = Registry<OpConverter>::Lookup(op_desc.Type());
6085
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <gtest/gtest.h>
16+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
17+
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
18+
19+
namespace paddle {
20+
namespace inference {
21+
namespace tensorrt {
22+
23+
TEST(elementwise_op, add_weight_test) {
24+
std::unordered_set<std::string> parameters({"elementwise_add-Y"});
25+
framework::Scope scope;
26+
TRTConvertValidation validator(1, parameters, scope, 1 << 15);
27+
validator.DeclInputVar("elementwise_add-X", nvinfer1::DimsCHW(10, 3, 3));
28+
validator.DeclParamVar("elementwise_add-Y", nvinfer1::Dims3(10, 1, 1));
29+
// validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2));
30+
validator.DeclOutputVar("elementwise_add-Out", nvinfer1::DimsCHW(10, 3, 3));
31+
32+
// Prepare Op description
33+
framework::OpDesc desc;
34+
desc.SetType("elementwise_add");
35+
desc.SetInput("X", {"elementwise_add-X"});
36+
desc.SetInput("Y", {"elementwise_add-Y"});
37+
desc.SetOutput("Out", {"elementwise_add-Out"});
38+
39+
int axis = 1;
40+
desc.SetAttr("axis", axis);
41+
42+
validator.SetOp(*desc.Proto());
43+
44+
validator.Execute(1);
45+
}
46+
47+
TEST(elementwise_op, add_tensor_test) {
48+
std::unordered_set<std::string> parameters;
49+
framework::Scope scope;
50+
TRTConvertValidation validator(1, parameters, scope, 1 << 15);
51+
validator.DeclInputVar("elementwise_add-X", nvinfer1::DimsCHW(10, 3, 3));
52+
validator.DeclInputVar("elementwise_add-Y", nvinfer1::Dims3(10, 3, 3));
53+
// validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2));
54+
validator.DeclOutputVar("elementwise_add-Out", nvinfer1::DimsCHW(10, 3, 3));
55+
56+
// Prepare Op description
57+
framework::OpDesc desc;
58+
desc.SetType("elementwise_add");
59+
desc.SetInput("X", {"elementwise_add-X"});
60+
desc.SetInput("Y", {"elementwise_add-Y"});
61+
desc.SetOutput("Out", {"elementwise_add-Out"});
62+
63+
int axis = 1;
64+
desc.SetAttr("axis", axis);
65+
66+
validator.SetOp(*desc.Proto());
67+
68+
validator.Execute(1);
69+
}
70+
71+
} // namespace tensorrt
72+
} // namespace inference
73+
} // namespace paddle
74+
USE_OP(elementwise_add);

0 commit comments

Comments
 (0)