@@ -40,7 +40,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
40
40
" tensor's rank." );
41
41
}
42
42
43
- auto out_dims = GetOutputShape (axes, x_dims, ctx );
43
+ auto out_dims = GetOutputShape (axes, x_dims, false );
44
44
ctx->SetOutputDim (" Out" , out_dims);
45
45
if (x_dims[0 ] == out_dims[0 ]) {
46
46
// Only pass LoD when the first dimension of output and Input(X)
@@ -51,7 +51,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
51
51
52
52
static framework::DDim GetOutputShape (const std::vector<int > squeeze_dims,
53
53
const framework::DDim &in_dims,
54
- framework::InferShapeContext *ctx ) {
54
+ bool is_runtime ) {
55
55
size_t num_squeeze_dims = squeeze_dims.size ();
56
56
int cnt_squeezed_dims = 0 ;
57
57
bool should_squeeze[9 ] = {false };
@@ -73,7 +73,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
73
73
PADDLE_ENFORCE (current >= 0 ,
74
74
" Invalid axis, the negative axis is out of range." );
75
75
76
- if (ctx-> IsRuntime () ) {
76
+ if (is_runtime ) {
77
77
PADDLE_ENFORCE (in_dims[current] == 1 ,
78
78
" Invalid axis index, the axis that will be squeezed "
79
79
" should be equal to 1." );
@@ -108,7 +108,7 @@ class SqueezeOp : public framework::OperatorBase {
108
108
const platform::Place &place) const override {
109
109
auto &axes = Attr<std::vector<int >>(" axes" );
110
110
auto x_dims = scope.FindVar (Input (" X" ))->Get <framework::LoDTensor>().dims ();
111
- auto out_dims = SqueezeOpInferShape::GetOutputShape (axes, x_dims);
111
+ auto out_dims = SqueezeOpInferShape::GetOutputShape (axes, x_dims, true );
112
112
113
113
framework::AttributeMap attrs;
114
114
attrs[" shape" ] = framework::vectorize2int (out_dims);
@@ -228,7 +228,7 @@ class Squeeze2Op : public framework::OperatorBase {
228
228
const platform::Place &place) const override {
229
229
auto &axes = Attr<std::vector<int >>(" axes" );
230
230
auto x_dims = scope.FindVar (Input (" X" ))->Get <framework::LoDTensor>().dims ();
231
- auto out_dims = Squeeze2OpInferShape::GetOutputShape (axes, x_dims);
231
+ auto out_dims = Squeeze2OpInferShape::GetOutputShape (axes, x_dims, true );
232
232
233
233
framework::AttributeMap attrs;
234
234
attrs[" shape" ] = framework::vectorize2int (out_dims);
0 commit comments