@@ -31,31 +31,33 @@ class TransposeOp : public framework::OperatorWithKernel {
31
31
using framework::OperatorWithKernel::OperatorWithKernel;
32
32
33
33
void InferShape (framework::InferShapeContext *ctx) const override {
34
- PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input(X) should not be null " );
35
- PADDLE_ENFORCE (ctx->HasOutput (" Out" ), " Output( Out) should not be null " );
34
+ OP_INOUT_CHECK (ctx->HasInput (" X" ), " Input" , " X " , " Transpose " );
35
+ OP_INOUT_CHECK (ctx->HasOutput (" Out" ), " Output" , " Out" , " Transpose " );
36
36
auto x_dims = ctx->GetInputDim (" X" );
37
37
std::vector<int > axis = ctx->Attrs ().Get <std::vector<int >>(" axis" );
38
38
size_t x_rank = x_dims.size ();
39
39
size_t axis_size = axis.size ();
40
40
41
41
PADDLE_ENFORCE_EQ (x_rank, axis_size,
42
- " ShapeError: The input tensor's dimension "
43
- " should be equal to the axis's size. "
44
- " But received input tensor's dimension is %d, "
45
- " axis's size is %d" ,
46
- x_rank, axis_size);
42
+ platform::errors::InvalidArgument (
43
+ " The input tensor's dimension "
44
+ " should be equal to the axis's size. "
45
+ " But received input tensor's dimension is %d, "
46
+ " axis's size is %d" ,
47
+ x_rank, axis_size));
47
48
48
49
std::vector<int > count (axis_size, 0 );
49
50
for (size_t i = 0 ; i < axis_size; i++) {
50
- PADDLE_ENFORCE (
51
- axis[i] < static_cast <int >(axis_size) && ++count[axis[i]] == 1 ,
52
- " ValueError: Each element of Attribute axis should "
53
- " be a unique value range from 0 to (dims - 1), "
54
- " where the dims is the axis's size, "
55
- " unique value means this axis value can appear only once. "
56
- " But received axis[%d] is %d, axis_size is %d, "
57
- " count[axis[%d]] is %d" ,
58
- i, axis[i], axis_size, i, count[axis[i]]);
51
+ PADDLE_ENFORCE_EQ (
52
+ axis[i] < static_cast <int >(axis_size) && ++count[axis[i]] == 1 , true ,
53
+ platform::errors::InvalidArgument (
54
+ " Each element of Attribute axis should "
55
+ " be a unique value range from 0 to (dims - 1), "
56
+ " where the dims is the axis's size, "
57
+ " unique value means this axis value can appear only once. "
58
+ " But received axis[%d] is %d, axis_size is %d, "
59
+ " count[axis[%d]] is %d" ,
60
+ i, axis[i], axis_size, i, count[axis[i]]));
59
61
}
60
62
61
63
framework::DDim out_dims (x_dims);
@@ -149,9 +151,9 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
149
151
using framework::OperatorWithKernel::OperatorWithKernel;
150
152
151
153
void InferShape (framework::InferShapeContext *ctx) const override {
152
- PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input(X) should not be null " );
153
- PADDLE_ENFORCE (ctx->HasInput (framework::GradVarName (" Out" )),
154
- " Input( Out@GRAD) should not be null " );
154
+ OP_INOUT_CHECK (ctx->HasInput (" X" ), " Input" , " X " , " TransposeOpGrad " );
155
+ OP_INOUT_CHECK (ctx->HasInput (framework::GradVarName (" Out" )), " Input " ,
156
+ framework::GradVarName ( " Out" ), " TransposeOpGrad " );
155
157
auto x_dims = ctx->GetInputDim (" X" );
156
158
ctx->SetOutputDim (framework::GradVarName (" X" ), x_dims);
157
159
if (ctx->HasOutput (framework::GradVarName (" X" ))) {
@@ -193,8 +195,7 @@ class Transpose2Op : public TransposeOp {
193
195
194
196
void InferShape (framework::InferShapeContext *ctx) const override {
195
197
TransposeOp::InferShape (ctx);
196
- PADDLE_ENFORCE (ctx->HasOutput (" XShape" ),
197
- " Output(XShape) should not be null" );
198
+ OP_INOUT_CHECK (ctx->HasOutput (" XShape" ), " Output" , " XShape" , " Transpose2" );
198
199
const auto &in_dims = ctx->GetInputDim (" X" );
199
200
std::vector<int64_t > x_shape_dim (in_dims.size () + 1 );
200
201
x_shape_dim[0 ] = 0 ;
@@ -259,9 +260,10 @@ class Transpose2OpGrad : public framework::OperatorWithKernel {
259
260
using framework::OperatorWithKernel::OperatorWithKernel;
260
261
261
262
void InferShape (framework::InferShapeContext *ctx) const override {
262
- PADDLE_ENFORCE (ctx->HasInput (" XShape" ), " Input(XShape) should not be null" );
263
- PADDLE_ENFORCE (ctx->HasInput (framework::GradVarName (" Out" )),
264
- " Input(Out@GRAD) should not be null" );
263
+ OP_INOUT_CHECK (ctx->HasInput (" XShape" ), " Input" , " XShape" ,
264
+ " Transpose2OpGrad" );
265
+ OP_INOUT_CHECK (ctx->HasInput (framework::GradVarName (" Out" )), " Input" ,
266
+ framework::GradVarName (" Out" ), " Transpose2OpGrad" );
265
267
if (ctx->HasOutput (framework::GradVarName (" X" ))) {
266
268
auto xshape_dim = ctx->GetInputDim (" XShape" );
267
269
auto x_shape_dim =
0 commit comments