Skip to content

Commit f4f4a5e

Browse files
authored
Merge pull request #16891 from tink2123/cherry-pick
[Cherry pick] modified infer shape
2 parents 9845b22 + b637460 commit f4f4a5e

File tree

5 files changed

+31
-12
lines changed

5 files changed

+31
-12
lines changed

paddle/fluid/operators/affine_channel_op.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,13 @@ class AffineChannelOp : public framework::OperatorWithKernel {
7979
: x_dims[x_dims.size() - 1]);
8080

8181
PADDLE_ENFORCE_EQ(scale_dims.size(), 1UL);
82-
PADDLE_ENFORCE_EQ(scale_dims[0], C);
8382
PADDLE_ENFORCE_EQ(b_dims.size(), 1UL);
84-
PADDLE_ENFORCE_EQ(b_dims[0], C);
83+
if (ctx->IsRuntime() || scale_dims[0] > 0) {
84+
PADDLE_ENFORCE_EQ(scale_dims[0], C);
85+
}
86+
if (ctx->IsRuntime() || b_dims[0] > 0) {
87+
PADDLE_ENFORCE_EQ(b_dims[0], C);
88+
}
8589

8690
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
8791
ctx->ShareLoD("X", "Out");

paddle/fluid/operators/conv_op.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,14 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
6868

6969
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
7070
for (size_t i = 0; i < strides.size(); ++i) {
71-
output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2],
72-
dilations[i], paddings[i],
73-
strides[i]));
71+
if ((!ctx->IsRuntime()) &&
72+
(in_dims[i + 2] <= 0 || filter_dims[i + 2] <= 0)) {
73+
output_shape.push_back(-1);
74+
} else {
75+
output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2],
76+
dilations[i], paddings[i],
77+
strides[i]));
78+
}
7479
}
7580
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
7681
ctx->ShareLoD("Input", "Output");

paddle/fluid/operators/detection_map_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
5151
PADDLE_ENFORCE_EQ(label_dims.size(), 2,
5252
"The rank of Input(Label) must be 2, "
5353
"the shape is [N, 6].");
54-
PADDLE_ENFORCE(label_dims[1] == 6 || label_dims[1] == 5,
55-
"The shape of Input(Label) is [N, 6] or [N, 5].");
54+
if (ctx->IsRuntime() || label_dims[1] > 0) {
55+
PADDLE_ENFORCE(label_dims[1] == 6 || label_dims[1] == 5,
56+
"The shape of Input(Label) is [N, 6] or [N, 5].");
57+
}
5658

5759
if (ctx->HasInput("PosCount")) {
5860
PADDLE_ENFORCE(ctx->HasInput("TruePos"),

paddle/fluid/operators/row_conv_op.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,12 @@ class RowConvOp : public framework::OperatorWithKernel {
4141
auto filter_dims = ctx->GetInputDim("Filter");
4242
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
4343
PADDLE_ENFORCE_EQ(filter_dims.size(), 2, "Input(Y)'s rank should be 2.");
44-
PADDLE_ENFORCE_EQ(
45-
x_dims[1], filter_dims[1],
46-
"The 2nd dimension of Input(X) and Input(Filter) should be same.");
44+
if (ctx->IsRuntime() || (x_dims[1] > 0 && filter_dims[1] > 0)) {
45+
PADDLE_ENFORCE_EQ(
46+
x_dims[1], filter_dims[1],
47+
"The 2nd dimension of Input(X) and Input(Filter) should be same.");
48+
}
49+
4750
ctx->SetOutputDim("Out", x_dims);
4851
ctx->ShareLoD("X", "Out");
4952
}

paddle/fluid/operators/unpool_op.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,15 @@ class UnpoolOp : public framework::OperatorWithKernel {
9999
PADDLE_ENFORCE(in_x_dims.size() == 4,
100100
"Unpooling intput must be of 4-dimensional.");
101101
PADDLE_ENFORCE_EQ(in_x_dims, in_y_dims);
102+
102103
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
103104
for (size_t i = 0; i < ksize.size(); ++i) {
104-
output_shape.push_back(UnpoolOutputSize(in_x_dims[i + 2], ksize[i],
105-
paddings[i], strides[i]));
105+
if (!ctx->IsRuntime() && in_x_dims[i + 2] <= 0) {
106+
output_shape.push_back(-1);
107+
} else {
108+
output_shape.push_back(UnpoolOutputSize(in_x_dims[i + 2], ksize[i],
109+
paddings[i], strides[i]));
110+
}
106111
}
107112
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
108113
}

0 commit comments

Comments
 (0)