Skip to content

Commit 9ab738a

Browse files
authored
cherry-pick fix shape check in density_prior_box, test=release/1.6 (#21474)
1 parent 893ea7e commit 9ab738a

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

paddle/fluid/operators/detection/density_prior_box_op.cc

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,23 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel {
2929
PADDLE_ENFORCE(image_dims.size() == 4, "The layout of image is NCHW.");
3030
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
3131

32-
PADDLE_ENFORCE_LT(input_dims[2], image_dims[2],
33-
"The height of input must smaller than image.");
34-
35-
PADDLE_ENFORCE_LT(input_dims[3], image_dims[3],
36-
"The width of input must smaller than image.");
32+
if (ctx->IsRuntime()) {
33+
PADDLE_ENFORCE_LT(
34+
input_dims[2], image_dims[2],
35+
platform::errors::InvalidArgument(
36+
"The input tensor Input's height"
37+
"of DensityPriorBoxOp should be smaller than input tensor Image's"
38+
"hight. But received Input's height = %d, Image's height = %d",
39+
input_dims[2], image_dims[2]));
40+
41+
PADDLE_ENFORCE_LT(
42+
input_dims[3], image_dims[3],
43+
platform::errors::InvalidArgument(
44+
"The input tensor Input's width"
45+
"of DensityPriorBoxOp should be smaller than input tensor Image's"
46+
"width. But received Input's width = %d, Image's width = %d",
47+
input_dims[3], image_dims[3]));
48+
}
3749
auto variances = ctx->Attrs().Get<std::vector<float>>("variances");
3850

3951
auto fixed_sizes = ctx->Attrs().Get<std::vector<float>>("fixed_sizes");
@@ -55,10 +67,13 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel {
5567
dim_vec[3] = 4;
5668
ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec));
5769
ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec));
58-
} else {
70+
} else if (ctx->IsRuntime()) {
5971
int64_t dim0 = input_dims[2] * input_dims[3] * num_priors;
6072
ctx->SetOutputDim("Boxes", {dim0, 4});
6173
ctx->SetOutputDim("Variances", {dim0, 4});
74+
} else {
75+
ctx->SetOutputDim("Boxes", {-1, 4});
76+
ctx->SetOutputDim("Variances", {-1, 4});
6277
}
6378
}
6479

0 commit comments

Comments
 (0)