@@ -40,7 +40,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
40
40
" tensor's rank." );
41
41
}
42
42
43
- auto out_dims = GetOutputShape (axes, x_dims);
43
+ auto out_dims = GetOutputShape (axes, x_dims, ctx );
44
44
ctx->SetOutputDim (" Out" , out_dims);
45
45
if (x_dims[0 ] == out_dims[0 ]) {
46
46
// Only pass LoD when the first dimension of output and Input(X)
@@ -50,7 +50,8 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
50
50
}
51
51
52
52
static framework::DDim GetOutputShape (const std::vector<int > squeeze_dims,
53
- const framework::DDim &in_dims) {
53
+ const framework::DDim &in_dims,
54
+ framework::InferShapeContext *ctx) {
54
55
size_t num_squeeze_dims = squeeze_dims.size ();
55
56
int cnt_squeezed_dims = 0 ;
56
57
bool should_squeeze[9 ] = {false };
@@ -71,9 +72,12 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
71
72
// Check current index, the upper limit has beed checked in line 36.
72
73
PADDLE_ENFORCE (current >= 0 ,
73
74
" Invalid axis, the negative axis is out of range." );
74
- PADDLE_ENFORCE (in_dims[current] == 1 ,
75
- " Invalid axis index, the axis that will be squeezed "
76
- " should be equal to 1." );
75
+
76
+ if (ctx->IsRuntime ()) {
77
+ PADDLE_ENFORCE (in_dims[current] == 1 ,
78
+ " Invalid axis index, the axis that will be squeezed "
79
+ " should be equal to 1." );
80
+ }
77
81
78
82
if (!(should_squeeze[current])) {
79
83
++cnt_squeezed_dims;
0 commit comments