@@ -49,7 +49,8 @@ struct SimpleOpTypeSetTeller : public Teller {
49
49
#endif
50
50
}
51
51
52
- bool operator ()(const std::string& op_type, const framework::OpDesc& desc,
52
+ bool operator ()(const std::string& op_type,
53
+ const framework::OpDesc& desc,
53
54
bool use_no_calib_int8) override {
54
55
if (use_no_calib_int8) {
55
56
return int8_teller_set.count (op_type);
@@ -111,6 +112,7 @@ struct SimpleOpTypeSetTeller : public Teller {
111
112
" mish" ,
112
113
" nearest_interp_v2" ,
113
114
" bilinear_interp_v2" ,
115
+ " cast" ,
114
116
" pool3d" ,
115
117
" deformable_conv" ,
116
118
" relu6" ,
@@ -175,6 +177,7 @@ struct SimpleOpTypeSetTeller : public Teller {
175
177
" mish" ,
176
178
" bilinear_interp_v2" ,
177
179
" nearest_interp_v2" ,
180
+ " cast" ,
178
181
" pool3d" ,
179
182
" deformable_conv" ,
180
183
" relu6" ,
@@ -191,7 +194,8 @@ struct SimpleOpTypeSetTeller : public Teller {
191
194
" multiclass_nms3" };
192
195
};
193
196
194
- bool OpTeller::Tell (const framework::ir::Node* node, bool use_no_calib_int8,
197
+ bool OpTeller::Tell (const framework::ir::Node* node,
198
+ bool use_no_calib_int8,
195
199
bool with_dynamic_shape) {
196
200
const std::string op_type = node->Op ()->Type ();
197
201
const framework::OpDesc desc = *node->Op ();
@@ -706,8 +710,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
706
710
}
707
711
708
712
if (op_type == " nearest_interp" ) {
709
- std::vector<std::string> attrs{" interp_method " , " align_corners " , " scale " ,
710
- " out_h" , " out_w" };
713
+ std::vector<std::string> attrs{
714
+ " interp_method " , " align_corners " , " scale " , " out_h" , " out_w" };
711
715
for (auto const attr : attrs) {
712
716
if (!desc.HasAttr (attr)) return false ;
713
717
}
@@ -747,9 +751,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
747
751
}
748
752
749
753
if (op_type == " nearest_interp_v2" ) {
750
- std::vector<std::string> attrs{" data_layout" , " interp_method" ,
751
- " align_corners" , " scale" ,
752
- " out_h" , " out_w" };
754
+ std::vector<std::string> attrs{" data_layout" ,
755
+ " interp_method" ,
756
+ " align_corners" ,
757
+ " scale" ,
758
+ " out_h" ,
759
+ " out_w" };
753
760
for (auto const attr : attrs) {
754
761
if (!desc.HasAttr (attr)) return false ;
755
762
}
@@ -775,9 +782,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
775
782
}
776
783
777
784
if (op_type == " bilinear_interp_v2" ) {
778
- std::vector<std::string> attrs{" data_layout" , " interp_method" ,
779
- " align_corners" , " scale" ,
780
- " out_h" , " out_w" };
785
+ std::vector<std::string> attrs{" data_layout" ,
786
+ " interp_method" ,
787
+ " align_corners" ,
788
+ " scale" ,
789
+ " out_h" ,
790
+ " out_w" };
781
791
for (auto const attr : attrs) {
782
792
if (!desc.HasAttr (attr)) {
783
793
VLOG (3 ) << " The op_type " << op_type << " doesn't have the attr "
@@ -882,8 +892,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
882
892
}
883
893
884
894
if (op_type == " batch_norm" ) {
885
- const std::vector<std::string> bn_inputs = {" X " , " Bias " , " Mean " , " Scale " ,
886
- " Variance" };
895
+ const std::vector<std::string> bn_inputs = {
896
+ " X " , " Bias " , " Mean " , " Scale " , " Variance" };
887
897
for (unsigned int i = 0 ; i < bn_inputs.size (); i++) {
888
898
if (desc.Input (bn_inputs[i]).size () != 1 ) {
889
899
VLOG (3 ) << " Invalid " << bn_inputs[i]
@@ -1458,8 +1468,10 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
1458
1468
" the roi_align will change the batch size." ;
1459
1469
return false ;
1460
1470
}
1461
- std::vector<std::string> attrs{" pooled_height" , " pooled_width" ,
1462
- " spatial_scale" , " sampling_ratio" ,
1471
+ std::vector<std::string> attrs{" pooled_height" ,
1472
+ " pooled_width" ,
1473
+ " spatial_scale" ,
1474
+ " sampling_ratio" ,
1463
1475
" aligned" };
1464
1476
for (auto const attr : attrs) {
1465
1477
if (!desc.HasAttr (attr)) return false ;
@@ -1641,10 +1653,10 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
1641
1653
auto x_var_name = desc.Input (" X" )[0 ];
1642
1654
auto * x_var_desc = block->FindVar (x_var_name);
1643
1655
const auto x_shape = x_var_desc->GetShape ();
1644
- int input_num = std::accumulate (x_shape. begin () + 1 , x_shape. end (), 1 ,
1645
- std::multiplies<int >());
1646
- int shape_num = std::accumulate (shape. begin () + 1 , shape. end (), 1 ,
1647
- std::multiplies<int >());
1656
+ int input_num = std::accumulate (
1657
+ x_shape. begin () + 1 , x_shape. end (), 1 , std::multiplies<int >());
1658
+ int shape_num = std::accumulate (
1659
+ shape. begin () + 1 , shape. end (), 1 , std::multiplies<int >());
1648
1660
if (input_num == shape_num) {
1649
1661
return true ;
1650
1662
}
@@ -1751,6 +1763,36 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
1751
1763
}
1752
1764
#endif
1753
1765
1766
+ if (op_type == " cast" ) {
1767
+ // trt 6015 result in Windows ppyolo_mbv3 TRT fp32 diff
1768
+ #if !IS_TRT_VERSION_GE(7000)
1769
+ return false ;
1770
+ #endif
1771
+ if (!(desc.HasAttr (" in_dtype" ) && desc.HasAttr (" out_dtype" ))) {
1772
+ VLOG (3 ) << " the " << op_type
1773
+ << " does not have attr (in_dtype or "
1774
+ " out_dtype)" ;
1775
+ return false ;
1776
+ }
1777
+ int in_dtype = BOOST_GET_CONST (int , desc.GetAttr (" in_dtype" ));
1778
+ int out_dtype = BOOST_GET_CONST (int , desc.GetAttr (" out_dtype" ));
1779
+ if ((in_dtype == 4 || in_dtype == 5 ) && out_dtype == 4 ) {
1780
+ VLOG (3 ) << " unsupport data type conversion" ;
1781
+ return false ;
1782
+ }
1783
+ if (in_dtype == 0 ) {
1784
+ VLOG (3 ) << " do not support input data type as bool now" ;
1785
+ return false ;
1786
+ }
1787
+ if (!((in_dtype == 5 || in_dtype == 4 || in_dtype == 2 ) &&
1788
+ (out_dtype == 5 || out_dtype == 4 || out_dtype == 2 ))) {
1789
+ VLOG (3 )
1790
+ << " only valid conversions are: "
1791
+ " (kFLOAT | kHALF | kINT32 | kBOOL) -> (kFLOAT | kHALF | kINT32)" ;
1792
+ return false ;
1793
+ }
1794
+ }
1795
+
1754
1796
if (op_type == " conv3d" || op_type == " conv3d_transpose" ) {
1755
1797
if (desc.HasAttr (" padding_algorithm" )) {
1756
1798
std::string padding_algorithm =
0 commit comments