Skip to content

Commit 534beb5

Browse files
committed
fix for itnerpolate. test=release/1.4
1 parent 70a967d commit 534beb5

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

paddle/fluid/operators/interpolate_op.cc

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class InterpolateOp : public framework::OperatorWithKernel {
4040
int out_h = ctx->Attrs().Get<int>("out_h");
4141
int out_w = ctx->Attrs().Get<int>("out_w");
4242
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4");
43+
PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0.");
44+
PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0.");
4345

4446
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
4547
auto out_size_dim = ctx->GetInputDim("OutSize");
@@ -50,12 +52,8 @@ class InterpolateOp : public framework::OperatorWithKernel {
5052
return;
5153
}
5254

53-
if (ctx->IsRuntime() || (out_h > 0 && out_w > 0)) {
54-
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
55-
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
56-
} else {
57-
ctx->SetOutputDim("Out", dim_x);
58-
}
55+
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
56+
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
5957
}
6058

6159
protected:

0 commit comments

Comments
 (0)