Skip to content

Commit fc88437

Browse files
committed
modify infer shape in pad_op.cc, pad_constant_like_op.cc. No need in psroi_pool_op.cc, crop_op.cc
1 parent 3063449 commit fc88437

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

paddle/fluid/operators/pad_constant_like_op.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/pad_constant_like_op.h"
16+
#include <memory>
1617

1718
namespace paddle {
1819
namespace operators {
@@ -38,8 +39,13 @@ class PadConstantLikeOp : public framework::OperatorWithKernel {
3839
"The dimention of X and Y should be the same.");
3940

4041
for (int i = 0; i < x_dim.size(); ++i) {
41-
PADDLE_ENFORCE_GE(x_dim[i], y_dim[i]);
42+
if ((!ctx->IsRuntime()) && ((x_dim[i] == -1) || (y_dim[i] == -1))) {
43+
continue;
44+
} else {
45+
PADDLE_ENFORCE_GE(x_dim[i], y_dim[i]);
46+
}
4247
}
48+
4349
ctx->SetOutputDim("Out", x_dim);
4450
ctx->ShareLoD("X", /*->*/ "Out");
4551
}
@@ -162,7 +168,11 @@ class PadConstantLikeOpGrad : public framework::OperatorWithKernel {
162168
ctx->ShareLoD("Y", /*->*/ y_grad_name);
163169

164170
for (int i = 0; i < y_dim.size(); ++i) {
165-
PADDLE_ENFORCE_GE(dout_dim[i], y_dim[i]);
171+
if ((!ctx->IsRuntime()) && ((dout_dim[i] == -1) || (y_dim[i] == -1))) {
172+
continue;
173+
} else {
174+
PADDLE_ENFORCE_GE(dout_dim[i], y_dim[i]);
175+
}
166176
}
167177
}
168178
}

paddle/fluid/operators/pad_op.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ class PadOp : public framework::OperatorWithKernel {
3636
"of input tensor.");
3737
std::vector<int64_t> out_dims(x_dim.size());
3838
for (int i = 0; i < x_dim.size(); ++i) {
39-
out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1];
39+
if ((!ctx->IsRuntime()) && (x_dim[i] == -1)) {
40+
out_dims[i] = -1;
41+
} else {
42+
out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1];
43+
}
4044
}
4145
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
4246
if (out_dims[0] == x_dim[0]) {
@@ -100,18 +104,14 @@ class PadOpGrad : public framework::OperatorWithKernel {
100104
using framework::OperatorWithKernel::OperatorWithKernel;
101105

102106
void InferShape(framework::InferShapeContext* ctx) const override {
103-
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
104-
auto& paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
105-
for (int i = 0; i < dout_dims.size(); ++i) {
106-
dout_dims[i] -= (paddings[i * 2] + paddings[i * 2 + 1]);
107-
}
108-
109107
auto x_grad_name = framework::GradVarName("X");
110108
if (ctx->HasOutput(x_grad_name)) {
111109
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
112110
auto& paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
113111
for (int i = 0; i < dout_dims.size(); ++i) {
114-
dout_dims[i] -= (paddings[i * 2] + paddings[i * 2 + 1]);
112+
if (ctx->IsRuntime() || (dout_dims[i] != -1)) {
113+
dout_dims[i] -= (paddings[i * 2] + paddings[i * 2 + 1]);
114+
}
115115
}
116116
ctx->SetOutputDim(x_grad_name, dout_dims);
117117
}

0 commit comments

Comments
 (0)