@@ -27,24 +27,35 @@ class UnStackOp : public framework::OperatorWithKernel {
27
27
using framework::OperatorWithKernel::OperatorWithKernel;
28
28
29
29
void InferShape (framework::InferShapeContext *ctx) const override {
30
- PADDLE_ENFORCE_EQ (ctx->HasInput (" X" ), true , " Input(X) must exist." );
31
-
30
+ OP_INOUT_CHECK (ctx->HasInput (" X" ), " Input" , " X" , " UnStack" );
32
31
int axis = ctx->Attrs ().Get <int >(" axis" );
33
32
int num = ctx->Attrs ().Get <int >(" num" );
34
33
auto x_dim = ctx->GetInputDim (" X" );
35
34
int rank = x_dim.size ();
36
- PADDLE_ENFORCE_GE (
37
- axis, -rank, " Attr(axis) must be inside [-rank, rank), where rank = %d" ,
38
- rank);
39
- PADDLE_ENFORCE_LT (
40
- axis, rank, " Attr(axis) must be inside [-rank, rank), where rank = %d" ,
41
- rank);
35
+ PADDLE_ENFORCE_GE (axis, -rank,
36
+ platform::errors::InvalidArgument (
37
+ " The attribute axis is out of range, it must be "
38
+ " inside [-rank, rank), where rank = %d" ,
39
+ rank));
40
+ PADDLE_ENFORCE_LT (axis, rank,
41
+ platform::errors::InvalidArgument (
42
+ " The attribute axis is out of range, it must be "
43
+ " inside [-rank, rank), where rank = %d" ,
44
+ rank));
42
45
if (axis < 0 ) axis += rank;
43
46
44
47
PADDLE_ENFORCE_EQ (ctx->Outputs (" Y" ).size (), static_cast <size_t >(num),
45
- " Number of Outputs(Y) is wrong" );
48
+ platform::errors::InvalidArgument (
49
+ " Number of Outputs(Y) is wrong. Got %d , but it must "
50
+ " equal to attribute num which is %d." ,
51
+ ctx->Outputs (" Y" ).size (), static_cast <size_t >(num)));
46
52
if (x_dim[axis] > 0 ) {
47
- PADDLE_ENFORCE_EQ (num, x_dim[axis], " Number of Outputs(Y) is wrong" );
53
+ PADDLE_ENFORCE_EQ (
54
+ num, x_dim[axis],
55
+ platform::errors::InvalidArgument (
56
+ " The number of attribute num is not equal to the length of the "
57
+ " %d axis of Input(X). Expect %d but got %d." ,
58
+ axis, x_dim[axis], num));
48
59
}
49
60
auto vec = framework::vectorize<int >(x_dim);
50
61
vec.erase (vec.begin () + axis);
@@ -89,24 +100,29 @@ class UnStackGradOp : public framework::OperatorWithKernel {
89
100
90
101
void InferShape (framework::InferShapeContext *ctx) const override {
91
102
PADDLE_ENFORCE_GT (ctx->Inputs (framework::GradVarName (" Y" )).size (), 0 ,
92
- " Number of Inputs(Y@Grad) must be larger than 0 " );
93
- PADDLE_ENFORCE_EQ (ctx-> HasOutput ( framework::GradVarName ( " X " )), true ,
94
- " Output(X@Grad) must exist. " );
95
-
103
+ platform::errors::InvalidArgument (
104
+ " Number of Inputs(Y@Grad) must be larger than 0 " ));
105
+ OP_INOUT_CHECK (ctx-> HasOutput ( framework::GradVarName ( " X " )), " Output " , " X " ,
106
+ " UnStackGrad " );
96
107
auto input_dims = ctx->GetInputsDim (framework::GradVarName (" Y" ));
97
108
for (size_t i = 1 ; i < input_dims.size (); ++i) {
98
109
PADDLE_ENFORCE_EQ (input_dims[i], input_dims[0 ],
99
- " Dims of all Inputs(Y@Grad) must be the same" );
110
+ platform::errors::InvalidArgument (
111
+ " Dims of all Inputs(Y@Grad) must be the same" ));
100
112
}
101
113
102
114
int axis = ctx->Attrs ().Get <int >(" axis" );
103
115
int rank = input_dims[0 ].size ();
104
- PADDLE_ENFORCE_GE (
105
- axis, -(rank + 1 ),
106
- " Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d" , rank);
107
- PADDLE_ENFORCE_LT (
108
- axis, rank + 1 ,
109
- " Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d" , rank);
116
+ PADDLE_ENFORCE_GE (axis, -(rank + 1 ),
117
+ platform::errors::InvalidArgument (
118
+ " The attribute axis is out of range, it must be "
119
+ " inside [-(rank+1), rank+1), where rank = %d" ,
120
+ rank));
121
+ PADDLE_ENFORCE_LT (axis, rank + 1 ,
122
+ platform::errors::InvalidArgument (
123
+ " The attribute axis is out of range, it must be "
124
+ " inside [-(rank+1), rank+1), where rank = %d" ,
125
+ rank));
110
126
if (axis < 0 ) axis += (rank + 1 );
111
127
112
128
auto vec = framework::vectorize<int >(input_dims[0 ]);
0 commit comments