File tree Expand file tree Collapse file tree 3 files changed +21
-8
lines changed Expand file tree Collapse file tree 3 files changed +21
-8
lines changed Original file line number Diff line number Diff line change @@ -71,8 +71,16 @@ class BinaryLogicalOpInferShape : public framework::InferShapeBase {
71
71
" Input(Y) of %s operator must not be null" , comment.type );
72
72
auto dim_x = context->GetInputDim (" X" );
73
73
auto dim_y = context->GetInputDim (" Y" );
74
- PADDLE_ENFORCE_EQ (framework::product (dim_x), framework::product (dim_y),
75
- " The number of elements in X and Y should be same" );
74
+
75
+ int product_x = framework::product (dim_x);
76
+ int product_y = framework::product (dim_y);
77
+ bool check = context->IsRuntime () || (product_x >= 0 && product_y >= 0 );
78
+ if (check) {
79
+ PADDLE_ENFORCE_EQ (
80
+ product_x, product_y,
81
+ " The number of elements in X and Y should be same, %d != %d" ,
82
+ product_x, product_y);
83
+ }
76
84
77
85
context->SetOutputDim (" Out" , context->GetInputDim (" X" ));
78
86
context->ShareLoD (" X" , " Out" );
Original file line number Diff line number Diff line change @@ -47,8 +47,11 @@ class GRUOp : public framework::OperatorWithKernel {
47
47
auto weight_dims = ctx->GetInputDim (" Weight" );
48
48
int input_size = input_dims[1 ];
49
49
int frame_size = weight_dims[0 ];
50
- PADDLE_ENFORCE_EQ (input_size, frame_size * 3 ,
51
- " The input_size must be 3 times of frame_size in GRUOp." );
50
+ if (ctx->IsRuntime ()) {
51
+ PADDLE_ENFORCE_EQ (
52
+ input_size, frame_size * 3 ,
53
+ " The input_size must be 3 times of frame_size in GRUOp." );
54
+ }
52
55
PADDLE_ENFORCE_EQ (
53
56
weight_dims[1 ], frame_size * 3 ,
54
57
" The shape of Weight matrix must be [frame_size, frame_size * 3]." );
Original file line number Diff line number Diff line change @@ -34,10 +34,12 @@ class LstmUnitOp : public framework::OperatorWithKernel {
34
34
auto c_prev_dims = ctx->GetInputDim (" C_prev" );
35
35
36
36
PADDLE_ENFORCE_EQ (x_dims.size (), 2 , " Input(X)'s rank must be 2." );
37
- PADDLE_ENFORCE_EQ (x_dims[0 ], c_prev_dims[0 ],
38
- " Batch size of inputs and states must be equal" );
39
- PADDLE_ENFORCE_EQ (x_dims[1 ], c_prev_dims[1 ] * 4 ,
40
- " Dimension of FC should equal to prev state * 4" );
37
+ if (ctx->IsRuntime ()) {
38
+ PADDLE_ENFORCE_EQ (x_dims[0 ], c_prev_dims[0 ],
39
+ " Batch size of inputs and states must be equal" );
40
+ PADDLE_ENFORCE_EQ (x_dims[1 ], c_prev_dims[1 ] * 4 ,
41
+ " Dimension of FC should equal to prev state * 4" );
42
+ }
41
43
42
44
int b_size = c_prev_dims[0 ]; // batch size
43
45
int s_dim = c_prev_dims[1 ]; // state dim
You can’t perform that action at this time.
0 commit comments