@@ -24,74 +24,92 @@ class LSTMPOp : public framework::OperatorWithKernel {
24
24
using framework::OperatorWithKernel::OperatorWithKernel;
25
25
26
26
void InferShape (framework::InferShapeContext* ctx) const override {
27
- PADDLE_ENFORCE (ctx->HasInput (" Input" ),
28
- " Input(Input) of LSTMP operator should not be null." );
29
- PADDLE_ENFORCE (ctx->HasInput (" Weight" ),
30
- " Input(Weight) of LSTMP operator should not be null." );
31
- PADDLE_ENFORCE (ctx->HasInput (" ProjWeight" ),
32
- " Input(ProjWeight) of LSTMP operator should not be null." );
33
- PADDLE_ENFORCE (ctx->HasInput (" Bias" ),
34
- " Input(Bias) of LSTMP operator should not be null." );
35
-
36
- PADDLE_ENFORCE (ctx->HasOutput (" Projection" ),
37
- " Output(Projection) of LSTMP operator should not be null." );
38
- PADDLE_ENFORCE (ctx->HasOutput (" Cell" ),
39
- " Output(Cell) of LSTMP operator should not be null." );
40
- PADDLE_ENFORCE (ctx->HasOutput (" BatchGate" ),
41
- " Output(BatchGate) of LSTMP operator should not be null." );
42
- PADDLE_ENFORCE (ctx->HasOutput (" BatchCellPreAct" ),
43
- " Output(BatchCellPreAct) of LSTMP operator should not be "
44
- " null." );
45
- PADDLE_ENFORCE (ctx->HasOutput (" BatchHidden" ),
46
- " Output(BatchHidden) of LSTMP operator should not be null." );
27
+ OP_INOUT_CHECK (ctx->HasInput (" Input" ), " Input" , " Input" , " LSTMP" );
28
+ OP_INOUT_CHECK (ctx->HasInput (" Weight" ), " Input" , " Weight" , " LSTMP" );
29
+ OP_INOUT_CHECK (ctx->HasInput (" ProjWeight" ), " Input" , " ProjWeight" , " LSTMP" );
30
+ OP_INOUT_CHECK (ctx->HasInput (" Bias" ), " Input" , " Bias" , " LSTMP" );
31
+
32
+ OP_INOUT_CHECK (ctx->HasOutput (" Projection" ), " Output" , " Projection" ,
33
+ " LSTMP" );
34
+ OP_INOUT_CHECK (ctx->HasOutput (" Cell" ), " Output" , " Cell" , " LSTMP" );
35
+ OP_INOUT_CHECK (ctx->HasOutput (" BatchGate" ), " Output" , " BatchGate" , " LSTMP" );
36
+ OP_INOUT_CHECK (ctx->HasOutput (" BatchCellPreAct" ), " Output" ,
37
+ " BatchCellPreAct" , " LSTMP" );
38
+ OP_INOUT_CHECK (ctx->HasOutput (" BatchHidden" ), " Output" , " BatchHidden" ,
39
+ " LSTMP" );
47
40
48
41
auto in_dims = ctx->GetInputDim (" Input" );
49
42
50
- PADDLE_ENFORCE_EQ (in_dims.size (), 2 ,
51
- " Input(X)'s rank of LSTMP operator must be 2." );
43
+ PADDLE_ENFORCE_EQ (
44
+ in_dims.size (), 2 ,
45
+ platform::errors::InvalidArgument (
46
+ " Input(X)'s rank of LSTMP operator must be 2, but received %d." ,
47
+ in_dims.size ()));
52
48
53
49
int frame_size = in_dims[1 ] / 4 ;
54
50
auto w_dims = ctx->GetInputDim (" Weight" );
55
51
auto proj_dims = ctx->GetInputDim (" ProjWeight" );
56
- PADDLE_ENFORCE_EQ (w_dims.size (), 2 ,
57
- " The rank of Input(Weight) should be 2." );
58
- PADDLE_ENFORCE_EQ (w_dims[0 ], proj_dims[1 ],
59
- " The first dimension of Input(Weight) "
60
- " should be %d." ,
61
- proj_dims[1 ]);
52
+ PADDLE_ENFORCE_EQ (
53
+ w_dims.size (), 2 ,
54
+ platform::errors::InvalidArgument (
55
+ " The rank of Input(Weight) should be 2, but received %d." ,
56
+ w_dims.size ()));
57
+ PADDLE_ENFORCE_EQ (
58
+ w_dims[0 ], proj_dims[1 ],
59
+ platform::errors::InvalidArgument (
60
+ " The first dimension of Input(Weight) and the second dimension of "
61
+ " Input(ProjWeight) should be the same, but received %d vs %d." ,
62
+ w_dims[0 ], proj_dims[1 ]));
62
63
PADDLE_ENFORCE_EQ (w_dims[1 ], 4 * frame_size,
63
- " The second dimension of Input(Weight) "
64
- " should be 4 * %d." ,
65
- frame_size);
66
-
67
- PADDLE_ENFORCE_EQ (proj_dims.size (), 2 ,
68
- " The rank of Input(ProjWeight) should be 2." );
64
+ platform::errors::InvalidArgument (
65
+ " The second dimension of Input(Weight) should be 4 * "
66
+ " %d, but received %d." ,
67
+ frame_size, w_dims[1 ]));
68
+
69
+ PADDLE_ENFORCE_EQ (
70
+ proj_dims.size (), 2 ,
71
+ platform::errors::InvalidArgument (
72
+ " The rank of Input(ProjWeight) should be 2, but received %d." ,
73
+ proj_dims.size ()));
69
74
PADDLE_ENFORCE_EQ (proj_dims[0 ], frame_size,
70
- " The first dimension of Input(ProjWeight) "
71
- " should be %d." ,
72
- frame_size);
75
+ platform::errors::InvalidArgument (
76
+ " The first dimension of Input(ProjWeight) should be "
77
+ " %d, but received %d." ,
78
+ frame_size, proj_dims[0 ]));
73
79
74
80
if (ctx->HasInput (" H0" )) {
75
- PADDLE_ENFORCE (ctx->HasInput (" C0" ),
76
- " Input(C0) of LSTMP operator should not be null after "
77
- " Input(H0) provided." );
81
+ PADDLE_ENFORCE_EQ (
82
+ ctx->HasInput (" C0" ), true ,
83
+ platform::errors::NotFound (" Input(C0) of LSTMP operator should not "
84
+ " be null after Input(H0) provided." ));
78
85
}
79
86
80
87
auto b_dims = ctx->GetInputDim (" Bias" );
81
- PADDLE_ENFORCE_EQ (b_dims.size (), 2 , " The rank of Input(Bias) should be 2." );
82
- PADDLE_ENFORCE_EQ (b_dims[0 ], 1 ,
83
- " The first dimension of Input(Bias) should be 1." );
88
+ PADDLE_ENFORCE_EQ (
89
+ b_dims.size (), 2 ,
90
+ platform::errors::InvalidArgument (
91
+ " The rank of Input(Bias) should be 2, but received %d." ,
92
+ b_dims.size ()));
93
+ PADDLE_ENFORCE_EQ (
94
+ b_dims[0 ], 1 ,
95
+ platform::errors::InvalidArgument (
96
+ " The first dimension of Input(Bias) should be 1, but received %d." ,
97
+ b_dims[0 ]));
84
98
85
99
if (ctx->Attrs ().Get <bool >(" use_peepholes" )) {
86
- PADDLE_ENFORCE_EQ (b_dims[1 ], 7 * frame_size,
87
- " The second dimension of Input(Bias) should be "
88
- " 7 * %d if enable peepholes connection" ,
89
- frame_size);
100
+ PADDLE_ENFORCE_EQ (
101
+ b_dims[1 ], 7 * frame_size,
102
+ platform::errors::InvalidArgument (
103
+ " The second dimension of Input(Bias) should be 7 * %d if enable "
104
+ " peepholes connection, but received %d." ,
105
+ frame_size, b_dims[1 ]));
90
106
} else {
91
- PADDLE_ENFORCE_EQ (b_dims[1 ], 4 * frame_size,
92
- " The second dimension of Input(Bias) should be "
93
- " 4 * %d if disable peepholes connection" ,
94
- frame_size);
107
+ PADDLE_ENFORCE_EQ (
108
+ b_dims[1 ], 4 * frame_size,
109
+ platform::errors::InvalidArgument (
110
+ " The second dimension of Input(Bias) should be 4 * %d if disable "
111
+ " peepholes connection, but received %d." ,
112
+ frame_size, b_dims[1 ]));
95
113
}
96
114
97
115
framework::DDim out_dims ({in_dims[0 ], frame_size});
@@ -314,21 +332,18 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
314
332
using framework::OperatorWithKernel::OperatorWithKernel;
315
333
316
334
void InferShape (framework::InferShapeContext* ctx) const override {
317
- PADDLE_ENFORCE (ctx->HasInput (" Projection" ),
318
- " Input(Projection) of LSTMP operator should not be null." );
319
- PADDLE_ENFORCE (ctx->HasInput (" Cell" ),
320
- " Input(Cell) of LSTMP operator should not be null." );
321
- PADDLE_ENFORCE (ctx->HasInput (" Weight" ),
322
- " Input(Weight) of LSTMP operator should not be null." );
323
- PADDLE_ENFORCE (ctx->HasInput (" ProjWeight" ),
324
- " Input(ProjWeight) of LSTMP operator should not be null." );
325
- PADDLE_ENFORCE (ctx->HasInput (" Bias" ),
326
- " Input(Bias) of LSTMP operator should not be null." );
327
-
328
- PADDLE_ENFORCE (ctx->HasInput (" BatchGate" ),
329
- " Input(BatchGate) of LSTMP operator should not be null." );
330
- PADDLE_ENFORCE (ctx->HasInput (" BatchCellPreAct" ),
331
- " Input(BatchGate) of LSTMP operator should not be null." );
335
+ OP_INOUT_CHECK (ctx->HasInput (" Projection" ), " Input" , " Projection" ,
336
+ " LSTMP@Grad" );
337
+ OP_INOUT_CHECK (ctx->HasInput (" Cell" ), " Input" , " Cell" , " LSTMP@Grad" );
338
+ OP_INOUT_CHECK (ctx->HasInput (" Weight" ), " Input" , " Weight" , " LSTMP@Grad" );
339
+ OP_INOUT_CHECK (ctx->HasInput (" ProjWeight" ), " Input" , " ProjWeight" ,
340
+ " LSTMP@Grad" );
341
+ OP_INOUT_CHECK (ctx->HasInput (" Bias" ), " Input" , " Bias" , " LSTMP@Grad" );
342
+
343
+ OP_INOUT_CHECK (ctx->HasInput (" BatchGate" ), " Input" , " BatchGate" ,
344
+ " LSTMP@Grad" );
345
+ OP_INOUT_CHECK (ctx->HasInput (" BatchCellPreAct" ), " Input" , " BatchCellPreAct" ,
346
+ " LSTMP@Grad" );
332
347
333
348
auto SetOutGradDim = [&ctx](const std::string& name) {
334
349
auto g_name = framework::GradVarName (name);
0 commit comments