Skip to content

Commit 9a04540

Browse files
[Paddle-TRT] fix conv2d/int64 (#45023)
* fix_conv2d_2_3 * commit * fix_conv2d_2_3 * fix_conv2d_2_3 * fix_conv2d_2_3
1 parent cbab018 commit 9a04540

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

paddle/fluid/inference/tensorrt/op_teller.cc

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,22 @@ bool OpTeller::Tell(const framework::ir::Node* node,
354354
}
355355
}
356356
#endif
357+
// In fact, this should include all conv, not only conv2d
358+
if (op_type == "conv2d") {
359+
auto* block = desc.Block();
360+
if (block == nullptr) {
361+
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
362+
"Developers need to check whether block_desc is passed in "
363+
"the pass.";
364+
return false;
365+
}
366+
auto* filter_var_desc = block->FindVar(desc.Input("Filter")[0]);
367+
if (!filter_var_desc->Persistable()) {
368+
VLOG(3) << "Trt not support filter is a intermediate tensor in "
369+
"conv2d op.";
370+
return false;
371+
}
372+
}
357373
}
358374

359375
if (op_type == "deformable_conv") {
@@ -912,6 +928,19 @@ bool OpTeller::Tell(const framework::ir::Node* node,
912928
return false;
913929
}
914930
}
931+
auto* block = desc.Block();
932+
if (block == nullptr) {
933+
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
934+
"Developers need to check whether block_desc is passed in "
935+
"the pass.";
936+
return false;
937+
}
938+
auto* x_var_desc = block->FindVar(desc.Input("X")[0]);
939+
auto dtype = x_var_desc->GetDataType();
940+
// At present, forbid int64_t into trt.
941+
if (dtype == 3) {
942+
return false;
943+
}
915944
}
916945

917946
if (op_type == "unsqueeze2") {
@@ -931,6 +960,19 @@ bool OpTeller::Tell(const framework::ir::Node* node,
931960
return false;
932961
}
933962
}
963+
auto* block = desc.Block();
964+
if (block == nullptr) {
965+
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
966+
"Developers need to check whether block_desc is passed in "
967+
"the pass.";
968+
return false;
969+
}
970+
auto* x_var_desc = block->FindVar(desc.Input("X")[0]);
971+
auto dtype = x_var_desc->GetDataType();
972+
// At present, forbid int64_t into trt.
973+
if (dtype == 3) {
974+
return false;
975+
}
934976
}
935977

936978
if (op_type == "batch_norm") {
@@ -1073,6 +1115,11 @@ bool OpTeller::Tell(const framework::ir::Node* node,
10731115
auto x_var_name = desc.Input("X")[0];
10741116
auto* x_var_desc = block->FindVar(x_var_name);
10751117
const auto x_shape = x_var_desc->GetShape();
1118+
auto dtype = x_var_desc->GetDataType();
1119+
// At present, only support float32 or float16 into trt.
1120+
if (!(dtype == 5 || dtype == 4)) {
1121+
return false;
1122+
}
10761123
if (!with_dynamic_shape && x_shape.size() == 1) {
10771124
VLOG(3) << "Scale op does not support 1-dimensional input in tensorrt";
10781125
return false;
@@ -1163,6 +1210,20 @@ bool OpTeller::Tell(const framework::ir::Node* node,
11631210
return false;
11641211
}
11651212
}
1213+
1214+
auto* block = desc.Block();
1215+
if (block == nullptr) {
1216+
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
1217+
"Developers need to check whether block_desc is passed in "
1218+
"the pass.";
1219+
return false;
1220+
}
1221+
auto* x_var_desc = block->FindVar(desc.Input("Input")[0]);
1222+
auto dtype = x_var_desc->GetDataType();
1223+
// At present, forbid int64_t into trt.
1224+
if (dtype == 3) {
1225+
return false;
1226+
}
11661227
}
11671228

11681229
if (op_type == "elementwise_add" || op_type == "elementwise_mul" ||

0 commit comments

Comments
 (0)