@@ -40,7 +40,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
40
40
" tensor's rank." );
41
41
}
42
42
43
- auto out_dims = GetOutputShape (axes, x_dims);
43
+ auto out_dims = GetOutputShape (axes, x_dims, false );
44
44
ctx->SetOutputDim (" Out" , out_dims);
45
45
if (x_dims[0 ] == out_dims[0 ]) {
46
46
// Only pass LoD when the first dimension of output and Input(X)
@@ -50,7 +50,8 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
50
50
}
51
51
52
52
static framework::DDim GetOutputShape (const std::vector<int > squeeze_dims,
53
- const framework::DDim &in_dims) {
53
+ const framework::DDim &in_dims,
54
+ bool is_runtime) {
54
55
size_t num_squeeze_dims = squeeze_dims.size ();
55
56
int cnt_squeezed_dims = 0 ;
56
57
bool should_squeeze[9 ] = {false };
@@ -71,9 +72,12 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
71
72
// Check current index, the upper limit has beed checked in line 36.
72
73
PADDLE_ENFORCE (current >= 0 ,
73
74
" Invalid axis, the negative axis is out of range." );
74
- PADDLE_ENFORCE (in_dims[current] == 1 ,
75
- " Invalid axis index, the axis that will be squeezed "
76
- " should be equal to 1." );
75
+
76
+ if (is_runtime) {
77
+ PADDLE_ENFORCE (in_dims[current] == 1 ,
78
+ " Invalid axis index, the axis that will be squeezed "
79
+ " should be equal to 1." );
80
+ }
77
81
78
82
if (!(should_squeeze[current])) {
79
83
++cnt_squeezed_dims;
@@ -103,7 +107,7 @@ class SqueezeOp : public framework::OperatorBase {
103
107
const platform::Place &place) const override {
104
108
auto &axes = Attr<std::vector<int >>(" axes" );
105
109
auto x_dims = scope.FindVar (Input (" X" ))->Get <framework::LoDTensor>().dims ();
106
- auto out_dims = SqueezeOpInferShape::GetOutputShape (axes, x_dims);
110
+ auto out_dims = SqueezeOpInferShape::GetOutputShape (axes, x_dims, true );
107
111
108
112
framework::AttributeMap attrs;
109
113
attrs[" shape" ] = framework::vectorize2int (out_dims);
@@ -223,7 +227,7 @@ class Squeeze2Op : public framework::OperatorBase {
223
227
const platform::Place &place) const override {
224
228
auto &axes = Attr<std::vector<int >>(" axes" );
225
229
auto x_dims = scope.FindVar (Input (" X" ))->Get <framework::LoDTensor>().dims ();
226
- auto out_dims = Squeeze2OpInferShape::GetOutputShape (axes, x_dims);
230
+ auto out_dims = Squeeze2OpInferShape::GetOutputShape (axes, x_dims, true );
227
231
228
232
framework::AttributeMap attrs;
229
233
attrs[" shape" ] = framework::vectorize2int (out_dims);
0 commit comments