@@ -71,10 +71,9 @@ class MaxOutOp : public framework::OperatorWithKernel {
71
71
using framework::OperatorWithKernel::OperatorWithKernel;
72
72
void InferShape (framework::InferShapeContext* ctx) const override {
73
73
PADDLE_ENFORCE (ctx->HasInput (" X" ),
74
- " Input tensor of MaxoutOp"
75
- " should not be null." );
74
+ " Input(X) of MaxoutOpshould not be null." );
76
75
PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
77
- " Output tensor of MaxoutOp should not be null." );
76
+ " Output(Out) of MaxoutOp should not be null." );
78
77
auto in_x_dims = ctx->GetInputDim (" X" );
79
78
int groups = ctx->Attrs ().Get <int >(" groups" );
80
79
// check groups > 1
@@ -90,9 +89,10 @@ class MaxOutOpGrad : public framework::OperatorWithKernel {
90
89
public:
91
90
using framework::OperatorWithKernel::OperatorWithKernel;
92
91
void InferShape (framework::InferShapeContext* ctx) const override {
93
- PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input tensor must not be null." );
92
+ PADDLE_ENFORCE (ctx->HasInput (" X" ),
93
+ " Input(X) of MaxOutOpGrad must not be null." );
94
94
PADDLE_ENFORCE (ctx->HasOutput (framework::GradVarName (" X" )),
95
- " Output gradient tensor should not be null." );
95
+ " Output(Grad@X) of MaxOutOpGrad should not be null." );
96
96
ctx->SetOutputDim (framework::GradVarName (" X" ), ctx->GetInputDim (" X" ));
97
97
}
98
98
};
0 commit comments