Skip to content

Commit fc41eb4

Browse files
committed
add conv2d trt converter
1 parent 6169d72 commit fc41eb4

File tree

3 files changed

+109
-1
lines changed

3 files changed

+109
-1
lines changed

paddle/fluid/inference/tensorrt/convert/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,6 @@ nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc
1313
DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL)
1414
nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
1515
DEPS ${FLUID_CORE_MODULES} tensorrt_engine activation_op SERIAL)
16+
17+
nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc
18+
DEPS ${FLUID_CORE_MODULES} tensorrt_engine conv_op SERIAL)

paddle/fluid/inference/tensorrt/convert/conv2d_op.cc

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,59 @@ namespace tensorrt {
2020

2121
class Conv2dOpConverter : public OpConverter {
2222
public:
23-
Conv2dOpConverter() {}
2423
void operator()(const framework::proto::OpDesc& op,
2524
const framework::Scope& scope, bool test_mode) override {
2625
LOG(INFO)
2726
<< "convert a fluid conv2d op to tensorrt conv layer without bias";
27+
28+
framework::OpDesc op_desc(op, nullptr);
29+
PADDLE_ENFORCE_EQ(op_desc.Input("Input").size(), 1);
30+
PADDLE_ENFORCE_EQ(op_desc.Input("Filter").size(), 1); // Y is a weight
31+
PADDLE_ENFORCE_EQ(op_desc.Output("Output").size(), 1);
32+
33+
auto* X = engine_->GetITensor(op_desc.Input("Input").front());
34+
// Declare weights
35+
auto* Y_v = scope.FindVar(op_desc.Input("Filter").front());
36+
PADDLE_ENFORCE_NOT_NULL(Y_v);
37+
auto* Y_t = Y_v->GetMutable<framework::LoDTensor>();
38+
auto* weight_data = Y_t->mutable_data<float>(platform::CPUPlace());
39+
40+
const int n_output = Y_t->dims()[0];
41+
const int filter_h = Y_t->dims()[2];
42+
const int filter_w = Y_t->dims()[3];
43+
44+
const int groups = boost::get<int>(op_desc.GetAttr("groups"));
45+
const std::vector<int> dilations =
46+
boost::get<std::vector<int>>(op_desc.GetAttr("dilations"));
47+
const std::vector<int> strides =
48+
boost::get<std::vector<int>>(op_desc.GetAttr("strides"));
49+
const std::vector<int> paddings =
50+
boost::get<std::vector<int>>(op_desc.GetAttr("paddings"));
51+
52+
nvinfer1::DimsHW nv_ksize(filter_h, filter_w);
53+
nvinfer1::DimsHW nv_dilations(dilations[0], dilations[1]);
54+
nvinfer1::DimsHW nv_strides(strides[0], strides[1]);
55+
nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]);
56+
57+
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
58+
static_cast<void*>(weight_data),
59+
Y_t->memory_size() / sizeof(float)};
60+
61+
TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, nullptr, 0};
62+
auto* layer = TRT_ENGINE_ADD_LAYER(
63+
engine_, Convolution, *const_cast<nvinfer1::ITensor*>(X), n_output,
64+
nv_ksize, weight.get(), bias.get());
65+
PADDLE_ENFORCE(layer != nullptr);
66+
layer->setStride(nv_strides);
67+
layer->setPadding(nv_paddings);
68+
layer->setDilation(nv_dilations);
69+
layer->setNbGroups(groups);
70+
71+
auto output_name = op_desc.Output("Output").front();
72+
engine_->SetITensor(output_name, layer->getOutput(0));
73+
if (test_mode) {
74+
engine_->DeclareOutput(output_name);
75+
}
2876
}
2977
};
3078

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <gtest/gtest.h>
16+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
17+
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
18+
19+
namespace paddle {
20+
namespace inference {
21+
namespace tensorrt {
22+
23+
TEST(conv2d_op, test) {
24+
std::unordered_set<std::string> parameters({"conv2d-Y"});
25+
framework::Scope scope;
26+
TRTConvertValidation validator(2, parameters, scope, 1 << 15);
27+
28+
validator.DeclInputVar("conv2d-X", nvinfer1::Dims4(1, 2, 5, 5));
29+
validator.DeclParamVar("conv2d-Y", nvinfer1::Dims4(3, 2, 3, 3));
30+
validator.DeclOutputVar("conv2d-Out", nvinfer1::Dims4(1, 3, 5, 5));
31+
32+
// Prepare Op description
33+
framework::OpDesc desc;
34+
desc.SetType("conv2d");
35+
desc.SetInput("Input", {"conv2d-X"});
36+
desc.SetInput("Filter", {"conv2d-Y"});
37+
desc.SetOutput("Output", {"conv2d-Out"});
38+
39+
const std::vector<int> strides({1, 1});
40+
const std::vector<int> paddings({1, 1});
41+
const std::vector<int> dilations({1, 1});
42+
const int groups = 1;
43+
44+
desc.SetAttr("strides", strides);
45+
desc.SetAttr("paddings", paddings);
46+
desc.SetAttr("dilations", dilations);
47+
desc.SetAttr("groups", groups);
48+
49+
validator.SetOp(*desc.Proto());
50+
51+
validator.Execute(1);
52+
}
53+
54+
} // namespace tensorrt
55+
} // namespace inference
56+
} // namespace paddle
57+
USE_OP(conv2d);

0 commit comments

Comments
 (0)