Skip to content

Commit 7cdce09

Browse files
authored
[cherry pick] add cast trt convert (#44837)
* add cast trt convert * skip cast trt convert when input dtype is bool * code format * fix bug * update unittest * fix bug
1 parent 627e5bd commit 7cdce09

File tree

5 files changed

+317
-49
lines changed

5 files changed

+317
-49
lines changed

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1793,6 +1793,7 @@ USE_TRT_CONVERTER(multiclass_nms3);
17931793
USE_TRT_CONVERTER(nearest_interp);
17941794
USE_TRT_CONVERTER(nearest_interp_v2);
17951795
USE_TRT_CONVERTER(bilinear_interp_v2);
1796+
USE_TRT_CONVERTER(cast);
17961797
USE_TRT_CONVERTER(reshape);
17971798
USE_TRT_CONVERTER(reduce_sum);
17981799
USE_TRT_CONVERTER(gather_nd);
Lines changed: 64 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,66 @@
11
# Add TRT tests
2-
nv_library(tensorrt_converter
3-
SRCS matmul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
4-
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc group_norm_op.cc
5-
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
6-
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc flatten_contiguous_range_op.cc
7-
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc
8-
gather_op.cc
9-
bilinear_interp_v2_op.cc
10-
anchor_generator_op.cc
11-
yolo_box_op.cc
12-
roi_align_op.cc
13-
affine_channel_op.cc
14-
multiclass_nms_op.cc
15-
multiclass_nms3_op.cc
16-
nearest_interp_op.cc
17-
reshape_op.cc
18-
reduce_op.cc
19-
gather_nd_op.cc
20-
tile_op.cc
21-
conv3d_op.cc
22-
mish_op.cc
23-
nearest_interp_v2_op.cc
24-
pool3d_op.cc
25-
deformable_conv_op.cc
26-
preln_emb_eltwise_layernorm.cc
27-
strided_slice_op.cc
28-
preln_skip_layernorm.cc
29-
roll_op.cc
30-
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
2+
nv_library(
3+
tensorrt_converter
4+
SRCS matmul_op.cc
5+
conv2d_op.cc
6+
fc_op.cc
7+
pool2d_op.cc
8+
elementwise_op.cc
9+
batch_norm_op.cc
10+
activation_op.cc
11+
softmax_op.cc
12+
concat_op.cc
13+
dropout_op.cc
14+
group_norm_op.cc
15+
pad_op.cc
16+
split_op.cc
17+
prelu_op.cc
18+
leaky_relu_op.cc
19+
gelu_op.cc
20+
layer_norm_op.cc
21+
multihead_matmul_op.cc
22+
shuffle_channel_op.cc
23+
swish_op.cc
24+
instance_norm_op.cc
25+
stack_op.cc
26+
transpose_op.cc
27+
flatten_op.cc
28+
flatten_contiguous_range_op.cc
29+
emb_eltwise_layernorm.cc
30+
skip_layernorm.cc
31+
scale_op.cc
32+
slice_op.cc
33+
hard_sigmoid_op.cc
34+
hard_swish_op.cc
35+
clip_op.cc
36+
gather_op.cc
37+
bilinear_interp_v2_op.cc
38+
cast_op.cc
39+
anchor_generator_op.cc
40+
yolo_box_op.cc
41+
roi_align_op.cc
42+
affine_channel_op.cc
43+
multiclass_nms_op.cc
44+
multiclass_nms3_op.cc
45+
nearest_interp_op.cc
46+
reshape_op.cc
47+
reduce_op.cc
48+
gather_nd_op.cc
49+
tile_op.cc
50+
conv3d_op.cc
51+
mish_op.cc
52+
nearest_interp_v2_op.cc
53+
pool3d_op.cc
54+
deformable_conv_op.cc
55+
preln_emb_eltwise_layernorm.cc
56+
strided_slice_op.cc
57+
preln_skip_layernorm.cc
58+
roll_op.cc
59+
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto
60+
op_registry)
3161

32-
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
33-
paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_converter)
62+
nv_test(
63+
test_op_converter
64+
SRCS test_op_converter.cc
65+
DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine
66+
tensorrt_converter)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
13+
14+
namespace paddle {
15+
namespace framework {
16+
class Scope;
17+
18+
namespace proto {
19+
class OpDesc;
20+
} // namespace proto
21+
} // namespace framework
22+
} // namespace paddle
23+
24+
namespace paddle {
25+
namespace inference {
26+
namespace tensorrt {
27+
28+
class CastOpConverter : public OpConverter {
29+
public:
30+
void operator()(const framework::proto::OpDesc& op,
31+
const framework::Scope& scope,
32+
bool test_mode) override {
33+
VLOG(3) << "convert a cast op to tensorrt";
34+
framework::OpDesc op_desc(op, nullptr);
35+
36+
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
37+
auto out_dtype = BOOST_GET_CONST(int, op_desc.GetAttr("out_dtype"));
38+
39+
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Identity, *input);
40+
41+
switch (out_dtype) {
42+
case 2: // INT32 = 2
43+
layer->getOutput(0)->setType(nvinfer1::DataType::kINT32);
44+
break;
45+
case 4: // FP16 = 4
46+
layer->getOutput(0)->setType(nvinfer1::DataType::kHALF);
47+
break;
48+
case 5: // FP32 = 5
49+
layer->getOutput(0)->setType(nvinfer1::DataType::kFLOAT);
50+
break;
51+
default:
52+
LOG(ERROR) << "Unable to convert a fluid data type(" << out_dtype
53+
<< ") to a nvinfer DataType";
54+
break;
55+
}
56+
57+
auto output_name = op_desc.Output("Out")[0];
58+
RreplenishLayerAndOutput(layer, "cast", {output_name}, test_mode);
59+
}
60+
};
61+
62+
} // namespace tensorrt
63+
} // namespace inference
64+
} // namespace paddle
65+
66+
REGISTER_TRT_OP_CONVERTER(cast, CastOpConverter);

paddle/fluid/inference/tensorrt/op_teller.cc

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ struct SimpleOpTypeSetTeller : public Teller {
4949
#endif
5050
}
5151

52-
bool operator()(const std::string& op_type, const framework::OpDesc& desc,
52+
bool operator()(const std::string& op_type,
53+
const framework::OpDesc& desc,
5354
bool use_no_calib_int8) override {
5455
if (use_no_calib_int8) {
5556
return int8_teller_set.count(op_type);
@@ -111,6 +112,7 @@ struct SimpleOpTypeSetTeller : public Teller {
111112
"mish",
112113
"nearest_interp_v2",
113114
"bilinear_interp_v2",
115+
"cast",
114116
"pool3d",
115117
"deformable_conv",
116118
"relu6",
@@ -175,6 +177,7 @@ struct SimpleOpTypeSetTeller : public Teller {
175177
"mish",
176178
"bilinear_interp_v2",
177179
"nearest_interp_v2",
180+
"cast",
178181
"pool3d",
179182
"deformable_conv",
180183
"relu6",
@@ -191,7 +194,8 @@ struct SimpleOpTypeSetTeller : public Teller {
191194
"multiclass_nms3"};
192195
};
193196

194-
bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
197+
bool OpTeller::Tell(const framework::ir::Node* node,
198+
bool use_no_calib_int8,
195199
bool with_dynamic_shape) {
196200
const std::string op_type = node->Op()->Type();
197201
const framework::OpDesc desc = *node->Op();
@@ -706,8 +710,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
706710
}
707711

708712
if (op_type == "nearest_interp") {
709-
std::vector<std::string> attrs{"interp_method", "align_corners", "scale",
710-
"out_h", "out_w"};
713+
std::vector<std::string> attrs{
714+
"interp_method", "align_corners", "scale", "out_h", "out_w"};
711715
for (auto const attr : attrs) {
712716
if (!desc.HasAttr(attr)) return false;
713717
}
@@ -747,9 +751,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
747751
}
748752

749753
if (op_type == "nearest_interp_v2") {
750-
std::vector<std::string> attrs{"data_layout", "interp_method",
751-
"align_corners", "scale",
752-
"out_h", "out_w"};
754+
std::vector<std::string> attrs{"data_layout",
755+
"interp_method",
756+
"align_corners",
757+
"scale",
758+
"out_h",
759+
"out_w"};
753760
for (auto const attr : attrs) {
754761
if (!desc.HasAttr(attr)) return false;
755762
}
@@ -775,9 +782,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
775782
}
776783

777784
if (op_type == "bilinear_interp_v2") {
778-
std::vector<std::string> attrs{"data_layout", "interp_method",
779-
"align_corners", "scale",
780-
"out_h", "out_w"};
785+
std::vector<std::string> attrs{"data_layout",
786+
"interp_method",
787+
"align_corners",
788+
"scale",
789+
"out_h",
790+
"out_w"};
781791
for (auto const attr : attrs) {
782792
if (!desc.HasAttr(attr)) {
783793
VLOG(3) << "The op_type " << op_type << " doesn't have the attr "
@@ -882,8 +892,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
882892
}
883893

884894
if (op_type == "batch_norm") {
885-
const std::vector<std::string> bn_inputs = {"X", "Bias", "Mean", "Scale",
886-
"Variance"};
895+
const std::vector<std::string> bn_inputs = {
896+
"X", "Bias", "Mean", "Scale", "Variance"};
887897
for (unsigned int i = 0; i < bn_inputs.size(); i++) {
888898
if (desc.Input(bn_inputs[i]).size() != 1) {
889899
VLOG(3) << "Invalid " << bn_inputs[i]
@@ -1458,8 +1468,10 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
14581468
"the roi_align will change the batch size.";
14591469
return false;
14601470
}
1461-
std::vector<std::string> attrs{"pooled_height", "pooled_width",
1462-
"spatial_scale", "sampling_ratio",
1471+
std::vector<std::string> attrs{"pooled_height",
1472+
"pooled_width",
1473+
"spatial_scale",
1474+
"sampling_ratio",
14631475
"aligned"};
14641476
for (auto const attr : attrs) {
14651477
if (!desc.HasAttr(attr)) return false;
@@ -1641,10 +1653,10 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
16411653
auto x_var_name = desc.Input("X")[0];
16421654
auto* x_var_desc = block->FindVar(x_var_name);
16431655
const auto x_shape = x_var_desc->GetShape();
1644-
int input_num = std::accumulate(x_shape.begin() + 1, x_shape.end(), 1,
1645-
std::multiplies<int>());
1646-
int shape_num = std::accumulate(shape.begin() + 1, shape.end(), 1,
1647-
std::multiplies<int>());
1656+
int input_num = std::accumulate(
1657+
x_shape.begin() + 1, x_shape.end(), 1, std::multiplies<int>());
1658+
int shape_num = std::accumulate(
1659+
shape.begin() + 1, shape.end(), 1, std::multiplies<int>());
16481660
if (input_num == shape_num) {
16491661
return true;
16501662
}
@@ -1751,6 +1763,36 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
17511763
}
17521764
#endif
17531765

1766+
if (op_type == "cast") {
1767+
// trt 6015 result in Windows ppyolo_mbv3 TRT fp32 diff
1768+
#if !IS_TRT_VERSION_GE(7000)
1769+
return false;
1770+
#endif
1771+
if (!(desc.HasAttr("in_dtype") && desc.HasAttr("out_dtype"))) {
1772+
VLOG(3) << "the " << op_type
1773+
<< " does not have attr (in_dtype or "
1774+
"out_dtype)";
1775+
return false;
1776+
}
1777+
int in_dtype = BOOST_GET_CONST(int, desc.GetAttr("in_dtype"));
1778+
int out_dtype = BOOST_GET_CONST(int, desc.GetAttr("out_dtype"));
1779+
if ((in_dtype == 4 || in_dtype == 5) && out_dtype == 4) {
1780+
VLOG(3) << "unsupport data type conversion";
1781+
return false;
1782+
}
1783+
if (in_dtype == 0) {
1784+
VLOG(3) << "do not support input data type as bool now";
1785+
return false;
1786+
}
1787+
if (!((in_dtype == 5 || in_dtype == 4 || in_dtype == 2) &&
1788+
(out_dtype == 5 || out_dtype == 4 || out_dtype == 2))) {
1789+
VLOG(3)
1790+
<< "only valid conversions are: "
1791+
"(kFLOAT | kHALF | kINT32 | kBOOL) -> (kFLOAT | kHALF | kINT32)";
1792+
return false;
1793+
}
1794+
}
1795+
17541796
if (op_type == "conv3d" || op_type == "conv3d_transpose") {
17551797
if (desc.HasAttr("padding_algorithm")) {
17561798
std::string padding_algorithm =

0 commit comments

Comments
 (0)