Skip to content

Commit 247002e

Browse files
commit (#44887)
1 parent 24b3bbd commit 247002e

File tree

8 files changed

+594
-43
lines changed

8 files changed

+594
-43
lines changed

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1808,6 +1808,8 @@ USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm)
18081808
USE_TRT_CONVERTER(preln_skip_layernorm)
18091809
USE_TRT_CONVERTER(roll)
18101810
USE_TRT_CONVERTER(strided_slice)
1811+
USE_TRT_CONVERTER(squeeze2)
1812+
USE_TRT_CONVERTER(unsqueeze2)
18111813
#endif
18121814

18131815
namespace paddle_infer {

paddle/fluid/inference/tensorrt/convert/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ nv_library(
5656
strided_slice_op.cc
5757
preln_skip_layernorm.cc
5858
roll_op.cc
59+
squeeze2_op.cc
60+
unsqueeze2_op.cc
5961
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto
6062
op_registry)
6163

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
16+
17+
namespace paddle {
18+
namespace inference {
19+
namespace tensorrt {
20+
21+
class Squeeze2OpConverter : public OpConverter {
22+
public:
23+
void operator()(const framework::proto::OpDesc& op,
24+
const framework::Scope& scope,
25+
bool test_mode) override {
26+
VLOG(4) << "convert a fluid squeeze2 op to tensorrt shuffle layer";
27+
28+
framework::OpDesc op_desc(op, nullptr);
29+
// Declare inputs
30+
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
31+
auto input_dims = input->getDimensions();
32+
auto output_name = op_desc.Output("Out")[0];
33+
34+
// Get Attrs
35+
std::vector<int> axes =
36+
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("axes"));
37+
PADDLE_ENFORCE_GT(
38+
axes.size(),
39+
0,
40+
platform::errors::InvalidArgument(
41+
"Attr(axes).size should be > 0 in squeeze2 op in TensorRT,"
42+
"but received axes.size() = %d.",
43+
axes.size()));
44+
45+
std::vector<bool> should_squeeze(input_dims.nbDims, false);
46+
for (size_t i = 0; i < axes.size(); i++) {
47+
if (engine_->with_dynamic_shape()) {
48+
axes[i] += (axes[i] < 0) ? input_dims.nbDims : 0;
49+
} else {
50+
axes[i] += (axes[i] < 0) ? input_dims.nbDims : -1;
51+
}
52+
should_squeeze[axes[i]] = true;
53+
}
54+
55+
nvinfer1::Dims trt_out_dims;
56+
trt_out_dims.nbDims = 0;
57+
std::vector<int32_t> gather_indices;
58+
for (size_t i = 0; i < should_squeeze.size(); i++) {
59+
if (should_squeeze[i]) continue;
60+
gather_indices.push_back(i);
61+
// for static shape
62+
trt_out_dims.d[trt_out_dims.nbDims] = input_dims.d[i];
63+
trt_out_dims.nbDims++;
64+
}
65+
66+
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
67+
if (engine_->with_dynamic_shape()) {
68+
auto* shape_tensor = Shape(input);
69+
auto* real_shape_tensor = Gather(shape_tensor, gather_indices);
70+
layer->setInput(1, *real_shape_tensor);
71+
} else {
72+
layer->setReshapeDimensions(trt_out_dims);
73+
}
74+
RreplenishLayerAndOutput(layer, "squeeze2", {output_name}, test_mode);
75+
}
76+
};
77+
78+
} // namespace tensorrt
79+
} // namespace inference
80+
} // namespace paddle
81+
82+
REGISTER_TRT_OP_CONVERTER(squeeze2, Squeeze2OpConverter);
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
16+
17+
namespace paddle {
18+
namespace inference {
19+
namespace tensorrt {
20+
21+
class Unsqueeze2OpConverter : public OpConverter {
22+
public:
23+
void operator()(const framework::proto::OpDesc& op,
24+
const framework::Scope& scope,
25+
bool test_mode) override {
26+
VLOG(4) << "convert a fluid unsqueeze2 op to tensorrt shuffle layer";
27+
28+
framework::OpDesc op_desc(op, nullptr);
29+
// Declare inputs
30+
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
31+
auto input_dims = input->getDimensions();
32+
auto output_name = op_desc.Output("Out")[0];
33+
34+
// Get Attrs
35+
std::vector<int> axes =
36+
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("axes"));
37+
PADDLE_ENFORCE_GT(
38+
axes.size(),
39+
0,
40+
platform::errors::InvalidArgument(
41+
"Attr(axes).size should be > 0 in unsqueeze2 op in TensorRT,"
42+
"but received axes.size() = %d.",
43+
axes.size()));
44+
45+
std::vector<bool> should_unsqueeze(input_dims.nbDims + axes.size(), false);
46+
int cur_out_rank = input_dims.nbDims;
47+
for (size_t i = 0; i < axes.size(); i++) {
48+
cur_out_rank++;
49+
if (engine_->with_dynamic_shape()) {
50+
axes[i] += (axes[i] < 0) ? cur_out_rank : 0;
51+
} else {
52+
axes[i] += (axes[i] < 0) ? cur_out_rank : -1;
53+
}
54+
// axes[i] is relative to cur_out_rank
55+
// we make [axes[i], cur_out_rank - 2] shift right
56+
// and make (axes[i]) to true!
57+
for (int j = cur_out_rank - 1; j > axes[i]; j--) {
58+
should_unsqueeze[j] = should_unsqueeze[j - 1];
59+
}
60+
if (axes[i] >= cur_out_rank)
61+
should_unsqueeze[cur_out_rank - 1] = true;
62+
else
63+
should_unsqueeze[axes[i]] = true;
64+
}
65+
66+
nvinfer1::Dims trt_out_dims;
67+
trt_out_dims.nbDims = should_unsqueeze.size();
68+
std::vector<int32_t> gather_indices;
69+
int in_rank_i = 0;
70+
for (size_t i = 0; i < should_unsqueeze.size(); i++) {
71+
if (should_unsqueeze[i]) {
72+
trt_out_dims.d[i] = 1;
73+
gather_indices.push_back(input_dims.nbDims);
74+
continue;
75+
}
76+
trt_out_dims.d[i] = input_dims.d[in_rank_i];
77+
gather_indices.push_back(in_rank_i);
78+
in_rank_i++;
79+
}
80+
81+
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
82+
if (engine_->with_dynamic_shape()) {
83+
auto* shape_tensor = Shape(input);
84+
std::vector<int32_t> all_one(axes.size(), 1);
85+
auto* all_one_tensor = Add1DConstantLayer(all_one);
86+
std::vector<nvinfer1::ITensor*> concat_inputs = {shape_tensor,
87+
all_one_tensor};
88+
auto* real_shape_tensor = Gather(Concat(concat_inputs), gather_indices);
89+
layer->setInput(1, *real_shape_tensor);
90+
} else {
91+
layer->setReshapeDimensions(trt_out_dims);
92+
}
93+
RreplenishLayerAndOutput(layer, "unsqueeze2", {output_name}, test_mode);
94+
}
95+
};
96+
97+
} // namespace tensorrt
98+
} // namespace inference
99+
} // namespace paddle
100+
101+
REGISTER_TRT_OP_CONVERTER(unsqueeze2, Unsqueeze2OpConverter);

paddle/fluid/inference/tensorrt/op_teller.cc

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ struct SimpleOpTypeSetTeller : public Teller {
114114
"bilinear_interp_v2",
115115
"cast",
116116
"pool3d",
117+
"squeeze2",
118+
"unsqueeze2",
117119
"deformable_conv",
118120
"relu6",
119121
"hard_sigmoid",
@@ -179,6 +181,8 @@ struct SimpleOpTypeSetTeller : public Teller {
179181
"nearest_interp_v2",
180182
"cast",
181183
"pool3d",
184+
"squeeze2",
185+
"unsqueeze2",
182186
"deformable_conv",
183187
"relu6",
184188
"hard_sigmoid",
@@ -891,6 +895,44 @@ bool OpTeller::Tell(const framework::ir::Node* node,
891895
}
892896
}
893897

898+
if (op_type == "squeeze2") {
899+
std::vector<int> axes;
900+
if (desc.HasAttr("axes")) {
901+
axes = BOOST_GET_CONST(std::vector<int>, desc.GetAttr("axes"));
902+
}
903+
if (axes.size() == 0) {
904+
VLOG(3) << "The necessary attributes of the squeeze2 operator axes is "
905+
"missing.";
906+
return false;
907+
}
908+
if (!with_dynamic_shape) {
909+
if (std::find(axes.begin(), axes.end(), 0) != axes.end()) {
910+
VLOG(3) << "Invalid squeeze axes. Axes having batch axis is not "
911+
"supported in static shape";
912+
return false;
913+
}
914+
}
915+
}
916+
917+
if (op_type == "unsqueeze2") {
918+
std::vector<int> axes;
919+
if (desc.HasAttr("axes")) {
920+
axes = BOOST_GET_CONST(std::vector<int>, desc.GetAttr("axes"));
921+
}
922+
if (axes.size() == 0) {
923+
VLOG(3) << "The necessary attributes of the squeeze2 operator axes is "
924+
"missing.";
925+
return false;
926+
}
927+
if (!with_dynamic_shape) {
928+
if (std::find(axes.begin(), axes.end(), 0) != axes.end()) {
929+
VLOG(3) << "Invalid squeeze axes. Axes having batch axis is not "
930+
"supported in static shape";
931+
return false;
932+
}
933+
}
934+
}
935+
894936
if (op_type == "batch_norm") {
895937
const std::vector<std::string> bn_inputs = {
896938
"X", "Bias", "Mean", "Scale", "Variance"};

0 commit comments

Comments
 (0)