Skip to content

Commit 66655ad

Browse files
authored
test=release/1.8 cherry-pick unfold_op (#24505)
1 parent ec0f78a commit 66655ad

File tree

3 files changed

+50
-34
lines changed

3 files changed

+50
-34
lines changed

paddle/fluid/operators/unfold_op.cc

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,12 @@ class UnfoldOp : public framework::OperatorWithKernel {
6161
public:
6262
using framework::OperatorWithKernel::OperatorWithKernel;
6363
void InferShape(framework::InferShapeContext* ctx) const override {
64-
PADDLE_ENFORCE(ctx->HasInput("X"),
65-
"Input(X) of UnfoldOp should not be null");
66-
PADDLE_ENFORCE(ctx->HasOutput("Y"),
67-
"Output(Y) of UnfoldOp should not be null");
64+
PADDLE_ENFORCE_EQ(
65+
ctx->HasInput("X"), true,
66+
platform::errors::NotFound("Input(X) of UnfoldOp should not be null"));
67+
PADDLE_ENFORCE_EQ(
68+
ctx->HasOutput("Y"), true,
69+
platform::errors::NotFound("Output(Y) of UnfoldOp should not be null"));
6870
auto in_dims = ctx->GetInputDim("X");
6971
std::vector<int> kernel_sizes =
7072
ctx->Attrs().Get<std::vector<int>>("kernel_sizes");
@@ -74,31 +76,36 @@ class UnfoldOp : public framework::OperatorWithKernel {
7476
ctx->Attrs().Get<std::vector<int>>("dilations");
7577

7678
// Only [N, C, H, W] input supported now
77-
PADDLE_ENFORCE(
78-
in_dims.size() == 4,
79-
"Input should be 4-D tensor of format [N, C, H, W], but get %u",
80-
in_dims.size());
81-
PADDLE_ENFORCE(
82-
in_dims.size() - kernel_sizes.size() == 2U,
83-
"The dims of X should be larger than that of kernel_sizes "
84-
"by a number of 2, due to the batch size and input channel dim. "
85-
"But recieved dims(X:%u) - dims(kernel_sizes:%u) != 2",
86-
in_dims.size(), kernel_sizes.size());
79+
PADDLE_ENFORCE_EQ(
80+
in_dims.size(), 4,
81+
platform::errors::InvalidArgument(
82+
"Input should be 4-D tensor of format [N, C, H, W], but get %u",
83+
in_dims.size()));
84+
PADDLE_ENFORCE_EQ(
85+
in_dims.size() - kernel_sizes.size(), 2U,
86+
platform::errors::InvalidArgument(
87+
"The dims of X should be larger than that of kernel_sizes "
88+
"by a number of 2, due to the batch size and input channel dim. "
89+
"But recieved dims(X:%u) - dims(kernel_sizes:%u) != 2",
90+
in_dims.size(), kernel_sizes.size()));
8791
PADDLE_ENFORCE_EQ(
8892
strides.size(), kernel_sizes.size(),
89-
"The dims of strides should be the same with that of kernel_sizes. "
90-
"But recieved dims(strides: %u) != dims(kernel_sizes: %u).",
91-
strides.size(), kernel_sizes.size());
93+
platform::errors::InvalidArgument(
94+
"The dims of strides should be the same with that of kernel_sizes. "
95+
"But recieved dims(strides: %u) != dims(kernel_sizes: %u).",
96+
strides.size(), kernel_sizes.size()));
9297
PADDLE_ENFORCE_EQ(
9398
paddings.size(), 2 * strides.size(),
94-
"The dims of paddings should be 2 times of that of strides. "
95-
"But recieved dims(paddings: %u) != 2*dims(strides: %u).",
96-
paddings.size(), strides.size());
99+
platform::errors::InvalidArgument(
100+
"The dims of paddings should be 2 times of that of strides. "
101+
"But recieved dims(paddings: %u) != 2*dims(strides: %u).",
102+
paddings.size(), strides.size()));
97103
PADDLE_ENFORCE_EQ(
98104
strides.size(), dilations.size(),
99-
"The dims of strides should be the same with that of dilations. "
100-
"But recieved dims(strides: %u) != dims(dilations: %u).",
101-
strides.size(), dilations.size());
105+
platform::errors::InvalidArgument(
106+
"The dims of strides should be the same with that of dilations. "
107+
"But recieved dims(strides: %u) != dims(dilations: %u).",
108+
strides.size(), dilations.size()));
102109

103110
std::vector<int> out_dims;
104111
out_dims.push_back(in_dims[0]);
@@ -131,11 +138,15 @@ class UnfoldGradOp : public framework::OperatorWithKernel {
131138
using framework::OperatorWithKernel::OperatorWithKernel;
132139

133140
void InferShape(framework::InferShapeContext* ctx) const override {
134-
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
135-
"The gradient of Y should not be null");
136-
PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null");
137-
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
138-
"The gradient of X should not be null");
141+
PADDLE_ENFORCE_EQ(
142+
ctx->HasInput(framework::GradVarName("Y")), true,
143+
platform::errors::NotFound("The gradient of Y should not be null"));
144+
PADDLE_ENFORCE_EQ(
145+
ctx->HasInput("X"), true,
146+
platform::errors::NotFound("The input X should not be null"));
147+
PADDLE_ENFORCE_EQ(
148+
ctx->HasOutput(framework::GradVarName("X")), true,
149+
platform::errors::NotFound("The gradient of X should not be null"));
139150
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
140151
}
141152

paddle/fluid/operators/unfold_op.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,15 @@ inline int CalcOutputSize(int input_size, int filter_size, int dilation,
2929
int padding1, int padding2, int stride) {
3030
const int dkernel = dilation * (filter_size - 1) + 1;
3131
int output_size = (input_size + padding1 + padding2 - dkernel) / stride + 1;
32-
PADDLE_ENFORCE(output_size > 0,
33-
"Due to the settings of padding(%d, %d), filter_size(%d), "
34-
"dilation(%d) and "
35-
"stride(%d), the output size is less than 0, please check "
36-
"again. Input_size:%d",
37-
padding1, padding2, filter_size, dilation, stride, input_size);
32+
33+
PADDLE_ENFORCE_GT(
34+
output_size, 0UL,
35+
platform::errors::InvalidArgument(
36+
"Due to the settings of padding(%d, %d), filter_size(%d), "
37+
"dilation(%d) and "
38+
"stride(%d), the output size is less than 0, please check "
39+
"again. Input_size:%d",
40+
padding1, padding2, filter_size, dilation, stride, input_size));
3841

3942
return output_size;
4043
}

python/paddle/fluid/layers/nn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15257,6 +15257,8 @@ def unfold(x, kernel_sizes, strides=1, paddings=0, dilations=1, name=None):
1525715257

1525815258
helper = LayerHelper("unfold", **locals())
1525915259

15260+
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'unfold')
15261+
1526015262
assert len(x.shape) == 4, \
1526115263
"input should be the format of [N, C, H, W]"
1526215264

0 commit comments

Comments
 (0)