@@ -31,44 +31,58 @@ class GRUOp : public framework::OperatorWithKernel {
31
31
using framework::OperatorWithKernel::OperatorWithKernel;
32
32
33
33
void InferShape (framework::InferShapeContext* ctx) const override {
34
- PADDLE_ENFORCE (ctx->HasInput (" Input" ),
35
- " Input(%s) of GRUOp should not be null." , " Input" );
36
- PADDLE_ENFORCE (ctx->HasInput (" Weight" ),
37
- " Input(%s) of GRUOp should not be null." , " Weight" );
38
- PADDLE_ENFORCE (ctx->HasOutput (" BatchGate" ),
39
- " Output(%s) of GRUOp should not be null." , " BatchGate" );
40
- PADDLE_ENFORCE (ctx->HasOutput (" BatchResetHiddenPrev" ),
41
- " Output(%s) of GRUOp should not be null." ,
42
- " BatchResetHiddenPrev" );
43
- PADDLE_ENFORCE (ctx->HasOutput (" BatchHidden" ),
44
- " Output(%s) of GRUOp should not be null." , " BatchHidden" );
45
- PADDLE_ENFORCE (ctx->HasOutput (" Hidden" ),
46
- " Output(%s) of GRUOp should not be null." , " Hidden" );
34
+ OP_INOUT_CHECK (ctx->HasInput (" Input" ), " Input" , " Input" , " GRU" );
35
+ OP_INOUT_CHECK (ctx->HasInput (" Weight" ), " Input" , " Weight" , " GRU" );
36
+ OP_INOUT_CHECK (ctx->HasOutput (" BatchGate" ), " Output" , " BatchGate" , " GRU" );
37
+ OP_INOUT_CHECK (ctx->HasOutput (" BatchResetHiddenPrev" ), " Output" ,
38
+ " BatchResetHiddenPrev" , " GRU" );
39
+ OP_INOUT_CHECK (ctx->HasOutput (" BatchHidden" ), " Output" , " BatchHidden" ,
40
+ " GRU" );
41
+ OP_INOUT_CHECK (ctx->HasOutput (" Hidden" ), " Output" , " Hidden" , " GRU" );
42
+
47
43
auto input_dims = ctx->GetInputDim (" Input" );
48
44
auto weight_dims = ctx->GetInputDim (" Weight" );
49
45
int input_size = input_dims[1 ];
50
46
int frame_size = weight_dims[0 ];
51
47
if (ctx->IsRuntime ()) {
52
- PADDLE_ENFORCE_EQ (
53
- input_size, frame_size * 3 ,
54
- " The input_size must be 3 times of frame_size in GRUOp." );
48
+ PADDLE_ENFORCE_EQ (input_size, frame_size * 3 ,
49
+ platform::errors::InvalidArgument (
50
+ " The second dimension of Input(Input) must be 3 "
51
+ " times of frame_size in GRUOp, but received %d "
52
+ " (Input) vs %d (frame_size)." ,
53
+ input_size, frame_size));
55
54
}
56
55
PADDLE_ENFORCE_EQ (
57
56
weight_dims[1 ], frame_size * 3 ,
58
- " The shape of Weight matrix must be [frame_size, frame_size * 3]." );
57
+ platform::errors::InvalidArgument (
58
+ " The shape of Input(Weight) matrix must be [frame_size, frame_size "
59
+ " * 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size)." ,
60
+ weight_dims[0 ], weight_dims[1 ], frame_size, frame_size * 3 ));
59
61
if (ctx->HasInput (" H0" )) {
60
62
auto h0_dims = ctx->GetInputDim (" H0" );
61
- PADDLE_ENFORCE_EQ (h0_dims[1 ], frame_size,
62
- " The width of H0 must be equal to frame_size." );
63
+ PADDLE_ENFORCE_EQ (
64
+ h0_dims[1 ], frame_size,
65
+ platform::errors::InvalidArgument (
66
+ " The width of Input(H0) must be equal to frame_size, but "
67
+ " received %d (width of H0) vs %d (frame_size)." ,
68
+ h0_dims[1 ], frame_size));
63
69
}
64
70
if (ctx->HasInput (" Bias" )) {
65
71
auto bias_dims = ctx->GetInputDim (" Bias" );
66
72
int bias_height = bias_dims[0 ];
67
73
int bias_width = bias_dims[1 ];
68
- PADDLE_ENFORCE_EQ (bias_height, 1 ,
69
- " The shape of Bias must be [1, frame_size * 3]." );
70
- PADDLE_ENFORCE_EQ (bias_width, frame_size * 3 ,
71
- " The shape of Bias must be [1, frame_size * 3]." );
74
+ PADDLE_ENFORCE_EQ (
75
+ bias_height, 1 ,
76
+ platform::errors::InvalidArgument (
77
+ " The shape of Bias must be [1, frame_size * 3], but received "
78
+ " [%d, %d] (Bias) vs [1, %d] (frame_size * 3)." ,
79
+ bias_height, bias_width, frame_size * 3 ));
80
+ PADDLE_ENFORCE_EQ (
81
+ bias_width, frame_size * 3 ,
82
+ platform::errors::InvalidArgument (
83
+ " The shape of Bias must be [1, frame_size * 3], but received "
84
+ " [%d, %d] (Bias) vs [1, %d] (frame_size * 3)." ,
85
+ bias_height, bias_width, frame_size * 3 ));
72
86
}
73
87
ctx->SetOutputDim (" BatchGate" , input_dims);
74
88
ctx->SetOutputDim (" BatchResetHiddenPrev" , {input_dims[0 ], frame_size});
@@ -166,39 +180,50 @@ class GRUGradOp : public framework::OperatorWithKernel {
166
180
using framework::OperatorWithKernel::OperatorWithKernel;
167
181
168
182
void InferShape (framework::InferShapeContext* ctx) const override {
169
- PADDLE_ENFORCE (ctx->HasInput (" Input" ),
170
- " Input(%s) of GRUGradOp should not be null." , " Input" );
171
- PADDLE_ENFORCE (ctx->HasInput (" Weight" ),
172
- " Input(%s) of GRUGradOp should not be null." , " Weight" );
173
- PADDLE_ENFORCE (ctx->HasInput (" BatchGate" ),
174
- " Input(%s) of GRUGradOp should not be null." , " BatchGate" );
175
- PADDLE_ENFORCE (ctx->HasInput (" BatchResetHiddenPrev" ),
176
- " Input(%s) of GRUGradOp should not be null." ,
177
- " BatchResetHiddenPrev" );
178
- PADDLE_ENFORCE (ctx->HasInput (" BatchHidden" ),
179
- " Input(%s) of GRUOp should not be null." , " BatchHidden" );
180
- PADDLE_ENFORCE (ctx->HasInput (" Hidden" ),
181
- " Input(%s) of GRUGradOp should not be null." , " Hidden" );
182
- PADDLE_ENFORCE (ctx->HasInput (framework::GradVarName (" Hidden" )),
183
- " Input(%s@GRAD) of GRUGradOp should not be null." , " Hidden" );
183
+ OP_INOUT_CHECK (ctx->HasInput (" Input" ), " Input" , " Input" , " GRU@Grad" );
184
+ OP_INOUT_CHECK (ctx->HasInput (" Weight" ), " Input" , " Weight" , " GRU@Grad" );
185
+ OP_INOUT_CHECK (ctx->HasInput (" BatchGate" ), " Input" , " BatchGate" ,
186
+ " GRU@Grad" );
187
+ OP_INOUT_CHECK (ctx->HasInput (" BatchResetHiddenPrev" ), " Input" ,
188
+ " BatchResetHiddenPrev" , " GRU@Grad" );
189
+ OP_INOUT_CHECK (ctx->HasInput (" BatchHidden" ), " Input" , " BatchHidden" ,
190
+ " GRU@Grad" );
191
+ OP_INOUT_CHECK (ctx->HasInput (" Hidden" ), " Input" , " Hidden" , " GRU@Grad" );
192
+ OP_INOUT_CHECK (ctx->HasInput (framework::GradVarName (" Hidden" )), " Input" ,
193
+ framework::GradVarName (" Hidden" ), " GRU@Grad" );
194
+
184
195
auto input_dims = ctx->GetInputDim (" Input" );
185
196
auto weight_dims = ctx->GetInputDim (" Weight" );
186
197
int input_size = input_dims[1 ];
187
198
int frame_size = weight_dims[0 ];
188
199
int weight_height = weight_dims[0 ];
189
200
int weight_width = weight_dims[1 ];
190
- PADDLE_ENFORCE_EQ (input_size, frame_size * 3 ,
191
- " The input_size must be 3 times of frame_size in GRUOp." );
201
+ PADDLE_ENFORCE_EQ (
202
+ input_size, frame_size * 3 ,
203
+ platform::errors::InvalidArgument (
204
+ " The second dimension of Input(Input) must be 3 times of "
205
+ " frame_size in GRUOp, but received %d (Input) vs %d (frame_size)." ,
206
+ input_size, frame_size));
192
207
PADDLE_ENFORCE_EQ (
193
208
weight_height, frame_size,
194
- " The shape of Weight matrix must be [frame_size, frame_size * 3]." );
209
+ platform::errors::InvalidArgument (
210
+ " The shape of Input(Weight) matrix must be [frame_size, frame_size "
211
+ " * 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size)." ,
212
+ weight_height, weight_width, frame_size, frame_size * 3 ));
195
213
PADDLE_ENFORCE_EQ (
196
214
weight_width, frame_size * 3 ,
197
- " The shape of Weight matrix must be [frame_size, frame_size * 3]." );
215
+ platform::errors::InvalidArgument (
216
+ " The shape of Input(Weight) matrix must be [frame_size, frame_size "
217
+ " * 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size)." ,
218
+ weight_height, weight_width, frame_size, frame_size * 3 ));
198
219
if (ctx->HasInput (" H0" )) {
199
220
auto h0_dims = ctx->GetInputDim (" H0" );
200
- PADDLE_ENFORCE_EQ (h0_dims[1 ], frame_size,
201
- " The width of H0 must be equal to frame_size." );
221
+ PADDLE_ENFORCE_EQ (
222
+ h0_dims[1 ], frame_size,
223
+ platform::errors::InvalidArgument (
224
+ " The width of Input(H0) must be equal to frame_size, but "
225
+ " received %d (width of H0) vs %d (frame_size)." ,
226
+ h0_dims[1 ], frame_size));
202
227
auto h0_grad_name = framework::GradVarName (" H0" );
203
228
if (ctx->HasOutput (h0_grad_name))
204
229
ctx->SetOutputDim (h0_grad_name, h0_dims);
@@ -207,10 +232,18 @@ class GRUGradOp : public framework::OperatorWithKernel {
207
232
auto bias_dims = ctx->GetInputDim (" Bias" );
208
233
int bias_height = bias_dims[0 ];
209
234
int bias_width = bias_dims[1 ];
210
- PADDLE_ENFORCE_EQ (bias_height, 1 ,
211
- " The shape of Bias must be [1, frame_size * 3]." );
212
- PADDLE_ENFORCE_EQ (bias_width, frame_size * 3 ,
213
- " The shape of Bias must be [1, frame_size * 3]." );
235
+ PADDLE_ENFORCE_EQ (
236
+ bias_height, 1 ,
237
+ platform::errors::InvalidArgument (
238
+ " The shape of Bias must be [1, frame_size * 3], but received "
239
+ " [%d, %d] (Bias) vs [1, %d] (frame_size * 3)." ,
240
+ bias_height, bias_width, frame_size * 3 ));
241
+ PADDLE_ENFORCE_EQ (
242
+ bias_width, frame_size * 3 ,
243
+ platform::errors::InvalidArgument (
244
+ " The shape of Bias must be [1, frame_size * 3], but received "
245
+ " [%d, %d] (Bias) vs [1, %d] (frame_size * 3)." ,
246
+ bias_height, bias_width, frame_size * 3 ));
214
247
auto bias_grad_name = framework::GradVarName (" Bias" );
215
248
if (ctx->HasOutput (bias_grad_name))
216
249
ctx->SetOutputDim (bias_grad_name, bias_dims);
@@ -298,14 +331,20 @@ class GRUCPUKernel : public framework::OpKernel<T> {
298
331
T* packed_gate = blas.GEMM_ALLOC (CblasBMatrix, 1 /* height of C*/ ,
299
332
frame_size * 2 /* width of weight*/ ,
300
333
frame_size /* height of height*/ );
301
- PADDLE_ENFORCE (packed_gate);
334
+ PADDLE_ENFORCE_NOT_NULL (
335
+ packed_gate, platform::errors::NotFound (
336
+ " The caculation result of packed_gate by "
337
+ " GEMM_ALLOC should not be null when using MKL." ));
302
338
blas.GEMM_PACK (CblasBMatrix, CblasNoTrans, 1 /* cur bs?*/ , frame_size * 2 ,
303
339
frame_size, T (1.0 ), gru_value.gate_weight , frame_size * 2 ,
304
340
packed_gate);
305
341
T* packed_state = blas.GEMM_ALLOC (CblasBMatrix, 1 /* height of C*/ ,
306
342
frame_size /* width of weight*/ ,
307
343
frame_size /* height of height*/ );
308
- PADDLE_ENFORCE (packed_state);
344
+ PADDLE_ENFORCE_NOT_NULL (
345
+ packed_state, platform::errors::NotFound (
346
+ " The caculation result of packed_state by "
347
+ " GEMM_ALLOC should not be null when using MKL." ));
309
348
blas.GEMM_PACK (CblasBMatrix, CblasNoTrans, 1 /* cur bs?*/ , frame_size,
310
349
frame_size, T (1.0 ), gru_value.state_weight , frame_size,
311
350
packed_state);
0 commit comments