Skip to content

Commit e17ce37

Browse files
authored
Merge pull request #16388 from phlrain/pick_squeeze
Merge pull request #16348 from phlrain/fix_squeeze_check
2 parents 929fb85 + f7d979f commit e17ce37

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

paddle/fluid/operators/squeeze_op.cc

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
4040
"tensor's rank.");
4141
}
4242

43-
auto out_dims = GetOutputShape(axes, x_dims);
43+
auto out_dims = GetOutputShape(axes, x_dims, false);
4444
ctx->SetOutputDim("Out", out_dims);
4545
if (x_dims[0] == out_dims[0]) {
4646
// Only pass LoD when the first dimension of output and Input(X)
@@ -50,7 +50,8 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
5050
}
5151

5252
static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
53-
const framework::DDim &in_dims) {
53+
const framework::DDim &in_dims,
54+
bool is_runtime) {
5455
size_t num_squeeze_dims = squeeze_dims.size();
5556
int cnt_squeezed_dims = 0;
5657
bool should_squeeze[9] = {false};
@@ -71,9 +72,12 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
7172
// Check current index, the upper limit has beed checked in line 36.
7273
PADDLE_ENFORCE(current >= 0,
7374
"Invalid axis, the negative axis is out of range.");
74-
PADDLE_ENFORCE(in_dims[current] == 1,
75-
"Invalid axis index, the axis that will be squeezed "
76-
"should be equal to 1.");
75+
76+
if (is_runtime) {
77+
PADDLE_ENFORCE(in_dims[current] == 1,
78+
"Invalid axis index, the axis that will be squeezed "
79+
"should be equal to 1.");
80+
}
7781

7882
if (!(should_squeeze[current])) {
7983
++cnt_squeezed_dims;
@@ -103,7 +107,7 @@ class SqueezeOp : public framework::OperatorBase {
103107
const platform::Place &place) const override {
104108
auto &axes = Attr<std::vector<int>>("axes");
105109
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
106-
auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims);
110+
auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims, true);
107111

108112
framework::AttributeMap attrs;
109113
attrs["shape"] = framework::vectorize2int(out_dims);
@@ -223,7 +227,7 @@ class Squeeze2Op : public framework::OperatorBase {
223227
const platform::Place &place) const override {
224228
auto &axes = Attr<std::vector<int>>("axes");
225229
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
226-
auto out_dims = Squeeze2OpInferShape::GetOutputShape(axes, x_dims);
230+
auto out_dims = Squeeze2OpInferShape::GetOutputShape(axes, x_dims, true);
227231

228232
framework::AttributeMap attrs;
229233
attrs["shape"] = framework::vectorize2int(out_dims);

0 commit comments

Comments
 (0)