Skip to content

Commit cb6ece6

Browse files
committed
modified infer shape
test=develop
1 parent a5ef6bf commit cb6ece6

File tree

6 files changed

+48
-18
lines changed

6 files changed

+48
-18
lines changed

paddle/fluid/operators/affine_channel_op.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,16 @@ class AffineChannelOp : public framework::OperatorWithKernel {
7979
: x_dims[x_dims.size() - 1]);
8080

8181
PADDLE_ENFORCE_EQ(scale_dims.size(), 1UL);
82-
PADDLE_ENFORCE_EQ(scale_dims[0], C);
8382
PADDLE_ENFORCE_EQ(b_dims.size(), 1UL);
84-
PADDLE_ENFORCE_EQ(b_dims[0], C);
83+
if (ctx->IsRuntime()) {
84+
PADDLE_ENFORCE_EQ(scale_dims[0], C);
85+
PADDLE_ENFORCE_EQ(b_dims[0], C);
86+
} else {
87+
if (scale_dims[0] > 0 && b_dims[0] > 0) {
88+
PADDLE_ENFORCE_EQ(scale_dims[0], C);
89+
PADDLE_ENFORCE_EQ(b_dims[0], C);
90+
}
91+
}
8592

8693
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
8794
ctx->ShareLoD("X", "Out");

paddle/fluid/operators/conv_op.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,14 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
6868

6969
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
7070
for (size_t i = 0; i < strides.size(); ++i) {
71-
output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2],
72-
dilations[i], paddings[i],
73-
strides[i]));
71+
if ((!ctx->IsRuntime()) &&
72+
(in_dims[i + 2] == -1 || filter_dims[i + 2] == -1)) {
73+
output_shape.push_back(-1);
74+
} else {
75+
output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2],
76+
dilations[i], paddings[i],
77+
strides[i]));
78+
}
7479
}
7580
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
7681
ctx->ShareLoD("Input", "Output");

paddle/fluid/operators/detection_map_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
5151
PADDLE_ENFORCE_EQ(label_dims.size(), 2,
5252
"The rank of Input(Label) must be 2, "
5353
"the shape is [N, 6].");
54-
PADDLE_ENFORCE(label_dims[1] == 6 || label_dims[1] == 5,
55-
"The shape of Input(Label) is [N, 6] or [N, 5].");
54+
if (ctx->IsRuntime() || label_dims[1] > 0) {
55+
PADDLE_ENFORCE(label_dims[1] == 6 || label_dims[1] == 5,
56+
"The shape of Input(Label) is [N, 6] or [N, 5].");
57+
}
5658

5759
if (ctx->HasInput("PosCount")) {
5860
PADDLE_ENFORCE(ctx->HasInput("TruePos"),

paddle/fluid/operators/roi_pool_op.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,14 @@ class ROIPoolOp : public framework::OperatorWithKernel {
5050
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
5151
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
5252

53-
PADDLE_ENFORCE_GT(pooled_height, 0,
54-
"The pooled output height must greater than 0");
55-
PADDLE_ENFORCE_GT(pooled_width, 0,
56-
"The pooled output width must greater than 0");
57-
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
58-
"The spatial scale must greater than 0");
53+
if (ctx->IsRuntime()) {
54+
PADDLE_ENFORCE_GT(pooled_height, 0,
55+
"The pooled output height must greater than 0");
56+
PADDLE_ENFORCE_GT(pooled_width, 0,
57+
"The pooled output width must greater than 0");
58+
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
59+
"The spatial scale must greater than 0");
60+
}
5961

6062
auto out_dims = input_dims;
6163
out_dims[0] = rois_dims[0];

paddle/fluid/operators/row_conv_op.cc

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,18 @@ class RowConvOp : public framework::OperatorWithKernel {
4141
auto filter_dims = ctx->GetInputDim("Filter");
4242
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
4343
PADDLE_ENFORCE_EQ(filter_dims.size(), 2, "Input(Y)'s rank should be 2.");
44-
PADDLE_ENFORCE_EQ(
45-
x_dims[1], filter_dims[1],
46-
"The 2nd dimension of Input(X) and Input(Filter) should be same.");
44+
if (ctx->IsRuntime()) {
45+
PADDLE_ENFORCE_EQ(
46+
x_dims[1], filter_dims[1],
47+
"The 2nd dimension of Input(X) and Input(Filter) should be same.");
48+
} else {
49+
if (x_dims[1] > 0 && filter_dims[1] > 0) {
50+
PADDLE_ENFORCE_EQ(
51+
x_dims[1], filter_dims[1],
52+
"The 2nd dimension of Input(X) and Input(Filter) should be same.");
53+
}
54+
}
55+
4756
ctx->SetOutputDim("Out", x_dims);
4857
ctx->ShareLoD("X", "Out");
4958
}

paddle/fluid/operators/unpool_op.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,15 @@ class UnpoolOp : public framework::OperatorWithKernel {
9999
PADDLE_ENFORCE(in_x_dims.size() == 4,
100100
"Unpooling intput must be of 4-dimensional.");
101101
PADDLE_ENFORCE_EQ(in_x_dims, in_y_dims);
102+
102103
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
103104
for (size_t i = 0; i < ksize.size(); ++i) {
104-
output_shape.push_back(UnpoolOutputSize(in_x_dims[i + 2], ksize[i],
105-
paddings[i], strides[i]));
105+
if (!ctx->IsRuntime() && in_x_dims[i + 2] == -1) {
106+
output_shape.push_back(-1);
107+
} else {
108+
output_shape.push_back(UnpoolOutputSize(in_x_dims[i + 2], ksize[i],
109+
paddings[i], strides[i]));
110+
}
106111
}
107112
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
108113
}

0 commit comments

Comments
 (0)