Skip to content

Commit 6a5545a

Browse files
committed
fix squeeze shape check; test=develop
1 parent 190cfd6 commit 6a5545a

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

paddle/fluid/operators/squeeze_op.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
4040
"tensor's rank.");
4141
}
4242

43-
auto out_dims = GetOutputShape(axes, x_dims, ctx);
43+
auto out_dims = GetOutputShape(axes, x_dims, false);
4444
ctx->SetOutputDim("Out", out_dims);
4545
if (x_dims[0] == out_dims[0]) {
4646
// Only pass LoD when the first dimension of output and Input(X)
@@ -51,7 +51,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
5151

5252
static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
5353
const framework::DDim &in_dims,
54-
framework::InferShapeContext *ctx) {
54+
bool is_runtime) {
5555
size_t num_squeeze_dims = squeeze_dims.size();
5656
int cnt_squeezed_dims = 0;
5757
bool should_squeeze[9] = {false};
@@ -73,7 +73,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
7373
PADDLE_ENFORCE(current >= 0,
7474
"Invalid axis, the negative axis is out of range.");
7575

76-
if (ctx->IsRuntime()) {
76+
if (is_runtime) {
7777
PADDLE_ENFORCE(in_dims[current] == 1,
7878
"Invalid axis index, the axis that will be squeezed "
7979
"should be equal to 1.");
@@ -108,7 +108,7 @@ class SqueezeOp : public framework::OperatorBase {
108108
const platform::Place &place) const override {
109109
auto &axes = Attr<std::vector<int>>("axes");
110110
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
111-
auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims);
111+
auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims, true);
112112

113113
framework::AttributeMap attrs;
114114
attrs["shape"] = framework::vectorize2int(out_dims);
@@ -228,7 +228,7 @@ class Squeeze2Op : public framework::OperatorBase {
228228
const platform::Place &place) const override {
229229
auto &axes = Attr<std::vector<int>>("axes");
230230
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
231-
auto out_dims = Squeeze2OpInferShape::GetOutputShape(axes, x_dims);
231+
auto out_dims = Squeeze2OpInferShape::GetOutputShape(axes, x_dims, true);
232232

233233
framework::AttributeMap attrs;
234234
attrs["shape"] = framework::vectorize2int(out_dims);

0 commit comments

Comments
 (0)