@@ -143,22 +143,18 @@ class UnfoldOp : public framework::OperatorWithKernel {
143143 " but recieved dilations_height: %d dilations_width: %d." ,
144144 dilations[0 ], dilations[1 ]));
145145
146- bool contain_unknown_dim = framework::contain_unknown_dim (in_dims);
147- bool check = ctx->IsRuntime () || !contain_unknown_dim;
148- if (check) {
149- std::vector<int > out_dims;
150- out_dims.push_back (in_dims[0 ]);
151-
152- int output_channels = in_dims[1 ] * kernel_sizes[0 ] * kernel_sizes[1 ];
153- out_dims.push_back (output_channels);
154-
155- int output_height =
156- CalcOutputSize (in_dims[2 ], kernel_sizes[0 ], dilations[0 ], paddings[0 ],
157- paddings[2 ], strides[0 ]);
158- int output_width =
159- CalcOutputSize (in_dims[3 ], kernel_sizes[1 ], dilations[1 ], paddings[1 ],
160- paddings[3 ], strides[1 ]);
161- // check output height and width
146+ std::vector<int > out_dims;
147+ out_dims.push_back (in_dims[0 ]);
148+ int output_channels = in_dims[1 ] * kernel_sizes[0 ] * kernel_sizes[1 ];
149+ out_dims.push_back (output_channels);
150+
151+ int output_height =
152+ CalcOutputSize (in_dims[2 ], kernel_sizes[0 ], dilations[0 ], paddings[0 ],
153+ paddings[2 ], strides[0 ]);
154+ int output_width = CalcOutputSize (in_dims[3 ], kernel_sizes[1 ], dilations[1 ],
155+ paddings[1 ], paddings[3 ], strides[1 ]);
156+ if (ctx->IsRuntime ()) {
157+ // only check output height and width in runtime
162158 PADDLE_ENFORCE_GT (
163159 output_height, 0 ,
164160 platform::errors::InvalidArgument (
@@ -179,11 +175,10 @@ class UnfoldOp : public framework::OperatorWithKernel {
179175 in_dims[2 ], in_dims[3 ], kernel_sizes[0 ], kernel_sizes[1 ],
180176 strides[0 ], strides[1 ], dilations[0 ], dilations[1 ], output_height,
181177 output_width));
182- int output_col_length = output_height * output_width;
183- out_dims.push_back (output_col_length);
184-
185- ctx->SetOutputDim (" Y" , framework::make_ddim (out_dims));
186178 }
179+ int output_col_length = output_height * output_width;
180+ out_dims.push_back (output_col_length);
181+ ctx->SetOutputDim (" Y" , framework::make_ddim (out_dims));
187182 }
188183
189184 protected:
0 commit comments