Skip to content

Commit d3e99ae

Browse files
Noplzqingqing01
authored andcommitted
add normalize switch to box_coder_op (#11129)
1 parent e0a8c58 commit d3e99ae

File tree

4 files changed

+78
-44
lines changed

4 files changed

+78
-44
lines changed

paddle/fluid/operators/detection/box_coder_op.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
9191
"the code type used with the target box")
9292
.SetDefault("encode_center_size")
9393
.InEnum({"encode_center_size", "decode_center_size"});
94+
AddAttr<bool>("box_normalized",
95+
"(bool, default true) "
96+
"whether treat the priorbox as a noramlized box")
97+
.SetDefault(true);
9498
AddOutput("OutputBox",
9599
"(LoDTensor or Tensor) "
96100
"When code_type is 'encode_center_size', the output tensor of "

paddle/fluid/operators/detection/box_coder_op.cu

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,16 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data,
2020
const T* prior_box_var_data,
2121
const T* target_box_data, const int row,
2222
const int col, const int len,
23-
T* output) {
23+
const bool normalized, T* output) {
2424
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
2525
if (idx < row * col) {
2626
const int row_idx = idx / col;
2727
const int col_idx = idx % col;
28-
T prior_box_width =
29-
prior_box_data[col_idx * len + 2] - prior_box_data[col_idx * len];
30-
T prior_box_height =
31-
prior_box_data[col_idx * len + 3] - prior_box_data[col_idx * len + 1];
28+
T prior_box_width = prior_box_data[col_idx * len + 2] -
29+
prior_box_data[col_idx * len] + (normalized == false);
30+
T prior_box_height = prior_box_data[col_idx * len + 3] -
31+
prior_box_data[col_idx * len + 1] +
32+
(normalized == false);
3233
T prior_box_center_x =
3334
(prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2;
3435
T prior_box_center_y = (prior_box_data[col_idx * len + 3] +
@@ -41,10 +42,11 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data,
4142
T target_box_center_y = (target_box_data[row_idx * len + 3] +
4243
target_box_data[row_idx * len + 1]) /
4344
2;
44-
T target_box_width =
45-
target_box_data[row_idx * len + 2] - target_box_data[row_idx * len];
46-
T target_box_height =
47-
target_box_data[row_idx * len + 3] - target_box_data[row_idx * len + 1];
45+
T target_box_width = target_box_data[row_idx * len + 2] -
46+
target_box_data[row_idx * len] + (normalized == false);
47+
T target_box_height = target_box_data[row_idx * len + 3] -
48+
target_box_data[row_idx * len + 1] +
49+
(normalized == false);
4850

4951
output[idx * len] = (target_box_center_x - prior_box_center_x) /
5052
prior_box_width / prior_box_var_data[col_idx * len];
@@ -63,14 +65,15 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data,
6365
const T* prior_box_var_data,
6466
const T* target_box_data, const int row,
6567
const int col, const int len,
66-
T* output) {
68+
const bool normalized, T* output) {
6769
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
6870
if (idx < row * col) {
6971
const int col_idx = idx % col;
70-
T prior_box_width =
71-
prior_box_data[col_idx * len + 2] - prior_box_data[col_idx * len];
72-
T prior_box_height =
73-
prior_box_data[col_idx * len + 3] - prior_box_data[col_idx * len + 1];
72+
T prior_box_width = prior_box_data[col_idx * len + 2] -
73+
prior_box_data[col_idx * len] + (normalized == false);
74+
T prior_box_height = prior_box_data[col_idx * len + 3] -
75+
prior_box_data[col_idx * len + 1] +
76+
(normalized == false);
7477
T prior_box_center_x =
7578
(prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2;
7679
T prior_box_center_y = (prior_box_data[col_idx * len + 3] +
@@ -93,8 +96,10 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data,
9396

9497
output[idx * len] = target_box_center_x - target_box_width / 2;
9598
output[idx * len + 1] = target_box_center_y - target_box_height / 2;
96-
output[idx * len + 2] = target_box_center_x + target_box_width / 2;
97-
output[idx * len + 3] = target_box_center_y + target_box_height / 2;
99+
output[idx * len + 2] =
100+
target_box_center_x + target_box_width / 2 - (normalized == false);
101+
output[idx * len + 3] =
102+
target_box_center_y + target_box_height / 2 - (normalized == false);
98103
}
99104
}
100105

@@ -128,14 +133,15 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
128133
T* output = output_box->data<T>();
129134

130135
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
136+
bool normalized = context.Attr<bool>("box_normalized");
131137
if (code_type == BoxCodeType::kEncodeCenterSize) {
132138
EncodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
133139
prior_box_data, prior_box_var_data, target_box_data, row, col, len,
134-
output);
140+
normalized, output);
135141
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
136142
DecodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
137143
prior_box_data, prior_box_var_data, target_box_data, row, col, len,
138-
output);
144+
normalized, output);
139145
}
140146
}
141147
};

paddle/fluid/operators/detection/box_coder_op.h

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class BoxCoderKernel : public framework::OpKernel<T> {
3434
void EncodeCenterSize(const framework::Tensor& target_box,
3535
const framework::Tensor& prior_box,
3636
const framework::Tensor& prior_box_var,
37-
T* output) const {
37+
const bool normalized, T* output) const {
3838
int64_t row = target_box.dims()[0];
3939
int64_t col = prior_box.dims()[0];
4040
int64_t len = prior_box.dims()[1];
@@ -44,10 +44,11 @@ class BoxCoderKernel : public framework::OpKernel<T> {
4444

4545
for (int64_t i = 0; i < row; ++i) {
4646
for (int64_t j = 0; j < col; ++j) {
47-
T prior_box_width =
48-
prior_box_data[j * len + 2] - prior_box_data[j * len];
49-
T prior_box_height =
50-
prior_box_data[j * len + 3] - prior_box_data[j * len + 1];
47+
T prior_box_width = prior_box_data[j * len + 2] -
48+
prior_box_data[j * len] + (normalized == false);
49+
T prior_box_height = prior_box_data[j * len + 3] -
50+
prior_box_data[j * len + 1] +
51+
(normalized == false);
5152
T prior_box_center_x =
5253
(prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2;
5354
T prior_box_center_y =
@@ -57,10 +58,11 @@ class BoxCoderKernel : public framework::OpKernel<T> {
5758
(target_box_data[i * len + 2] + target_box_data[i * len]) / 2;
5859
T target_box_center_y =
5960
(target_box_data[i * len + 3] + target_box_data[i * len + 1]) / 2;
60-
T target_box_width =
61-
target_box_data[i * len + 2] - target_box_data[i * len];
62-
T target_box_height =
63-
target_box_data[i * len + 3] - target_box_data[i * len + 1];
61+
T target_box_width = target_box_data[i * len + 2] -
62+
target_box_data[i * len] + (normalized == false);
63+
T target_box_height = target_box_data[i * len + 3] -
64+
target_box_data[i * len + 1] +
65+
(normalized == false);
6466

6567
size_t offset = i * col * len + j * len;
6668
output[offset] = (target_box_center_x - prior_box_center_x) /
@@ -79,7 +81,7 @@ class BoxCoderKernel : public framework::OpKernel<T> {
7981
void DecodeCenterSize(const framework::Tensor& target_box,
8082
const framework::Tensor& prior_box,
8183
const framework::Tensor& prior_box_var,
82-
T* output) const {
84+
const bool normalized, T* output) const {
8385
int64_t row = target_box.dims()[0];
8486
int64_t col = prior_box.dims()[0];
8587
int64_t len = prior_box.dims()[1];
@@ -91,10 +93,11 @@ class BoxCoderKernel : public framework::OpKernel<T> {
9193
for (int64_t i = 0; i < row; ++i) {
9294
for (int64_t j = 0; j < col; ++j) {
9395
size_t offset = i * col * len + j * len;
94-
T prior_box_width =
95-
prior_box_data[j * len + 2] - prior_box_data[j * len];
96-
T prior_box_height =
97-
prior_box_data[j * len + 3] - prior_box_data[j * len + 1];
96+
T prior_box_width = prior_box_data[j * len + 2] -
97+
prior_box_data[j * len] + (normalized == false);
98+
T prior_box_height = prior_box_data[j * len + 3] -
99+
prior_box_data[j * len + 1] +
100+
(normalized == false);
98101
T prior_box_center_x =
99102
(prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2;
100103
T prior_box_center_y =
@@ -116,8 +119,10 @@ class BoxCoderKernel : public framework::OpKernel<T> {
116119

117120
output[offset] = target_box_center_x - target_box_width / 2;
118121
output[offset + 1] = target_box_center_y - target_box_height / 2;
119-
output[offset + 2] = target_box_center_x + target_box_width / 2;
120-
output[offset + 3] = target_box_center_y + target_box_height / 2;
122+
output[offset + 2] =
123+
target_box_center_x + target_box_width / 2 - (normalized == false);
124+
output[offset + 3] =
125+
target_box_center_y + target_box_height / 2 - (normalized == false);
121126
}
122127
}
123128
}
@@ -139,11 +144,14 @@ class BoxCoderKernel : public framework::OpKernel<T> {
139144
output_box->mutable_data<T>({row, col, len}, context.GetPlace());
140145

141146
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
147+
bool normalized = context.Attr<bool>("box_normalized");
142148
T* output = output_box->data<T>();
143149
if (code_type == BoxCodeType::kEncodeCenterSize) {
144-
EncodeCenterSize(*target_box, *prior_box, *prior_box_var, output);
150+
EncodeCenterSize(*target_box, *prior_box, *prior_box_var, normalized,
151+
output);
145152
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
146-
DecodeCenterSize(*target_box, *prior_box, *prior_box_var, output);
153+
DecodeCenterSize(*target_box, *prior_box, *prior_box_var, normalized,
154+
output);
147155
}
148156
}
149157
};

python/paddle/fluid/tests/unittests/test_box_coder_op.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from op_test import OpTest
2020

2121

22-
def box_coder(target_box, prior_box, prior_box_var, output_box, code_type):
22+
def box_coder(target_box, prior_box, prior_box_var, output_box, code_type,
23+
box_normalized):
2324
prior_box_x = (
2425
(prior_box[:, 2] + prior_box[:, 0]) / 2).reshape(1, prior_box.shape[0])
2526
prior_box_y = (
@@ -30,6 +31,9 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type):
3031
(prior_box[:, 3] - prior_box[:, 1])).reshape(1, prior_box.shape[0])
3132
prior_box_var = prior_box_var.reshape(1, prior_box_var.shape[0],
3233
prior_box_var.shape[1])
34+
if not box_normalized:
35+
prior_box_height = prior_box_height + 1
36+
prior_box_width = prior_box_width + 1
3337

3438
if (code_type == "EncodeCenterSize"):
3539
target_box_x = ((target_box[:, 2] + target_box[:, 0]) / 2).reshape(
@@ -40,6 +44,9 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type):
4044
target_box.shape[0], 1)
4145
target_box_height = ((target_box[:, 3] - target_box[:, 1])).reshape(
4246
target_box.shape[0], 1)
47+
if not box_normalized:
48+
target_box_height = target_box_height + 1
49+
target_box_width = target_box_width + 1
4350

4451
output_box[:,:,0] = (target_box_x - prior_box_x) / prior_box_width / \
4552
prior_box_var[:,:,0]
@@ -64,21 +71,25 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type):
6471
output_box[:, :, 1] = target_box_y - target_box_height / 2
6572
output_box[:, :, 2] = target_box_x + target_box_width / 2
6673
output_box[:, :, 3] = target_box_y + target_box_height / 2
74+
if not box_normalized:
75+
output_box[:, :, 2] = output_box[:, :, 2] - 1
76+
output_box[:, :, 3] = output_box[:, :, 3] - 1
6777

6878

69-
def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type):
79+
def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type,
80+
box_normalized):
7081
n = target_box.shape[0]
7182
m = prior_box.shape[0]
7283
output_box = np.zeros((n, m, 4), dtype=np.float32)
7384
for i in range(len(lod) - 1):
7485
if (code_type == "EncodeCenterSize"):
7586
box_coder(target_box[lod[i]:lod[i + 1], :], prior_box,
7687
prior_box_var, output_box[lod[i]:lod[i + 1], :, :],
77-
code_type)
88+
code_type, box_normalized)
7889
elif (code_type == "DecodeCenterSize"):
7990
box_coder(target_box[lod[i]:lod[i + 1], :, :], prior_box,
8091
prior_box_var, output_box[lod[i]:lod[i + 1], :, :],
81-
code_type)
92+
code_type, box_normalized)
8293
return output_box
8394

8495

@@ -93,15 +104,19 @@ def setUp(self):
93104
prior_box_var = np.random.random((10, 4)).astype('float32')
94105
target_box = np.random.random((5, 10, 4)).astype('float32')
95106
code_type = "DecodeCenterSize"
107+
box_normalized = False
96108
output_box = batch_box_coder(prior_box, prior_box_var, target_box,
97-
lod[0], code_type)
109+
lod[0], code_type, box_normalized)
98110

99111
self.inputs = {
100112
'PriorBox': prior_box,
101113
'PriorBoxVar': prior_box_var,
102114
'TargetBox': target_box,
103115
}
104-
self.attrs = {'code_type': 'decode_center_size'}
116+
self.attrs = {
117+
'code_type': 'decode_center_size',
118+
'box_normalized': False
119+
}
105120
self.outputs = {'OutputBox': output_box}
106121

107122

@@ -116,15 +131,16 @@ def setUp(self):
116131
prior_box_var = np.random.random((10, 4)).astype('float32')
117132
target_box = np.random.random((20, 4)).astype('float32')
118133
code_type = "EncodeCenterSize"
134+
box_normalized = True
119135
output_box = batch_box_coder(prior_box, prior_box_var, target_box,
120-
lod[0], code_type)
136+
lod[0], code_type, box_normalized)
121137

122138
self.inputs = {
123139
'PriorBox': prior_box,
124140
'PriorBoxVar': prior_box_var,
125141
'TargetBox': (target_box, lod),
126142
}
127-
self.attrs = {'code_type': 'encode_center_size'}
143+
self.attrs = {'code_type': 'encode_center_size', 'box_normalized': True}
128144
self.outputs = {'OutputBox': output_box}
129145

130146

0 commit comments

Comments
 (0)