File tree Expand file tree Collapse file tree 2 files changed +11
-6
lines changed Expand file tree Collapse file tree 2 files changed +11
-6
lines changed Original file line number Diff line number Diff line change @@ -47,8 +47,11 @@ class GRUOp : public framework::OperatorWithKernel {
47
47
auto weight_dims = ctx->GetInputDim (" Weight" );
48
48
int input_size = input_dims[1 ];
49
49
int frame_size = weight_dims[0 ];
50
- PADDLE_ENFORCE_EQ (input_size, frame_size * 3 ,
51
- " The input_size must be 3 times of frame_size in GRUOp." );
50
+ if (ctx->IsRuntime ()) {
51
+ PADDLE_ENFORCE_EQ (
52
+ input_size, frame_size * 3 ,
53
+ " The input_size must be 3 times of frame_size in GRUOp." );
54
+ }
52
55
PADDLE_ENFORCE_EQ (
53
56
weight_dims[1 ], frame_size * 3 ,
54
57
" The shape of Weight matrix must be [frame_size, frame_size * 3]." );
Original file line number Diff line number Diff line change @@ -34,10 +34,12 @@ class LstmUnitOp : public framework::OperatorWithKernel {
34
34
auto c_prev_dims = ctx->GetInputDim (" C_prev" );
35
35
36
36
PADDLE_ENFORCE_EQ (x_dims.size (), 2 , " Input(X)'s rank must be 2." );
37
- PADDLE_ENFORCE_EQ (x_dims[0 ], c_prev_dims[0 ],
38
- " Batch size of inputs and states must be equal" );
39
- PADDLE_ENFORCE_EQ (x_dims[1 ], c_prev_dims[1 ] * 4 ,
40
- " Dimension of FC should equal to prev state * 4" );
37
+ if (ctx->IsRuntime ()) {
38
+ PADDLE_ENFORCE_EQ (x_dims[0 ], c_prev_dims[0 ],
39
+ " Batch size of inputs and states must be equal" );
40
+ PADDLE_ENFORCE_EQ (x_dims[1 ], c_prev_dims[1 ] * 4 ,
41
+ " Dimension of FC should equal to prev state * 4" );
42
+ }
41
43
42
44
int b_size = c_prev_dims[0 ]; // batch size
43
45
int s_dim = c_prev_dims[1 ]; // state dim
You can’t perform that action at this time.
0 commit comments