File tree Expand file tree Collapse file tree 2 files changed +10
-1
lines changed
paddle/fluid/inference/tensorrt/convert Expand file tree Collapse file tree 2 files changed +10
-1
lines changed Original file line number Diff line number Diff line change @@ -33,6 +33,7 @@ class Pool2dOpConverter : public OpConverter {
33
33
PADDLE_ENFORCE_EQ (op_desc.Output (" Out" ).size (), 1 );
34
34
auto * input1 = engine_->GetITensor (op_desc.Input (" X" )[0 ]);
35
35
36
+ bool global_pooling = boost::get<bool >(op_desc.GetAttr (" global_pooling" ));
36
37
std::string pool_type =
37
38
boost::get<std::string>(op_desc.GetAttr (" pooling_type" ));
38
39
std::vector<int > ksize =
@@ -42,7 +43,13 @@ class Pool2dOpConverter : public OpConverter {
42
43
std::vector<int > paddings =
43
44
boost::get<std::vector<int >>(op_desc.GetAttr (" paddings" ));
44
45
45
- const nvinfer1::DimsHW nv_ksize (ksize[0 ], ksize[1 ]);
46
+ nvinfer1::DimsHW nv_ksize (ksize[0 ], ksize[1 ]);
47
+ if (global_pooling == true ) {
48
+ nvinfer1::Dims input_shape = input1->getDimensions ();
49
+ int nbDims = input_shape.nbDims ;
50
+ nv_ksize.d [0 ] = input_shape.d [nbDims - 2 ];
51
+ nv_ksize.d [1 ] = input_shape.d [nbDims - 1 ];
52
+ }
46
53
const nvinfer1::DimsHW nv_strides (strides[0 ], strides[1 ]);
47
54
const nvinfer1::DimsHW nv_paddings (paddings[0 ], paddings[1 ]);
48
55
Original file line number Diff line number Diff line change @@ -40,11 +40,13 @@ TEST(Pool2dOpConverter, main) {
40
40
std::vector<int > strides ({2 , 2 });
41
41
std::vector<int > paddings ({0 , 0 });
42
42
std::string pooling_t = " max" ;
43
+ bool global_pooling = false ;
43
44
44
45
desc.SetAttr (" pooling_type" , pooling_t );
45
46
desc.SetAttr (" ksize" , ksize);
46
47
desc.SetAttr (" strides" , strides);
47
48
desc.SetAttr (" paddings" , paddings);
49
+ desc.SetAttr (" global_pooling" , global_pooling);
48
50
49
51
LOG (INFO) << " set OP" ;
50
52
validator.SetOp (*desc.Proto ());
You can’t perform that action at this time.
0 commit comments