Skip to content

Commit 2ad5a6f

Browse files
committed
add iou similarity operator
1 parent cb6b468 commit 2ad5a6f

File tree

3 files changed

+197
-0
lines changed

3 files changed

+197
-0
lines changed

paddle/operators/iou_similarity_op.cc

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/iou_similarity_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class IOUSimilarityOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
protected:
25+
void InferShape(framework::InferShapeContext *ctx) const override {
26+
auto x_dims = ctx->GetInputDim("X");
27+
auto y_dims = ctx->GetInputDim("Y");
28+
29+
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "The shape of X is [N, 4]");
30+
PADDLE_ENFORCE_EQ(x_dims[1], 4UL, "The shape of X is [N, 4]");
31+
PADDLE_ENFORCE_EQ(y_dims.size(), 2UL, "The shape of Y is [M, 4]");
32+
PADDLE_ENFORCE_EQ(y_dims[1], 4UL, "The shape of Y is [M, 4]");
33+
34+
ctx->SetOutputDim("Out", framework::make_ddim({x_dims[0], y_dims[0]}));
35+
}
36+
};
37+
38+
class IOUSimilarityOpMaker : public framework::OpProtoAndCheckerMaker {
39+
public:
40+
IOUSimilarityOpMaker(OpProto *proto, OpAttrChecker *op_checker)
41+
: OpProtoAndCheckerMaker(proto, op_checker) {
42+
AddInput(
43+
"X",
44+
"(Tensor, default Tensor<float>) "
45+
"BoxList X holding N boxes, each box is "
46+
"represented as [xmin, ymin, xmax, ymax], the shape of X is [N, 4].");
47+
AddInput(
48+
"Y",
49+
"(Tensor, default Tensor<float>) "
50+
"BoxList Y holding M boxes, each box is "
51+
"represented as [xmin, ymin, xmax, ymax], the shape of X is [N, 4].");
52+
53+
AddOutput(
54+
"Out",
55+
"(Tensor) The output of iou_similarity op, a tensor with shape [N, M] "
56+
"representing pairwise iou scores.");
57+
58+
AddComment(R"DOC(
59+
IOU Similarity Operator.
60+
Computes pairwise intersection-over-union between box collections.
61+
)DOC");
62+
}
63+
};
64+
} // namespace operators
65+
} // namespace paddle
66+
67+
namespace ops = paddle::operators;
68+
REGISTER_OP_WITHOUT_GRADIENT(iou_similarity, ops::IOUSimilarityOp,
69+
ops::IOUSimilarityOpMaker);
70+
71+
REGISTER_OP_CPU_KERNEL(
72+
iou_similarity,
73+
ops::IOUSimilarityKernel<paddle::platform::CPUDeviceContext, float>,
74+
ops::IOUSimilarityKernel<paddle::platform::CPUDeviceContext, double>);

paddle/operators/iou_similarity_op.h

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/framework/op_registry.h"
17+
#include "paddle/platform/for_range.h"
18+
19+
template <typename T>
20+
inline T IOUSimilarity(T xmin1, T ymin1, T xmax1, T ymax1, T xmin2, T ymin2,
21+
T xmax2, T ymax2) {
22+
T area1 = (ymax1 - ymin1) * (xmax1 - xmin1);
23+
T area2 = (ymax2 - ymin2) * (xmax2 - xmin2);
24+
T inter_xmax = std::min(xmax1, xmax2);
25+
T inter_ymax = std::min(ymax1, ymax2);
26+
T inter_xmin = std::max(xmin1, xmin2);
27+
T inter_ymin = std::max(ymin1, ymin2);
28+
T inter_height = std::max(inter_ymax - inter_ymin, static_cast<T>(0));
29+
T inter_width = std::max(inter_xmax - inter_xmin, static_cast<T>(0));
30+
T inter_area = inter_width * inter_height;
31+
T union_area = area1 + area2 - inter_area;
32+
T sim_score = inter_area / union_area;
33+
return sim_score;
34+
}
35+
36+
template <typename T>
37+
struct IOUSimilarityFunctor {
38+
IOUSimilarityFunctor(const T* x, const T* y, T* z, int cols)
39+
: x_(x), y_(y), z_(z), cols_(static_cast<size_t>(cols)) {}
40+
41+
inline HOSTDEVICE void operator()(size_t row_id) const {
42+
T x_min1 = x_[row_id * 4];
43+
T y_min1 = x_[row_id * 4 + 1];
44+
T x_max1 = x_[row_id * 4 + 2];
45+
T y_max1 = x_[row_id * 4 + 3];
46+
for (int i = 0; i < cols_; ++i) {
47+
T x_min2 = y_[i * 4];
48+
T y_min2 = y_[i * 4 + 1];
49+
T x_max2 = y_[i * 4 + 2];
50+
T y_max2 = y_[i * 4 + 3];
51+
52+
T sim = IOUSimilarity(x_min1, y_min1, x_max1, y_max1, x_min2, y_min2,
53+
x_max2, y_max2);
54+
55+
z_[row_id * cols_ + i] = sim;
56+
}
57+
}
58+
const T* x_;
59+
const T* y_;
60+
T* z_;
61+
const size_t cols_;
62+
};
63+
64+
namespace paddle {
65+
namespace operators {
66+
67+
template <typename DeviceContext, typename T>
68+
class IOUSimilarityKernel : public framework::OpKernel<T> {
69+
public:
70+
void Compute(const framework::ExecutionContext& ctx) const override {
71+
const framework::Tensor* in_x = ctx.Input<framework::Tensor>("X");
72+
const framework::Tensor* in_y = ctx.Input<framework::Tensor>("Y");
73+
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
74+
75+
int x_n = in_x->dims()[0];
76+
int y_n = in_y->dims()[0];
77+
IOUSimilarityFunctor<T> functor(in_x->data<T>(), in_y->data<T>(),
78+
out->mutable_data<T>(ctx.GetPlace()), y_n);
79+
80+
platform::ForRange<DeviceContext> for_range(
81+
static_cast<const DeviceContext&>(ctx.device_context()), x_n);
82+
for_range(functor);
83+
}
84+
}; // namespace operators
85+
86+
} // namespace operators
87+
} // namespace paddle
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import unittest
2+
import numpy as np
3+
import sys
4+
import math
5+
from op_test import OpTest
6+
7+
8+
class TestIOUSimilarityOp(OpTest):
9+
def set_data(self):
10+
self.init_test_data()
11+
self.inputs = {'X': self.boxes1, 'Y': self.boxes2}
12+
13+
self.outputs = {'Out': self.output}
14+
15+
def test_check_output(self):
16+
self.check_output()
17+
18+
def test_check_grad(self):
19+
return
20+
21+
def setUp(self):
22+
self.op_type = "iou_similarity"
23+
self.set_data()
24+
25+
def init_test_data(self):
26+
self.boxes1 = np.array(
27+
[[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]]).astype('float32')
28+
self.boxes2 = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
29+
[0.0, 0.0, 20.0, 20.0]]).astype('float32')
30+
self.output = np.array(
31+
[[2.0 / 16.0, 0, 6.0 / 400.0],
32+
[1.0 / 16.0, 0.0, 5.0 / 400.0]]).astype('float32')
33+
34+
35+
if __name__ == '__main__':
36+
unittest.main()

0 commit comments

Comments
 (0)