@@ -28,19 +28,20 @@ inline BoxCodeType GetBoxCodeType(const std::string& type) {
28
28
PADDLE_THROW (" Not support type %s." , type);
29
29
}
30
30
31
- template <typename T>
31
+ template <typename DeviceContext, typename T>
32
32
class BoxCoderKernel : public framework ::OpKernel<T> {
33
33
public:
34
- void EncodeCenterSize (const framework::Tensor& target_box,
35
- const framework::Tensor& prior_box,
36
- const framework::Tensor& prior_box_var,
34
+ void EncodeCenterSize (const framework::Tensor* target_box,
35
+ const framework::Tensor* prior_box,
36
+ const framework::Tensor* prior_box_var,
37
37
const bool normalized, T* output) const {
38
- int64_t row = target_box.dims ()[0 ];
39
- int64_t col = prior_box.dims ()[0 ];
40
- int64_t len = prior_box.dims ()[1 ];
41
- auto * target_box_data = target_box.data <T>();
42
- auto * prior_box_data = prior_box.data <T>();
43
- auto * prior_box_var_data = prior_box_var.data <T>();
38
+ int64_t row = target_box->dims ()[0 ];
39
+ int64_t col = prior_box->dims ()[0 ];
40
+ int64_t len = prior_box->dims ()[1 ];
41
+ auto * target_box_data = target_box->data <T>();
42
+ auto * prior_box_data = prior_box->data <T>();
43
+ const T* prior_box_var_data = nullptr ;
44
+ if (prior_box_var) prior_box_var_data = prior_box_var->data <T>();
44
45
45
46
for (int64_t i = 0 ; i < row; ++i) {
46
47
for (int64_t j = 0 ; j < col; ++j) {
@@ -65,30 +66,35 @@ class BoxCoderKernel : public framework::OpKernel<T> {
65
66
(normalized == false );
66
67
67
68
size_t offset = i * col * len + j * len;
68
- output[offset] = (target_box_center_x - prior_box_center_x) /
69
- prior_box_width / prior_box_var_data[j * len] ;
70
- output[offset + 1 ] = (target_box_center_y - prior_box_center_y) /
71
- prior_box_height / prior_box_var_data[j * len + 1 ] ;
69
+ output[offset] =
70
+ (target_box_center_x - prior_box_center_x) / prior_box_width ;
71
+ output[offset + 1 ] =
72
+ (target_box_center_y - prior_box_center_y) / prior_box_height ;
72
73
output[offset + 2 ] =
73
- std::log (std::fabs (target_box_width / prior_box_width)) /
74
- prior_box_var_data[j * len + 2 ];
74
+ std::log (std::fabs (target_box_width / prior_box_width));
75
75
output[offset + 3 ] =
76
- std::log (std::fabs (target_box_height / prior_box_height)) /
77
- prior_box_var_data[j * len + 3 ];
76
+ std::log (std::fabs (target_box_height / prior_box_height));
77
+ if (prior_box_var) {
78
+ output[offset] /= prior_box_var_data[j * len];
79
+ output[offset + 1 ] /= prior_box_var_data[j * len + 1 ];
80
+ output[offset + 2 ] /= prior_box_var_data[j * len + 2 ];
81
+ output[offset + 3 ] /= prior_box_var_data[j * len + 3 ];
82
+ }
78
83
}
79
84
}
80
85
}
81
- void DecodeCenterSize (const framework::Tensor& target_box,
82
- const framework::Tensor& prior_box,
83
- const framework::Tensor& prior_box_var,
86
+ void DecodeCenterSize (const framework::Tensor* target_box,
87
+ const framework::Tensor* prior_box,
88
+ const framework::Tensor* prior_box_var,
84
89
const bool normalized, T* output) const {
85
- int64_t row = target_box. dims ()[0 ];
86
- int64_t col = prior_box. dims ()[0 ];
87
- int64_t len = prior_box. dims ()[1 ];
90
+ int64_t row = target_box-> dims ()[0 ];
91
+ int64_t col = prior_box-> dims ()[0 ];
92
+ int64_t len = prior_box-> dims ()[1 ];
88
93
89
- auto * target_box_data = target_box.data <T>();
90
- auto * prior_box_data = prior_box.data <T>();
91
- auto * prior_box_var_data = prior_box_var.data <T>();
94
+ auto * target_box_data = target_box->data <T>();
95
+ auto * prior_box_data = prior_box->data <T>();
96
+ const T* prior_box_var_data = nullptr ;
97
+ if (prior_box_var) prior_box_var_data = prior_box_var->data <T>();
92
98
93
99
for (int64_t i = 0 ; i < row; ++i) {
94
100
for (int64_t j = 0 ; j < col; ++j) {
@@ -103,19 +109,32 @@ class BoxCoderKernel : public framework::OpKernel<T> {
103
109
T prior_box_center_y =
104
110
(prior_box_data[j * len + 3 ] + prior_box_data[j * len + 1 ]) / 2 ;
105
111
106
- T target_box_center_x = prior_box_var_data[j * len] *
112
+ T target_box_center_x = 0 , target_box_center_y = 0 ;
113
+ T target_box_width = 0 , target_box_height = 0 ;
114
+ if (prior_box_var) {
115
+ target_box_center_x = prior_box_var_data[j * len] *
107
116
target_box_data[offset] * prior_box_width +
108
117
prior_box_center_x;
109
- T target_box_center_y = prior_box_var_data[j * len + 1 ] *
118
+ target_box_center_y = prior_box_var_data[j * len + 1 ] *
110
119
target_box_data[offset + 1 ] *
111
120
prior_box_height +
112
121
prior_box_center_y;
113
- T target_box_width = std::exp (prior_box_var_data[j * len + 2 ] *
122
+ target_box_width = std::exp (prior_box_var_data[j * len + 2 ] *
114
123
target_box_data[offset + 2 ]) *
115
124
prior_box_width;
116
- T target_box_height = std::exp (prior_box_var_data[j * len + 3 ] *
125
+ target_box_height = std::exp (prior_box_var_data[j * len + 3 ] *
117
126
target_box_data[offset + 3 ]) *
118
127
prior_box_height;
128
+ } else {
129
+ target_box_center_x =
130
+ target_box_data[offset] * prior_box_width + prior_box_center_x;
131
+ target_box_center_y = target_box_data[offset + 1 ] * prior_box_height +
132
+ prior_box_center_y;
133
+ target_box_width =
134
+ std::exp (target_box_data[offset + 2 ]) * prior_box_width;
135
+ target_box_height =
136
+ std::exp (target_box_data[offset + 3 ]) * prior_box_height;
137
+ }
119
138
120
139
output[offset] = target_box_center_x - target_box_width / 2 ;
121
140
output[offset + 1 ] = target_box_center_y - target_box_height / 2 ;
@@ -147,10 +166,10 @@ class BoxCoderKernel : public framework::OpKernel<T> {
147
166
bool normalized = context.Attr <bool >(" box_normalized" );
148
167
T* output = output_box->data <T>();
149
168
if (code_type == BoxCodeType::kEncodeCenterSize ) {
150
- EncodeCenterSize (* target_box, * prior_box, * prior_box_var, normalized,
169
+ EncodeCenterSize (target_box, prior_box, prior_box_var, normalized,
151
170
output);
152
171
} else if (code_type == BoxCodeType::kDecodeCenterSize ) {
153
- DecodeCenterSize (* target_box, * prior_box, * prior_box_var, normalized,
172
+ DecodeCenterSize (target_box, prior_box, prior_box_var, normalized,
154
173
output);
155
174
}
156
175
}
0 commit comments