@@ -61,10 +61,12 @@ class UnfoldOp : public framework::OperatorWithKernel {
61
61
public:
62
62
using framework::OperatorWithKernel::OperatorWithKernel;
63
63
void InferShape (framework::InferShapeContext* ctx) const override {
64
- PADDLE_ENFORCE (ctx->HasInput (" X" ),
65
- " Input(X) of UnfoldOp should not be null" );
66
- PADDLE_ENFORCE (ctx->HasOutput (" Y" ),
67
- " Output(Y) of UnfoldOp should not be null" );
64
+ PADDLE_ENFORCE_EQ (
65
+ ctx->HasInput (" X" ), true ,
66
+ platform::errors::NotFound (" Input(X) of UnfoldOp should not be null" ));
67
+ PADDLE_ENFORCE_EQ (
68
+ ctx->HasOutput (" Y" ), true ,
69
+ platform::errors::NotFound (" Output(Y) of UnfoldOp should not be null" ));
68
70
auto in_dims = ctx->GetInputDim (" X" );
69
71
std::vector<int > kernel_sizes =
70
72
ctx->Attrs ().Get <std::vector<int >>(" kernel_sizes" );
@@ -74,31 +76,36 @@ class UnfoldOp : public framework::OperatorWithKernel {
74
76
ctx->Attrs ().Get <std::vector<int >>(" dilations" );
75
77
76
78
// Only [N, C, H, W] input supported now
77
- PADDLE_ENFORCE (
78
- in_dims.size () == 4 ,
79
- " Input should be 4-D tensor of format [N, C, H, W], but get %u" ,
80
- in_dims.size ());
81
- PADDLE_ENFORCE (
82
- in_dims.size () - kernel_sizes.size () == 2U ,
83
- " The dims of X should be larger than that of kernel_sizes "
84
- " by a number of 2, due to the batch size and input channel dim. "
85
- " But recieved dims(X:%u) - dims(kernel_sizes:%u) != 2" ,
86
- in_dims.size (), kernel_sizes.size ());
79
+ PADDLE_ENFORCE_EQ (
80
+ in_dims.size (), 4 ,
81
+ platform::errors::InvalidArgument (
82
+ " Input should be 4-D tensor of format [N, C, H, W], but get %u" ,
83
+ in_dims.size ()));
84
+ PADDLE_ENFORCE_EQ (
85
+ in_dims.size () - kernel_sizes.size (), 2U ,
86
+ platform::errors::InvalidArgument (
87
+ " The dims of X should be larger than that of kernel_sizes "
88
+ " by a number of 2, due to the batch size and input channel dim. "
89
+ " But recieved dims(X:%u) - dims(kernel_sizes:%u) != 2" ,
90
+ in_dims.size (), kernel_sizes.size ()));
87
91
PADDLE_ENFORCE_EQ (
88
92
strides.size (), kernel_sizes.size (),
89
- " The dims of strides should be the same with that of kernel_sizes. "
90
- " But recieved dims(strides: %u) != dims(kernel_sizes: %u)." ,
91
- strides.size (), kernel_sizes.size ());
93
+ platform::errors::InvalidArgument (
94
+ " The dims of strides should be the same with that of kernel_sizes. "
95
+ " But recieved dims(strides: %u) != dims(kernel_sizes: %u)." ,
96
+ strides.size (), kernel_sizes.size ()));
92
97
PADDLE_ENFORCE_EQ (
93
98
paddings.size (), 2 * strides.size (),
94
- " The dims of paddings should be 2 times of that of strides. "
95
- " But recieved dims(paddings: %u) != 2*dims(strides: %u)." ,
96
- paddings.size (), strides.size ());
99
+ platform::errors::InvalidArgument (
100
+ " The dims of paddings should be 2 times of that of strides. "
101
+ " But recieved dims(paddings: %u) != 2*dims(strides: %u)." ,
102
+ paddings.size (), strides.size ()));
97
103
PADDLE_ENFORCE_EQ (
98
104
strides.size (), dilations.size (),
99
- " The dims of strides should be the same with that of dilations. "
100
- " But recieved dims(strides: %u) != dims(dilations: %u)." ,
101
- strides.size (), dilations.size ());
105
+ platform::errors::InvalidArgument (
106
+ " The dims of strides should be the same with that of dilations. "
107
+ " But recieved dims(strides: %u) != dims(dilations: %u)." ,
108
+ strides.size (), dilations.size ()));
102
109
103
110
std::vector<int > out_dims;
104
111
out_dims.push_back (in_dims[0 ]);
@@ -131,11 +138,15 @@ class UnfoldGradOp : public framework::OperatorWithKernel {
131
138
using framework::OperatorWithKernel::OperatorWithKernel;
132
139
133
140
void InferShape (framework::InferShapeContext* ctx) const override {
134
- PADDLE_ENFORCE (ctx->HasInput (framework::GradVarName (" Y" )),
135
- " The gradient of Y should not be null" );
136
- PADDLE_ENFORCE (ctx->HasInput (" X" ), " The input X should not be null" );
137
- PADDLE_ENFORCE (ctx->HasOutput (framework::GradVarName (" X" )),
138
- " The gradient of X should not be null" );
141
+ PADDLE_ENFORCE_EQ (
142
+ ctx->HasInput (framework::GradVarName (" Y" )), true ,
143
+ platform::errors::NotFound (" The gradient of Y should not be null" ));
144
+ PADDLE_ENFORCE_EQ (
145
+ ctx->HasInput (" X" ), true ,
146
+ platform::errors::NotFound (" The input X should not be null" ));
147
+ PADDLE_ENFORCE_EQ (
148
+ ctx->HasOutput (framework::GradVarName (" X" )), true ,
149
+ platform::errors::NotFound (" The gradient of X should not be null" ));
139
150
ctx->SetOutputDim (framework::GradVarName (" X" ), ctx->GetInputDim (" X" ));
140
151
}
141
152
0 commit comments