Skip to content

Commit 1743d1a

Browse files
authored
Merge pull request #15356 from jerrywgz/add_clip_op
Add box clip op
2 parents 43a67a2 + 4f18a9b commit 1743d1a

File tree

9 files changed

+368
-0
lines changed

9 files changed

+368
-0
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None
325325
paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name', 'axis'], varargs=None, keywords=None, defaults=('encode_center_size', True, None, 0))
326326
paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,))
327327
paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,))
328+
paddle.fluid.layers.box_clip ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,))
328329
paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None))
329330
paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None))
330331
paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1))

paddle/fluid/operators/detection/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc
3131
polygon_box_transform_op.cu)
3232
detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc)
3333
detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc)
34+
detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu)
3435
detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc)
3536

3637
if(WITH_GPU)

paddle/fluid/operators/detection/bbox_util.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,5 +99,29 @@ void BboxOverlaps(const framework::Tensor& r_boxes,
9999
}
100100
}
101101

102+
template <class T>
103+
void ClipTiledBoxes(const platform::DeviceContext& ctx,
104+
const framework::Tensor& im_info,
105+
const framework::Tensor& input_boxes,
106+
framework::Tensor* out) {
107+
T* out_data = out->mutable_data<T>(ctx.GetPlace());
108+
const T* im_info_data = im_info.data<T>();
109+
const T* input_boxes_data = input_boxes.data<T>();
110+
T zero(0);
111+
T im_w = round(im_info_data[1] / im_info_data[2]);
112+
T im_h = round(im_info_data[0] / im_info_data[2]);
113+
for (int64_t i = 0; i < input_boxes.numel(); ++i) {
114+
if (i % 4 == 0) {
115+
out_data[i] = std::max(std::min(input_boxes_data[i], im_w - 1), zero);
116+
} else if (i % 4 == 1) {
117+
out_data[i] = std::max(std::min(input_boxes_data[i], im_h - 1), zero);
118+
} else if (i % 4 == 2) {
119+
out_data[i] = std::max(std::min(input_boxes_data[i], im_w - 1), zero);
120+
} else {
121+
out_data[i] = std::max(std::min(input_boxes_data[i], im_h - 1), zero);
122+
}
123+
}
124+
}
125+
102126
} // namespace operators
103127
} // namespace paddle
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/fluid/operators/detection/box_clip_op.h"
13+
#include "paddle/fluid/framework/op_registry.h"
14+
15+
namespace paddle {
16+
namespace operators {
17+
18+
class BoxClipOp : public framework::OperatorWithKernel {
19+
public:
20+
using framework::OperatorWithKernel::OperatorWithKernel;
21+
22+
protected:
23+
void InferShape(framework::InferShapeContext* ctx) const override {
24+
PADDLE_ENFORCE(ctx->HasInput("Input"),
25+
"Input(Input) of BoxClipOp should not be null.");
26+
PADDLE_ENFORCE(ctx->HasInput("ImInfo"),
27+
"Input(ImInfo) of BoxClipOp should not be null.");
28+
29+
auto input_box_dims = ctx->GetInputDim("Input");
30+
auto im_info_dims = ctx->GetInputDim("ImInfo");
31+
32+
if (ctx->IsRuntime()) {
33+
auto input_box_size = input_box_dims.size();
34+
PADDLE_ENFORCE_EQ(input_box_dims[input_box_size - 1], 4,
35+
"The last dimension of Input must be 4");
36+
PADDLE_ENFORCE_EQ(im_info_dims.size(), 2,
37+
"The rank of Input(Input) in BoxClipOp must be 2");
38+
PADDLE_ENFORCE_EQ(im_info_dims[1], 3,
39+
"The last dimension of ImInfo must be 3");
40+
}
41+
ctx->ShareDim("Input", /*->*/ "Output");
42+
ctx->ShareLoD("Input", /*->*/ "Output");
43+
}
44+
};
45+
46+
class BoxClipOpMaker : public framework::OpProtoAndCheckerMaker {
47+
public:
48+
void Make() override {
49+
AddInput("Input",
50+
"(LoDTensor) "
51+
"Input is a LoDTensor with shape [..., 4] holds 4 points"
52+
"in last dimension in format [xmin, ymin, xmax, ymax]");
53+
AddInput("ImInfo",
54+
"(Tensor) Information for image reshape is in shape (N, 3), "
55+
"in format (height, width, im_scale)");
56+
AddOutput("Output",
57+
"(LoDTensor) "
58+
"Output is a LoDTensor with the same shape as Input"
59+
"and it is the result after clip");
60+
AddComment(R"DOC(
61+
This operator clips input boxes to original input images.
62+
63+
For each input box, The formula is given as follows:
64+
65+
$$xmin = \max(\min(xmin, im_w - 1), 0)$$
66+
$$ymin = \max(\min(ymin, im_h - 1), 0)$$
67+
$$xmax = \max(\min(xmax, im_w - 1), 0)$$
68+
$$ymax = \max(\min(ymax, im_h - 1), 0)$$
69+
70+
where im_w and im_h are computed from ImInfo, the formula is given as follows:
71+
72+
$$im_w = \round(width / im_scale)$$
73+
$$im_h = \round(height / im_scale)$$
74+
)DOC");
75+
}
76+
};
77+
78+
} // namespace operators
79+
} // namespace paddle
80+
81+
namespace ops = paddle::operators;
82+
REGISTER_OPERATOR(box_clip, ops::BoxClipOp, ops::BoxClipOpMaker,
83+
paddle::framework::EmptyGradOpMaker);
84+
REGISTER_OP_CPU_KERNEL(
85+
box_clip, ops::BoxClipKernel<paddle::platform::CPUDeviceContext, float>,
86+
ops::BoxClipKernel<paddle::platform::CPUDeviceContext, double>);
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software
9+
distributed under the License is distributed on an "AS IS" BASIS,
10+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
See the License for the specific language governing permissions and
12+
limitations under the License. */
13+
14+
#include <algorithm>
15+
#include "paddle/fluid/framework/op_registry.h"
16+
#include "paddle/fluid/operators/detection/box_clip_op.h"
17+
#include "paddle/fluid/operators/math/math_function.h"
18+
#include "paddle/fluid/platform/cuda_primitives.h"
19+
#include "paddle/fluid/platform/hostdevice.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
using Tensor = framework::Tensor;
25+
using LoDTenso = framework::LoDTensor;
26+
27+
static constexpr int ImInfoSize = 3;
28+
29+
template <typename T, int BlockSize>
30+
static __global__ void GPUBoxClip(const T *input, const size_t *lod,
31+
const size_t width, const T *im_info,
32+
T *output) {
33+
T im_w = round(im_info[blockIdx.x * ImInfoSize + 1] /
34+
im_info[blockIdx.x * ImInfoSize + 2]);
35+
T im_h = round(im_info[blockIdx.x * ImInfoSize] /
36+
im_info[blockIdx.x * ImInfoSize + 2]);
37+
for (int i = threadIdx.x; i < (lod[blockIdx.x + 1] - lod[blockIdx.x]) * width;
38+
i += BlockSize) {
39+
int idx = lod[blockIdx.x] * width + i;
40+
T im_size = (idx % 2 == 0) ? im_w : im_h;
41+
output[idx] = max(min(input[idx], im_size - 1), T(0.));
42+
}
43+
}
44+
45+
template <typename DeviceContext, typename T>
46+
class GPUBoxClipKernel : public framework::OpKernel<T> {
47+
public:
48+
void Compute(const framework::ExecutionContext &context) const override {
49+
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
50+
"This kernel only runs on GPU device.");
51+
auto *input = context.Input<LoDTensor>("Input");
52+
auto *im_info = context.Input<Tensor>("ImInfo");
53+
auto *output = context.Output<LoDTensor>("Output");
54+
const int64_t num = input->dims()[0];
55+
const int64_t bbox_width = input->numel() / num;
56+
auto lod = input->lod();
57+
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
58+
auto &dev_ctx = context.template device_context<DeviceContext>();
59+
auto stream = dev_ctx.stream();
60+
const size_t batch_size = lod.back().size() - 1;
61+
T *output_data = output->mutable_data<T>(dev_ctx.GetPlace());
62+
GPUBoxClip<T, 512><<<batch_size, 512, 0, stream>>>(
63+
input->data<T>(), abs_offset_lod[0].CUDAMutableData(dev_ctx.GetPlace()),
64+
bbox_width, im_info->data<T>(), output_data);
65+
}
66+
};
67+
68+
} // namespace operators
69+
} // namespace paddle
70+
71+
namespace ops = paddle::operators;
72+
REGISTER_OP_CUDA_KERNEL(
73+
box_clip, ops::GPUBoxClipKernel<paddle::platform::CUDADeviceContext, float>,
74+
ops::GPUBoxClipKernel<paddle::platform::CUDADeviceContext, double>);
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#pragma once
13+
#include <string>
14+
#include "paddle/fluid/framework/op_registry.h"
15+
#include "paddle/fluid/operators/detection/bbox_util.h"
16+
#include "paddle/fluid/operators/math/math_function.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
using Tensor = framework::Tensor;
22+
using LoDTensor = framework::LoDTensor;
23+
24+
template <typename DeviceContext, typename T>
25+
class BoxClipKernel : public framework::OpKernel<T> {
26+
public:
27+
void Compute(const framework::ExecutionContext& context) const override {
28+
auto* input_box = context.Input<LoDTensor>("Input");
29+
auto* im_info = context.Input<LoDTensor>("ImInfo");
30+
auto* output_box = context.Output<LoDTensor>("Output");
31+
auto& dev_ctx =
32+
context.template device_context<platform::CPUDeviceContext>();
33+
output_box->mutable_data<T>(context.GetPlace());
34+
if (input_box->lod().size()) {
35+
PADDLE_ENFORCE_EQ(input_box->lod().size(), 1UL,
36+
"Only support 1 level of LoD.");
37+
}
38+
auto box_lod = input_box->lod().back();
39+
int64_t n = static_cast<int64_t>(box_lod.size() - 1);
40+
for (int i = 0; i < n; ++i) {
41+
Tensor im_info_slice = im_info->Slice(i, i + 1);
42+
Tensor box_slice = input_box->Slice(box_lod[i], box_lod[i + 1]);
43+
Tensor output_slice = output_box->Slice(box_lod[i], box_lod[i + 1]);
44+
ClipTiledBoxes<T>(dev_ctx, im_info_slice, box_slice, &output_slice);
45+
}
46+
}
47+
};
48+
49+
} // namespace operators
50+
} // namespace paddle

python/paddle/fluid/layers/detection.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
'box_coder',
5050
'polygon_box_transform',
5151
'yolov3_loss',
52+
'box_clip',
5253
'multiclass_nms',
5354
]
5455

@@ -2055,6 +2056,54 @@ def generate_proposals(scores,
20552056
return rpn_rois, rpn_roi_probs
20562057

20572058

2059+
def box_clip(input, im_info, name=None):
2060+
"""
2061+
Clip the box into the size given by im_info
2062+
For each input box, The formula is given as follows:
2063+
2064+
.. code-block:: text
2065+
2066+
xmin = max(min(xmin, im_w - 1), 0)
2067+
ymin = max(min(ymin, im_h - 1), 0)
2068+
xmax = max(min(xmax, im_w - 1), 0)
2069+
ymax = max(min(ymax, im_h - 1), 0)
2070+
2071+
where im_w and im_h are computed from im_info:
2072+
2073+
.. code-block:: text
2074+
2075+
im_h = round(height / scale)
2076+
im_w = round(weight / scale)
2077+
2078+
Args:
2079+
input(variable): The input box, the last dimension is 4.
2080+
im_info(variable): The information of image with shape [N, 3] with
2081+
layout (height, width, scale). height and width
2082+
is the input size and scale is the ratio of input
2083+
size and original size.
2084+
name (str): The name of this layer. It is optional.
2085+
2086+
Returns:
2087+
Variable: The cliped tensor variable.
2088+
2089+
Examples:
2090+
.. code-block:: python
2091+
2092+
boxes = fluid.layers.data(
2093+
name='data', shape=[8, 4], dtype='float32', lod_level=1)
2094+
im_info = fluid.layers.data(name='im_info', shape=[3])
2095+
out = fluid.layers.box_clip(
2096+
input=boxes, im_info=im_info, inplace=True)
2097+
"""
2098+
2099+
helper = LayerHelper("box_clip", **locals())
2100+
output = helper.create_variable_for_type_inference(dtype=input.dtype)
2101+
inputs = {"Input": input, "ImInfo": im_info}
2102+
helper.append_op(type="box_clip", inputs=inputs, outputs={"Output": output})
2103+
2104+
return output
2105+
2106+
20582107
def multiclass_nms(bboxes,
20592108
scores,
20602109
score_threshold,
@@ -2132,9 +2181,11 @@ class number
21322181
(After version 1.3, when no boxes detected, the lod is changed
21332182
from {0} to {1})
21342183
2184+
21352185
Examples:
21362186
.. code-block:: python
21372187
2188+
21382189
boxes = fluid.layers.data(name='bboxes', shape=[81, 4],
21392190
dtype='float32', lod_level=1)
21402191
scores = fluid.layers.data(name='scores', shape=[81],

python/paddle/fluid/tests/test_detection.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,17 @@ def test_yolov3_loss(self):
482482
self.assertIsNotNone(loss)
483483

484484

485+
class TestBoxClip(unittest.TestCase):
486+
def test_box_clip(self):
487+
program = Program()
488+
with program_guard(program):
489+
input_box = layers.data(
490+
name='input_box', shape=[7, 4], dtype='float32', lod_level=1)
491+
im_info = layers.data(name='im_info', shape=[3], dtype='float32')
492+
out = layers.box_clip(input_box, im_info)
493+
self.assertIsNotNone(out)
494+
495+
485496
class TestMulticlassNMS(unittest.TestCase):
486497
def test_multiclass_nms(self):
487498
program = Program()

0 commit comments

Comments
 (0)