Skip to content

Commit 36f08ee

Browse files
authored
CUDA kernel for density_prior_box_op. (#14513)
* CUDA kernel for density_prior_box_op. * Support flatten to 2D.
1 parent dfbdece commit 36f08ee

File tree

9 files changed

+305
-117
lines changed

9 files changed

+305
-117
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ paddle.fluid.layers.hard_shrink ArgSpec(args=['x', 'threshold'], varargs=None, k
276276
paddle.fluid.layers.cumsum ArgSpec(args=['x', 'axis', 'exclusive', 'reverse'], varargs=None, keywords=None, defaults=(None, None, None))
277277
paddle.fluid.layers.thresholded_relu ArgSpec(args=['x', 'threshold'], varargs=None, keywords=None, defaults=(None,))
278278
paddle.fluid.layers.prior_box ArgSpec(args=['input', 'image', 'min_sizes', 'max_sizes', 'aspect_ratios', 'variance', 'flip', 'clip', 'steps', 'offset', 'name', 'min_max_aspect_ratios_order'], varargs=None, keywords=None, defaults=(None, [1.0], [0.1, 0.1, 0.2, 0.2], False, False, [0.0, 0.0], 0.5, None, False))
279-
paddle.fluid.layers.density_prior_box ArgSpec(args=['input', 'image', 'densities', 'fixed_sizes', 'fixed_ratios', 'variance', 'clip', 'steps', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, None, [0.1, 0.1, 0.2, 0.2], False, [0.0, 0.0], 0.5, None))
279+
paddle.fluid.layers.density_prior_box ArgSpec(args=['input', 'image', 'densities', 'fixed_sizes', 'fixed_ratios', 'variance', 'clip', 'steps', 'offset', 'flatten_to_2d', 'name'], varargs=None, keywords=None, defaults=(None, None, None, [0.1, 0.1, 0.2, 0.2], False, [0.0, 0.0], 0.5, False, None))
280280
paddle.fluid.layers.multi_box_head ArgSpec(args=['inputs', 'image', 'base_size', 'num_classes', 'aspect_ratios', 'min_ratio', 'max_ratio', 'min_sizes', 'max_sizes', 'steps', 'step_w', 'step_h', 'offset', 'variance', 'flip', 'clip', 'kernel_size', 'pad', 'stride', 'name', 'min_max_aspect_ratios_order'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, 0.5, [0.1, 0.1, 0.2, 0.2], True, False, 1, 0, 1, None, False))
281281
paddle.fluid.layers.bipartite_match ArgSpec(args=['dist_matrix', 'match_type', 'dist_threshold', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
282282
paddle.fluid.layers.target_assign ArgSpec(args=['input', 'matched_indices', 'negative_indices', 'mismatch_value', 'name'], varargs=None, keywords=None, defaults=(None, None, None))

paddle/fluid/framework/op_desc.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,12 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
252252
this->attrs_[name] = std::vector<int>();
253253
break;
254254
}
255+
case proto::AttrType::LONGS: {
256+
VLOG(110) << "SetAttr: " << Type() << ", " << name
257+
<< " from LONGS to LONGS";
258+
this->attrs_[name] = std::vector<int64_t>();
259+
break;
260+
}
255261
case proto::AttrType::FLOATS: {
256262
VLOG(110) << "SetAttr: " << Type() << ", " << name
257263
<< " from INTS to FLOATS";

paddle/fluid/operators/detection/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ iou_similarity_op.cu)
2222
detection_library(mine_hard_examples_op SRCS mine_hard_examples_op.cc)
2323
detection_library(multiclass_nms_op SRCS multiclass_nms_op.cc poly_util.cc gpc.cc)
2424
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu)
25-
detection_library(density_prior_box_op SRCS density_prior_box_op.cc)
25+
detection_library(density_prior_box_op SRCS density_prior_box_op.cc density_prior_box_op.cu)
2626
detection_library(anchor_generator_op SRCS anchor_generator_op.cc
2727
anchor_generator_op.cu)
2828
detection_library(target_assign_op SRCS target_assign_op.cc

paddle/fluid/operators/detection/density_prior_box_op.cc

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,32 +39,35 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel {
3939
auto fixed_sizes = ctx->Attrs().Get<std::vector<float>>("fixed_sizes");
4040
auto fixed_ratios = ctx->Attrs().Get<std::vector<float>>("fixed_ratios");
4141
auto densities = ctx->Attrs().Get<std::vector<int>>("densities");
42+
bool flatten = ctx->Attrs().Get<bool>("flatten_to_2d");
4243

4344
PADDLE_ENFORCE_EQ(fixed_sizes.size(), densities.size(),
4445
"The number of fixed_sizes and densities must be equal.");
4546
size_t num_priors = 0;
46-
if ((fixed_sizes.size() > 0) && (densities.size() > 0)) {
47-
for (size_t i = 0; i < densities.size(); ++i) {
48-
if (fixed_ratios.size() > 0) {
49-
num_priors += (fixed_ratios.size()) * (pow(densities[i], 2));
50-
}
51-
}
47+
for (size_t i = 0; i < densities.size(); ++i) {
48+
num_priors += (fixed_ratios.size()) * (pow(densities[i], 2));
49+
}
50+
if (!flatten) {
51+
std::vector<int64_t> dim_vec(4);
52+
dim_vec[0] = input_dims[2];
53+
dim_vec[1] = input_dims[3];
54+
dim_vec[2] = num_priors;
55+
dim_vec[3] = 4;
56+
ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec));
57+
ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec));
58+
} else {
59+
int64_t dim0 = input_dims[2] * input_dims[3] * num_priors;
60+
ctx->SetOutputDim("Boxes", {dim0, 4});
61+
ctx->SetOutputDim("Variances", {dim0, 4});
5262
}
53-
std::vector<int64_t> dim_vec(4);
54-
dim_vec[0] = input_dims[2];
55-
dim_vec[1] = input_dims[3];
56-
dim_vec[2] = num_priors;
57-
dim_vec[3] = 4;
58-
ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec));
59-
ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec));
6063
}
6164

6265
protected:
6366
framework::OpKernelType GetExpectedKernelType(
6467
const framework::ExecutionContext& ctx) const override {
6568
return framework::OpKernelType(
6669
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),
67-
platform::CPUPlace());
70+
ctx.GetPlace());
6871
}
6972
};
7073

@@ -101,7 +104,10 @@ class DensityPriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
101104
});
102105
AddAttr<bool>("clip", "(bool) Whether to clip out-of-boundary boxes.")
103106
.SetDefault(true);
104-
107+
AddAttr<bool>("flatten_to_2d",
108+
"(bool) Whether to flatten to 2D and "
109+
"the second dim is 4.")
110+
.SetDefault(false);
105111
AddAttr<float>(
106112
"step_w",
107113
"Density prior boxes step across width, 0.0 for auto calculation.")
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/detection/density_prior_box_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
template <typename T>
21+
static __device__ inline T Clip(T in) {
22+
return min(max(in, 0.), 1.);
23+
}
24+
25+
template <typename T>
26+
static __global__ void GenDensityPriorBox(
27+
const int height, const int width, const int im_height, const int im_width,
28+
const T offset, const T step_width, const T step_height,
29+
const int num_priors, const T* ratios_shift, bool is_clip, const T var_xmin,
30+
const T var_ymin, const T var_xmax, const T var_ymax, T* out, T* var) {
31+
int gidx = blockIdx.x * blockDim.x + threadIdx.x;
32+
int gidy = blockIdx.y * blockDim.y + threadIdx.y;
33+
int step_x = blockDim.x * gridDim.x;
34+
int step_y = blockDim.y * gridDim.y;
35+
36+
const T* width_ratio = ratios_shift;
37+
const T* height_ratio = ratios_shift + num_priors;
38+
const T* width_shift = ratios_shift + 2 * num_priors;
39+
const T* height_shift = ratios_shift + 3 * num_priors;
40+
41+
for (int j = gidy; j < height; j += step_y) {
42+
for (int i = gidx; i < width * num_priors; i += step_x) {
43+
int h = j;
44+
int w = i / num_priors;
45+
int k = i % num_priors;
46+
47+
T center_x = (w + offset) * step_width;
48+
T center_y = (h + offset) * step_height;
49+
50+
T center_x_temp = center_x + width_shift[k];
51+
T center_y_temp = center_y + height_shift[k];
52+
53+
T box_width_ratio = width_ratio[k] / 2.;
54+
T box_height_ratio = height_ratio[k] / 2.;
55+
56+
T xmin = max((center_x_temp - box_width_ratio) / im_width, 0.);
57+
T ymin = max((center_y_temp - box_height_ratio) / im_height, 0.);
58+
T xmax = min((center_x_temp + box_width_ratio) / im_width, 1.);
59+
T ymax = min((center_y_temp + box_height_ratio) / im_height, 1.);
60+
61+
int out_offset = (j * width * num_priors + i) * 4;
62+
out[out_offset] = is_clip ? Clip<T>(xmin) : xmin;
63+
out[out_offset + 1] = is_clip ? Clip<T>(ymin) : ymin;
64+
out[out_offset + 2] = is_clip ? Clip<T>(xmax) : xmax;
65+
out[out_offset + 3] = is_clip ? Clip<T>(ymax) : ymax;
66+
67+
var[out_offset] = var_xmin;
68+
var[out_offset + 1] = var_ymin;
69+
var[out_offset + 2] = var_xmax;
70+
var[out_offset + 3] = var_ymax;
71+
}
72+
}
73+
}
74+
75+
template <typename T>
76+
class DensityPriorBoxOpCUDAKernel : public framework::OpKernel<T> {
77+
public:
78+
void Compute(const framework::ExecutionContext& ctx) const override {
79+
auto* input = ctx.Input<paddle::framework::Tensor>("Input");
80+
auto* image = ctx.Input<paddle::framework::Tensor>("Image");
81+
auto* boxes = ctx.Output<paddle::framework::Tensor>("Boxes");
82+
auto* vars = ctx.Output<paddle::framework::Tensor>("Variances");
83+
84+
auto variances = ctx.Attr<std::vector<float>>("variances");
85+
auto is_clip = ctx.Attr<bool>("clip");
86+
87+
auto fixed_sizes = ctx.Attr<std::vector<float>>("fixed_sizes");
88+
auto fixed_ratios = ctx.Attr<std::vector<float>>("fixed_ratios");
89+
auto densities = ctx.Attr<std::vector<int>>("densities");
90+
91+
T step_w = static_cast<T>(ctx.Attr<float>("step_w"));
92+
T step_h = static_cast<T>(ctx.Attr<float>("step_h"));
93+
T offset = static_cast<T>(ctx.Attr<float>("offset"));
94+
95+
auto img_width = image->dims()[3];
96+
auto img_height = image->dims()[2];
97+
98+
auto feature_width = input->dims()[3];
99+
auto feature_height = input->dims()[2];
100+
101+
T step_width, step_height;
102+
if (step_w == 0 || step_h == 0) {
103+
step_width = static_cast<T>(img_width) / feature_width;
104+
step_height = static_cast<T>(img_height) / feature_height;
105+
} else {
106+
step_width = step_w;
107+
step_height = step_h;
108+
}
109+
110+
int num_priors = 0;
111+
for (size_t i = 0; i < densities.size(); ++i) {
112+
num_priors += (fixed_ratios.size()) * (pow(densities[i], 2));
113+
}
114+
int step_average = static_cast<int>((step_width + step_height) * 0.5);
115+
116+
framework::Tensor h_temp;
117+
T* tdata = h_temp.mutable_data<T>({num_priors * 4}, platform::CPUPlace());
118+
int idx = 0;
119+
for (size_t s = 0; s < fixed_sizes.size(); ++s) {
120+
auto fixed_size = fixed_sizes[s];
121+
int density = densities[s];
122+
for (size_t r = 0; r < fixed_ratios.size(); ++r) {
123+
float ar = fixed_ratios[r];
124+
int shift = step_average / density;
125+
float box_width_ratio = fixed_size * sqrt(ar);
126+
float box_height_ratio = fixed_size / sqrt(ar);
127+
for (int di = 0; di < density; ++di) {
128+
for (int dj = 0; dj < density; ++dj) {
129+
float center_x_temp = shift / 2. + dj * shift - step_average / 2.;
130+
float center_y_temp = shift / 2. + di * shift - step_average / 2.;
131+
tdata[idx] = box_width_ratio;
132+
tdata[num_priors + idx] = box_height_ratio;
133+
tdata[2 * num_priors + idx] = center_x_temp;
134+
tdata[3 * num_priors + idx] = center_y_temp;
135+
idx++;
136+
}
137+
}
138+
}
139+
}
140+
141+
boxes->mutable_data<T>(ctx.GetPlace());
142+
vars->mutable_data<T>(ctx.GetPlace());
143+
144+
framework::Tensor d_temp;
145+
framework::TensorCopySync(h_temp, ctx.GetPlace(), &d_temp);
146+
147+
// At least use 32 threads, at most 512 threads.
148+
// blockx is multiple of 32.
149+
int blockx = std::min(((feature_width * num_priors + 31) >> 5) << 5, 512L);
150+
int gridx = (feature_width * num_priors + blockx - 1) / blockx;
151+
dim3 threads(blockx, 1);
152+
dim3 grids(gridx, feature_height);
153+
154+
auto stream =
155+
ctx.template device_context<platform::CUDADeviceContext>().stream();
156+
GenDensityPriorBox<T><<<grids, threads, 0, stream>>>(
157+
feature_height, feature_width, img_height, img_width, offset,
158+
step_width, step_height, num_priors, d_temp.data<T>(), is_clip,
159+
variances[0], variances[1], variances[2], variances[3],
160+
boxes->data<T>(), vars->data<T>());
161+
}
162+
}; // namespace operators
163+
164+
} // namespace operators
165+
} // namespace paddle
166+
167+
namespace ops = paddle::operators;
168+
REGISTER_OP_CUDA_KERNEL(density_prior_box,
169+
ops::DensityPriorBoxOpCUDAKernel<float>,
170+
ops::DensityPriorBoxOpCUDAKernel<double>);

paddle/fluid/operators/detection/density_prior_box_op.h

Lines changed: 35 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
22
licensed under the Apache License, Version 2.0 (the "License");
33
you may not use this file except in compliance with the License.
44
You may obtain a copy of the License at
@@ -52,18 +52,16 @@ class DensityPriorBoxOpKernel : public framework::OpKernel<T> {
5252
step_height = step_h;
5353
}
5454
int num_priors = 0;
55-
if (fixed_sizes.size() > 0 && densities.size() > 0) {
56-
for (size_t i = 0; i < densities.size(); ++i) {
57-
if (fixed_ratios.size() > 0) {
58-
num_priors += (fixed_ratios.size()) * (pow(densities[i], 2));
59-
}
60-
}
55+
for (size_t i = 0; i < densities.size(); ++i) {
56+
num_priors += (fixed_ratios.size()) * (pow(densities[i], 2));
6157
}
6258

6359
boxes->mutable_data<T>(ctx.GetPlace());
6460
vars->mutable_data<T>(ctx.GetPlace());
65-
auto e_boxes = framework::EigenTensor<T, 4>::From(*boxes).setConstant(0.0);
6661

62+
auto box_dim = vars->dims();
63+
boxes->Resize({feature_height, feature_width, num_priors, 4});
64+
auto e_boxes = framework::EigenTensor<T, 4>::From(*boxes).setConstant(0.0);
6765
int step_average = static_cast<int>((step_width + step_height) * 0.5);
6866

6967
for (int h = 0; h < feature_height; ++h) {
@@ -76,36 +74,34 @@ class DensityPriorBoxOpKernel : public framework::OpKernel<T> {
7674
auto fixed_size = fixed_sizes[s];
7775
int density = densities[s];
7876
// Generate density prior boxes with fixed ratios.
79-
if (fixed_ratios.size() > 0) {
80-
for (size_t r = 0; r < fixed_ratios.size(); ++r) {
81-
float ar = fixed_ratios[r];
82-
int shift = step_average / density;
83-
float box_width_ratio = fixed_size * sqrt(ar);
84-
float box_height_ratio = fixed_size / sqrt(ar);
85-
for (int di = 0; di < density; ++di) {
86-
for (int dj = 0; dj < density; ++dj) {
87-
float center_x_temp =
88-
center_x - step_average / 2. + shift / 2. + dj * shift;
89-
float center_y_temp =
90-
center_y - step_average / 2. + shift / 2. + di * shift;
91-
e_boxes(h, w, idx, 0) =
92-
(center_x_temp - box_width_ratio / 2.) / img_width >= 0
93-
? (center_x_temp - box_width_ratio / 2.) / img_width
94-
: 0;
95-
e_boxes(h, w, idx, 1) =
96-
(center_y_temp - box_height_ratio / 2.) / img_height >= 0
97-
? (center_y_temp - box_height_ratio / 2.) / img_height
98-
: 0;
99-
e_boxes(h, w, idx, 2) =
100-
(center_x_temp + box_width_ratio / 2.) / img_width <= 1
101-
? (center_x_temp + box_width_ratio / 2.) / img_width
102-
: 1;
103-
e_boxes(h, w, idx, 3) =
104-
(center_y_temp + box_height_ratio / 2.) / img_height <= 1
105-
? (center_y_temp + box_height_ratio / 2.) / img_height
106-
: 1;
107-
idx++;
108-
}
77+
for (size_t r = 0; r < fixed_ratios.size(); ++r) {
78+
float ar = fixed_ratios[r];
79+
int shift = step_average / density;
80+
float box_width_ratio = fixed_size * sqrt(ar);
81+
float box_height_ratio = fixed_size / sqrt(ar);
82+
for (int di = 0; di < density; ++di) {
83+
for (int dj = 0; dj < density; ++dj) {
84+
float center_x_temp =
85+
center_x - step_average / 2. + shift / 2. + dj * shift;
86+
float center_y_temp =
87+
center_y - step_average / 2. + shift / 2. + di * shift;
88+
e_boxes(h, w, idx, 0) =
89+
(center_x_temp - box_width_ratio / 2.) / img_width >= 0
90+
? (center_x_temp - box_width_ratio / 2.) / img_width
91+
: 0;
92+
e_boxes(h, w, idx, 1) =
93+
(center_y_temp - box_height_ratio / 2.) / img_height >= 0
94+
? (center_y_temp - box_height_ratio / 2.) / img_height
95+
: 0;
96+
e_boxes(h, w, idx, 2) =
97+
(center_x_temp + box_width_ratio / 2.) / img_width <= 1
98+
? (center_x_temp + box_width_ratio / 2.) / img_width
99+
: 1;
100+
e_boxes(h, w, idx, 3) =
101+
(center_y_temp + box_height_ratio / 2.) / img_height <= 1
102+
? (center_y_temp + box_height_ratio / 2.) / img_height
103+
: 1;
104+
idx++;
109105
}
110106
}
111107
}
@@ -139,6 +135,7 @@ class DensityPriorBoxOpKernel : public framework::OpKernel<T> {
139135
e_vars = var_et.broadcast(Eigen::DSizes<int, 2>(box_num, 1));
140136

141137
vars->Resize(var_dim);
138+
boxes->Resize(box_dim);
142139
}
143140
}; // namespace operators
144141

0 commit comments

Comments
 (0)