Skip to content

Commit 5c10c57

Browse files
luotao1phlrain
authored andcommitted
Merge pull request #16854 from luotao1/conv_shift_infershape
Fix conv_shift_op infershape
1 parent a8afb98 commit 5c10c57

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

paddle/fluid/operators/conv_shift_op.cc

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,17 @@ class ConvShiftOp : public framework::OperatorWithKernel {
3636
auto y_dims = ctx->GetInputDim("Y");
3737
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
3838
PADDLE_ENFORCE_EQ(y_dims.size(), 2, "Input(Y)'s rank should be 2.");
39-
PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
40-
"The 1st dimension of Input(X) and Input(Y) should "
41-
"be equal.");
42-
PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1,
43-
"The 2nd dimension of Input(Y) should be odd.");
44-
PADDLE_ENFORCE_LE(y_dims[1], x_dims[1],
45-
"The 2nd dimension of Input(Y) should be less than or "
46-
"equal to the 2nd dimension of Input(X).");
39+
if (ctx->IsRuntime() || (x_dims[0] > 0 && y_dims[0] > 0))
40+
PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
41+
"The 1st dimension of Input(X) and Input(Y) should "
42+
"be equal.");
43+
if (ctx->IsRuntime() || y_dims[1] > 0)
44+
PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1,
45+
"The 2nd dimension of Input(Y) should be odd.");
46+
if (ctx->IsRuntime() || (x_dims[1] > 0 && y_dims[1] > 0))
47+
PADDLE_ENFORCE_LE(y_dims[1], x_dims[1],
48+
"The 2nd dimension of Input(Y) should be less than or "
49+
"equal to the 2nd dimension of Input(X).");
4750
ctx->ShareDim("X", /*->*/ "Out");
4851
ctx->ShareLoD("X", /*->*/ "Out");
4952
}

0 commit comments

Comments
 (0)