@@ -20,17 +20,27 @@ class PixelShuffleOp : public framework::OperatorWithKernel {
20
20
using framework::OperatorWithKernel::OperatorWithKernel;
21
21
22
22
void InferShape (framework::InferShapeContext* ctx) const override {
23
- PADDLE_ENFORCE (ctx->HasInput (" X" ),
24
- " Input(X) of PixelShuffleOp should not be null." );
25
- PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
26
- " Output(Out) of PixelShuffleOp should not be null." );
23
+ PADDLE_ENFORCE_EQ (ctx->HasInput (" X" ), true ,
24
+ platform::errors::NotFound (
25
+ " Input(X) of PixelShuffleOp should not be null." ));
26
+ PADDLE_ENFORCE_EQ (ctx->HasOutput (" Out" ), true ,
27
+ platform::errors::NotFound (
28
+ " Output(Out) of PixelShuffleOp should not be null." ));
27
29
28
30
auto input_dims = ctx->GetInputDim (" X" );
29
- PADDLE_ENFORCE (input_dims.size () == 4 , " The layout of input is NCHW." );
31
+ PADDLE_ENFORCE_EQ (
32
+ input_dims.size (), 4 ,
33
+ platform::errors::InvalidArgument (
34
+ " Input should be a 4-D tensor of format [N, C, H, W], but got %u." ,
35
+ input_dims.size ()));
36
+
30
37
auto upscale_factor = ctx->Attrs ().Get <int >(" upscale_factor" );
31
38
32
- PADDLE_ENFORCE (input_dims[1 ] % (upscale_factor * upscale_factor) == 0 ,
33
- " Upscale_factor should devide the number of channel" );
39
+ PADDLE_ENFORCE_EQ (input_dims[1 ] % (upscale_factor * upscale_factor), 0 ,
40
+ platform::errors::InvalidArgument (
41
+ " The square of upscale_factor[%u] should divide the "
42
+ " number of channel[%u]" ,
43
+ input_dims[1 ], upscale_factor * upscale_factor));
34
44
35
45
auto output_dims = input_dims;
36
46
output_dims[0 ] = input_dims[0 ];
@@ -57,7 +67,8 @@ class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
57
67
.SetDefault (1 )
58
68
.AddCustomChecker ([](const int & upscale_factor) {
59
69
PADDLE_ENFORCE_GE (upscale_factor, 1 ,
60
- " upscale_factor should be larger than 0." );
70
+ platform::errors::InvalidArgument (
71
+ " upscale_factor should be larger than 0." ));
61
72
});
62
73
63
74
AddComment (R"DOC(
@@ -95,13 +106,19 @@ class PixelShuffleGradOp : public framework::OperatorWithKernel {
95
106
using framework::OperatorWithKernel::OperatorWithKernel;
96
107
97
108
void InferShape (framework::InferShapeContext* ctx) const override {
98
- PADDLE_ENFORCE (ctx->HasInput (framework::GradVarName (" Out" )),
99
- " Input(Out@Grad) should not be null" );
100
- PADDLE_ENFORCE (ctx->HasOutput (framework::GradVarName (" X" )),
101
- " Output(X@Grad) should not be null" );
109
+ PADDLE_ENFORCE_EQ (
110
+ ctx->HasInput (framework::GradVarName (" Out" )), true ,
111
+ platform::errors::NotFound (" Input(Out@Grad) should not be null" ));
112
+ PADDLE_ENFORCE_EQ (
113
+ ctx->HasOutput (framework::GradVarName (" X" )), true ,
114
+ platform::errors::NotFound (" Output(X@Grad) should not be null" ));
102
115
103
116
auto do_dims = ctx->GetInputDim (framework::GradVarName (" Out" ));
104
- PADDLE_ENFORCE (do_dims.size () == 4 , " The layout of input is NCHW." );
117
+ PADDLE_ENFORCE_EQ (
118
+ do_dims.size (), 4 ,
119
+ platform::errors::InvalidArgument (
120
+ " Input should be a 4-D tensor of format [N, C, H, W], but got %u." ,
121
+ do_dims.size ()));
105
122
106
123
auto upscale_factor = ctx->Attrs ().Get <int >(" upscale_factor" );
107
124
0 commit comments