@@ -36,14 +36,17 @@ class ConvShiftOp : public framework::OperatorWithKernel {
36
36
auto y_dims = ctx->GetInputDim (" Y" );
37
37
PADDLE_ENFORCE_EQ (x_dims.size (), 2 , " Input(X)'s rank should be 2." );
38
38
PADDLE_ENFORCE_EQ (y_dims.size (), 2 , " Input(Y)'s rank should be 2." );
39
- PADDLE_ENFORCE_EQ (x_dims[0 ], y_dims[0 ],
40
- " The 1st dimension of Input(X) and Input(Y) should "
41
- " be equal." );
42
- PADDLE_ENFORCE_EQ (y_dims[1 ] % 2 , 1 ,
43
- " The 2nd dimension of Input(Y) should be odd." );
44
- PADDLE_ENFORCE_LE (y_dims[1 ], x_dims[1 ],
45
- " The 2nd dimension of Input(Y) should be less than or "
46
- " equal to the 2nd dimension of Input(X)." );
39
+ if (ctx->IsRuntime () || (x_dims[0 ] > 0 && y_dims[0 ] > 0 ))
40
+ PADDLE_ENFORCE_EQ (x_dims[0 ], y_dims[0 ],
41
+ " The 1st dimension of Input(X) and Input(Y) should "
42
+ " be equal." );
43
+ if (ctx->IsRuntime () || y_dims[1 ] > 0 )
44
+ PADDLE_ENFORCE_EQ (y_dims[1 ] % 2 , 1 ,
45
+ " The 2nd dimension of Input(Y) should be odd." );
46
+ if (ctx->IsRuntime () || (x_dims[1 ] > 0 && y_dims[1 ] > 0 ))
47
+ PADDLE_ENFORCE_LE (y_dims[1 ], x_dims[1 ],
48
+ " The 2nd dimension of Input(Y) should be less than or "
49
+ " equal to the 2nd dimension of Input(X)." );
47
50
ctx->ShareDim (" X" , /* ->*/ " Out" );
48
51
ctx->ShareLoD (" X" , /* ->*/ " Out" );
49
52
}
0 commit comments