Skip to content

Commit ec56c75

Browse files
NHZlXPaddle CI
authored andcommitted
concat op converter && mobilenet && resnet && map cnn model support
(cherry-pick from commit 3de4556)
1 parent d91e84a commit ec56c75

File tree

6 files changed

+121
-5
lines changed

6 files changed

+121
-5
lines changed

paddle/fluid/inference/analysis/analyzer.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class DfgPassManagerImpl final : public DfgPassManager {
4545
auto trt_teller = [&](const Node* node) {
4646
std::unordered_set<std::string> teller_set(
4747
{"elementwise_add", "mul", "conv2d", "pool2d", "relu", "softmax",
48-
"depthwise_conv2d", "batch_norm"});
48+
"depthwise_conv2d", "batch_norm", "concat"});
4949
if (!node->IsFunction()) return false;
5050

5151
const auto* func = static_cast<const Function*>(node);

paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
3232
: NativePaddlePredictor(config), config_(config) {}
3333

3434
bool Init(const std::shared_ptr<framework::Scope>& parent_scope) {
35+
FLAGS_IA_enable_tensorrt_subgraph_engine = true;
3536
VLOG(3) << "Predictor::init()";
3637
FLAGS_tensorrt_max_batch_size = config_.max_batch_size;
3738
FLAGS_tensorrt_workspace_size = config_.workspace_size;
@@ -161,3 +162,4 @@ USE_TRT_CONVERTER(fc);
161162
USE_TRT_CONVERTER(pool2d);
162163
USE_TRT_CONVERTER(softmax);
163164
USE_TRT_CONVERTER(batch_norm);
165+
USE_TRT_CONVERTER(concat);

paddle/fluid/inference/tensorrt/convert/CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Add TRT tests
22
nv_library(tensorrt_converter
33
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
4-
batch_norm_op.cc activation_op.cc softmax_op.cc
4+
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc
55
DEPS tensorrt_engine operator scope framework_proto op_registry)
66

77
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
@@ -18,12 +18,12 @@ nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc
1818
DEPS ${FLUID_CORE_MODULES} tensorrt_engine conv_op SERIAL)
1919
nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc
2020
DEPS ${FLUID_CORE_MODULES} tensorrt_engine pool_op SERIAL)
21-
2221
nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc
2322
DEPS ${FLUID_CORE_MODULES} tensorrt_engine elementwise_add_op SERIAL)
24-
2523
nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc
2624
DEPS ${FLUID_CORE_MODULES} tensorrt_engine softmax_op SERIAL)
27-
2825
nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc
2926
DEPS ${FLUID_CORE_MODULES} tensorrt_engine batch_norm_op SERIAL)
27+
28+
nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc
29+
DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
16+
17+
namespace paddle {
18+
namespace inference {
19+
namespace tensorrt {
20+
21+
/*
22+
* MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights.
23+
*/
24+
class ConcatOpConverter : public OpConverter {
25+
public:
26+
void operator()(const framework::proto::OpDesc& op,
27+
const framework::Scope& scope, bool test_mode) override {
28+
VLOG(4) << "convert a fluid mul op to tensorrt mul layer without bias";
29+
30+
framework::OpDesc op_desc(op, nullptr);
31+
// Declare inputs
32+
std::vector<nvinfer1::ITensor*> itensors;
33+
for (auto& input_name : op_desc.Input("X")) {
34+
itensors.push_back(engine_->GetITensor(input_name));
35+
}
36+
int axis = boost::get<int>(op_desc.GetAttr("axis"));
37+
PADDLE_ENFORCE(axis > 0,
38+
"The axis attr of Concat op should be large than 0 for trt");
39+
40+
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Concatenation, itensors.data(),
41+
itensors.size());
42+
axis = axis - 1; // Remove batch dim
43+
layer->setAxis(axis);
44+
auto output_name = op_desc.Output("Out")[0];
45+
engine_->SetITensor(output_name, layer->getOutput(0));
46+
if (test_mode) { // the test framework can not determine which is the
47+
// output, so place the declaration inside.
48+
engine_->DeclareOutput(output_name);
49+
}
50+
}
51+
};
52+
53+
} // namespace tensorrt
54+
} // namespace inference
55+
} // namespace paddle
56+
57+
REGISTER_TRT_OP_CONVERTER(concat, ConcatOpConverter);

paddle/fluid/inference/tensorrt/convert/op_converter.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,14 @@ class OpConverter {
7979
it =
8080
Registry<OpConverter>::Lookup("elementwise_" + op_type + "_tensor");
8181
}
82+
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]",
83+
op_desc.Type());
84+
}
85+
86+
if (op_desc.Type() == "depthwise_conv2d") {
87+
it = Registry<OpConverter>::Lookup("conv2d");
88+
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]",
89+
op_desc.Type());
8290
}
8391

8492
if (!it) {
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <gtest/gtest.h>
16+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
17+
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
18+
19+
namespace paddle {
20+
namespace inference {
21+
namespace tensorrt {
22+
23+
TEST(concat_op, test) {
24+
std::unordered_set<std::string> parameters({""});
25+
framework::Scope scope;
26+
TRTConvertValidation validator(10, parameters, scope, 1000);
27+
validator.DeclInputVar("concat_x1", nvinfer1::DimsCHW(10, 3, 1));
28+
validator.DeclInputVar("concat_x2", nvinfer1::DimsCHW(3, 3, 1));
29+
validator.DeclInputVar("concat_x3", nvinfer1::DimsCHW(7, 3, 1));
30+
validator.DeclOutputVar("concat_out", nvinfer1::DimsCHW(20, 3, 1));
31+
32+
// Prepare Op description
33+
framework::OpDesc desc;
34+
desc.SetType("concat");
35+
desc.SetInput("X", {"concat_x1", "concat_x2", "concat_x3"});
36+
desc.SetOutput("Out", {"concat_out"});
37+
38+
int axis = 1;
39+
desc.SetAttr("axis", axis);
40+
41+
validator.SetOp(*desc.Proto());
42+
43+
validator.Execute(5);
44+
}
45+
46+
} // namespace tensorrt
47+
} // namespace inference
48+
} // namespace paddle
49+
USE_OP(concat);

0 commit comments

Comments
 (0)