@@ -354,6 +354,22 @@ bool OpTeller::Tell(const framework::ir::Node* node,
354
354
}
355
355
}
356
356
#endif
357
+ // In fact, this should include all conv, not only conv2d
358
+ if (op_type == " conv2d" ) {
359
+ auto * block = desc.Block ();
360
+ if (block == nullptr ) {
361
+ VLOG (3 ) << " The block desc is nullptr, we can't continue to analyze. "
362
+ " Developers need to check whether block_desc is passed in "
363
+ " the pass." ;
364
+ return false ;
365
+ }
366
+ auto * filter_var_desc = block->FindVar (desc.Input (" Filter" )[0 ]);
367
+ if (!filter_var_desc->Persistable ()) {
368
+ VLOG (3 ) << " Trt not support filter is a intermediate tensor in "
369
+ " conv2d op." ;
370
+ return false ;
371
+ }
372
+ }
357
373
}
358
374
359
375
if (op_type == " deformable_conv" ) {
@@ -912,6 +928,19 @@ bool OpTeller::Tell(const framework::ir::Node* node,
912
928
return false ;
913
929
}
914
930
}
931
+ auto * block = desc.Block ();
932
+ if (block == nullptr ) {
933
+ VLOG (3 ) << " The block desc is nullptr, we can't continue to analyze. "
934
+ " Developers need to check whether block_desc is passed in "
935
+ " the pass." ;
936
+ return false ;
937
+ }
938
+ auto * x_var_desc = block->FindVar (desc.Input (" X" )[0 ]);
939
+ auto dtype = x_var_desc->GetDataType ();
940
+ // At present, forbid int64_t into trt.
941
+ if (dtype == 3 ) {
942
+ return false ;
943
+ }
915
944
}
916
945
917
946
if (op_type == " unsqueeze2" ) {
@@ -931,6 +960,19 @@ bool OpTeller::Tell(const framework::ir::Node* node,
931
960
return false ;
932
961
}
933
962
}
963
+ auto * block = desc.Block ();
964
+ if (block == nullptr ) {
965
+ VLOG (3 ) << " The block desc is nullptr, we can't continue to analyze. "
966
+ " Developers need to check whether block_desc is passed in "
967
+ " the pass." ;
968
+ return false ;
969
+ }
970
+ auto * x_var_desc = block->FindVar (desc.Input (" X" )[0 ]);
971
+ auto dtype = x_var_desc->GetDataType ();
972
+ // At present, forbid int64_t into trt.
973
+ if (dtype == 3 ) {
974
+ return false ;
975
+ }
934
976
}
935
977
936
978
if (op_type == " batch_norm" ) {
@@ -1073,6 +1115,11 @@ bool OpTeller::Tell(const framework::ir::Node* node,
1073
1115
auto x_var_name = desc.Input (" X" )[0 ];
1074
1116
auto * x_var_desc = block->FindVar (x_var_name);
1075
1117
const auto x_shape = x_var_desc->GetShape ();
1118
+ auto dtype = x_var_desc->GetDataType ();
1119
+ // At present, only support float32 or float16 into trt.
1120
+ if (!(dtype == 5 || dtype == 4 )) {
1121
+ return false ;
1122
+ }
1076
1123
if (!with_dynamic_shape && x_shape.size () == 1 ) {
1077
1124
VLOG (3 ) << " Scale op does not support 1-dimensional input in tensorrt" ;
1078
1125
return false ;
@@ -1163,6 +1210,20 @@ bool OpTeller::Tell(const framework::ir::Node* node,
1163
1210
return false ;
1164
1211
}
1165
1212
}
1213
+
1214
+ auto * block = desc.Block ();
1215
+ if (block == nullptr ) {
1216
+ VLOG (3 ) << " The block desc is nullptr, we can't continue to analyze. "
1217
+ " Developers need to check whether block_desc is passed in "
1218
+ " the pass." ;
1219
+ return false ;
1220
+ }
1221
+ auto * x_var_desc = block->FindVar (desc.Input (" Input" )[0 ]);
1222
+ auto dtype = x_var_desc->GetDataType ();
1223
+ // At present, forbid int64_t into trt.
1224
+ if (dtype == 3 ) {
1225
+ return false ;
1226
+ }
1166
1227
}
1167
1228
1168
1229
if (op_type == " elementwise_add" || op_type == " elementwise_mul" ||
0 commit comments