Skip to content

Commit 591f087

Browse files
Merge pull request #16932 from SunGaofeng/infershape14
Infer shape of pad_op pad_constant_like_op for version 1.4.0
2 parents 923c337 + 07844e3 commit 591f087

File tree

2 files changed

+29
-10
lines changed

2 files changed

+29
-10
lines changed

paddle/fluid/operators/pad_constant_like_op.cc

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/pad_constant_like_op.h"
16+
#include <memory>
1617

1718
namespace paddle {
1819
namespace operators {
@@ -38,8 +39,16 @@ class PadConstantLikeOp : public framework::OperatorWithKernel {
3839
"The dimention of X and Y should be the same.");
3940

4041
for (int i = 0; i < x_dim.size(); ++i) {
41-
PADDLE_ENFORCE_GE(x_dim[i], y_dim[i]);
42+
if ((!ctx->IsRuntime()) && ((x_dim[i] == -1) || (y_dim[i] == -1))) {
43+
continue;
44+
} else {
45+
PADDLE_ENFORCE_GE(
46+
x_dim[i], y_dim[i],
47+
"expected X_dim[i] >= Y_dim[i], but received %d < %d for dim %d",
48+
x_dim[i], y_dim[i], i);
49+
}
4250
}
51+
4352
ctx->SetOutputDim("Out", x_dim);
4453
ctx->ShareLoD("X", /*->*/ "Out");
4554
}
@@ -162,7 +171,14 @@ class PadConstantLikeOpGrad : public framework::OperatorWithKernel {
162171
ctx->ShareLoD("Y", /*->*/ y_grad_name);
163172

164173
for (int i = 0; i < y_dim.size(); ++i) {
165-
PADDLE_ENFORCE_GE(dout_dim[i], y_dim[i]);
174+
if ((!ctx->IsRuntime()) && ((dout_dim[i] == -1) || (y_dim[i] == -1))) {
175+
continue;
176+
} else {
177+
PADDLE_ENFORCE_GE(dout_dim[i], y_dim[i],
178+
"expected Out_dim[i] >= Y_dim[i], but received %d "
179+
"< %d for dim %d",
180+
dout_dim[i], y_dim[i], i);
181+
}
166182
}
167183
}
168184
}

paddle/fluid/operators/pad_op.cc

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,16 @@ class PadOp : public framework::OperatorWithKernel {
3434
PADDLE_ENFORCE_EQ(x_dim.size() * 2, int64_t(paddings.size()),
3535
"Size of paddings should be equal to 2 * dimension size "
3636
"of input tensor.");
37+
for (size_t i = 0; i < paddings.size(); ++i) {
38+
PADDLE_ENFORCE_GE(paddings[i], 0, "paddings should >= 0.");
39+
}
3740
std::vector<int64_t> out_dims(x_dim.size());
3841
for (int i = 0; i < x_dim.size(); ++i) {
39-
out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1];
42+
if ((!ctx->IsRuntime()) && (x_dim[i] == -1)) {
43+
out_dims[i] = -1;
44+
} else {
45+
out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1];
46+
}
4047
}
4148
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
4249
if (out_dims[0] == x_dim[0]) {
@@ -100,18 +107,14 @@ class PadOpGrad : public framework::OperatorWithKernel {
100107
using framework::OperatorWithKernel::OperatorWithKernel;
101108

102109
void InferShape(framework::InferShapeContext* ctx) const override {
103-
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
104-
auto& paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
105-
for (int i = 0; i < dout_dims.size(); ++i) {
106-
dout_dims[i] -= (paddings[i * 2] + paddings[i * 2 + 1]);
107-
}
108-
109110
auto x_grad_name = framework::GradVarName("X");
110111
if (ctx->HasOutput(x_grad_name)) {
111112
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
112113
auto& paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
113114
for (int i = 0; i < dout_dims.size(); ++i) {
114-
dout_dims[i] -= (paddings[i * 2] + paddings[i * 2 + 1]);
115+
if (ctx->IsRuntime() || (dout_dims[i] != -1)) {
116+
dout_dims[i] -= (paddings[i * 2] + paddings[i * 2 + 1]);
117+
}
115118
}
116119
ctx->SetOutputDim(x_grad_name, dout_dims);
117120
}

0 commit comments

Comments
 (0)