Skip to content

Commit b637460

Browse files
committed
polish the code
test=develop
1 parent cb6ece6 commit b637460

File tree

5 files changed

+12
-23
lines changed

5 files changed

+12
-23
lines changed

paddle/fluid/operators/affine_channel_op.cc

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,11 @@ class AffineChannelOp : public framework::OperatorWithKernel {
8080

8181
PADDLE_ENFORCE_EQ(scale_dims.size(), 1UL);
8282
PADDLE_ENFORCE_EQ(b_dims.size(), 1UL);
83-
if (ctx->IsRuntime()) {
83+
if (ctx->IsRuntime() || scale_dims[0] > 0) {
8484
PADDLE_ENFORCE_EQ(scale_dims[0], C);
85+
}
86+
if (ctx->IsRuntime() || b_dims[0] > 0) {
8587
PADDLE_ENFORCE_EQ(b_dims[0], C);
86-
} else {
87-
if (scale_dims[0] > 0 && b_dims[0] > 0) {
88-
PADDLE_ENFORCE_EQ(scale_dims[0], C);
89-
PADDLE_ENFORCE_EQ(b_dims[0], C);
90-
}
9188
}
9289

9390
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));

paddle/fluid/operators/conv_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
6969
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
7070
for (size_t i = 0; i < strides.size(); ++i) {
7171
if ((!ctx->IsRuntime()) &&
72-
(in_dims[i + 2] == -1 || filter_dims[i + 2] == -1)) {
72+
(in_dims[i + 2] <= 0 || filter_dims[i + 2] <= 0)) {
7373
output_shape.push_back(-1);
7474
} else {
7575
output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2],

paddle/fluid/operators/roi_pool_op.cc

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,12 @@ class ROIPoolOp : public framework::OperatorWithKernel {
5050
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
5151
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
5252

53-
if (ctx->IsRuntime()) {
54-
PADDLE_ENFORCE_GT(pooled_height, 0,
55-
"The pooled output height must greater than 0");
56-
PADDLE_ENFORCE_GT(pooled_width, 0,
57-
"The pooled output width must greater than 0");
58-
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
59-
"The spatial scale must greater than 0");
60-
}
53+
PADDLE_ENFORCE_GT(pooled_height, 0,
54+
"The pooled output height must greater than 0");
55+
PADDLE_ENFORCE_GT(pooled_width, 0,
56+
"The pooled output width must greater than 0");
57+
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
58+
"The spatial scale must greater than 0");
6159

6260
auto out_dims = input_dims;
6361
out_dims[0] = rois_dims[0];

paddle/fluid/operators/row_conv_op.cc

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,10 @@ class RowConvOp : public framework::OperatorWithKernel {
4141
auto filter_dims = ctx->GetInputDim("Filter");
4242
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
4343
PADDLE_ENFORCE_EQ(filter_dims.size(), 2, "Input(Y)'s rank should be 2.");
44-
if (ctx->IsRuntime()) {
44+
if (ctx->IsRuntime() || (x_dims[1] > 0 && filter_dims[1] > 0)) {
4545
PADDLE_ENFORCE_EQ(
4646
x_dims[1], filter_dims[1],
4747
"The 2nd dimension of Input(X) and Input(Filter) should be same.");
48-
} else {
49-
if (x_dims[1] > 0 && filter_dims[1] > 0) {
50-
PADDLE_ENFORCE_EQ(
51-
x_dims[1], filter_dims[1],
52-
"The 2nd dimension of Input(X) and Input(Filter) should be same.");
53-
}
5448
}
5549

5650
ctx->SetOutputDim("Out", x_dims);

paddle/fluid/operators/unpool_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ class UnpoolOp : public framework::OperatorWithKernel {
102102

103103
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
104104
for (size_t i = 0; i < ksize.size(); ++i) {
105-
if (!ctx->IsRuntime() && in_x_dims[i + 2] == -1) {
105+
if (!ctx->IsRuntime() && in_x_dims[i + 2] <= 0) {
106106
output_shape.push_back(-1);
107107
} else {
108108
output_shape.push_back(UnpoolOutputSize(in_x_dims[i + 2], ksize[i],

0 commit comments

Comments
 (0)