Skip to content

Commit a05d25c

Browse files
committed
update code and doc, change input x to LoDTensor
1 parent d458795 commit a05d25c

File tree

4 files changed

+38
-12
lines changed

4 files changed

+38
-12
lines changed

paddle/operators/iou_similarity_op.cc

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,26 +44,38 @@ class IOUSimilarityOpMaker : public framework::OpProtoAndCheckerMaker {
4444
IOUSimilarityOpMaker(OpProto *proto, OpAttrChecker *op_checker)
4545
: OpProtoAndCheckerMaker(proto, op_checker) {
4646
AddInput("X",
47-
"(Tensor, default Tensor<float>) "
48-
"Box list X holds N boxes, each box is "
49-
"represented as [xmin, ymin, xmax, ymax], the shape of X is [N, "
50-
"4]. [xmin, ymin] is the lower left coordinate of the box, and "
51-
"[xmax, ymax] is the right upper coordinate of the box.");
47+
"(LoDTensor, default LoDTensor<float>) "
48+
"Box list X is a 2-D LoDTensor with shape [N, 4] holds N boxes, "
49+
"each box is represented as [xmin, ymin, xmax, ymax], "
50+
"the shape of X is [N, 4]. [xmin, ymin] is the lower left "
51+
"coordinate of the box, and [xmax, ymax] is the right upper "
52+
"coordinate of the box.This tensor can contain LoD information "
53+
"to represent a batch of inputs. One instance of this batch can "
54+
"contain different numbers of entities.");
5255
AddInput("Y",
5356
"(Tensor, default Tensor<float>) "
5457
"Box list Y holds M boxes, each box is "
5558
"represented as [xmin, ymin, xmax, ymax], the shape of X is [N, "
5659
"4]. [xmin, ymin] is the lower left coordinate of the box, and "
5760
"[xmax, ymax] is the right upper coordinate of the box.");
5861

59-
AddOutput(
60-
"Out",
61-
"(Tensor) The output of iou_similarity op, a tensor with shape [N, M] "
62-
"representing pairwise iou scores.");
62+
AddOutput("Out",
63+
"(LoDTensor or Tensor, the lod is same as input X) The output of "
64+
"iou_similarity op, a tensor with shape [N, M] "
65+
"representing pairwise iou scores.");
6366

6467
AddComment(R"DOC(
6568
IOU Similarity Operator.
6669
Computes intersection-over-union (IOU) between two box lists.
70+
Box list 'X' should be a LoDTensor and 'Y' is a common Tensor,
71+
boxes in 'Y' are shared by all input images.
72+
Given two box A and B, the calculation of IOU is as follows:
73+
74+
$$
75+
IOU(A, B) =
76+
\frac{area(A\cap B)}{area(A)+area(B)-area(A\cap B)}
77+
$$
78+
6779
)DOC");
6880
}
6981
};

paddle/operators/iou_similarity_op.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#define EIGEN_USE_GPU
1615
#include "paddle/operators/iou_similarity_op.h"
1716

1817
namespace ops = paddle::operators;

paddle/operators/iou_similarity_op.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ template <typename DeviceContext, typename T>
7171
class IOUSimilarityKernel : public framework::OpKernel<T> {
7272
public:
7373
void Compute(const framework::ExecutionContext& ctx) const override {
74-
const framework::Tensor* in_x = ctx.Input<framework::Tensor>("X");
74+
const framework::LoDTensor* in_x = ctx.Input<framework::LoDTensor>("X");
7575
const framework::Tensor* in_y = ctx.Input<framework::Tensor>("Y");
76-
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
76+
framework::LoDTensor* out = ctx.Output<framework::LoDTensor>("Out");
7777

7878
int x_n = in_x->dims()[0];
7979
int y_n = in_y->dims()[0];
@@ -83,6 +83,8 @@ class IOUSimilarityKernel : public framework::OpKernel<T> {
8383
platform::ForRange<DeviceContext> for_range(
8484
static_cast<const DeviceContext&>(ctx.device_context()), x_n);
8585
for_range(functor);
86+
87+
out->set_lod(in_x->lod());
8688
}
8789
}; // namespace operators
8890

python/paddle/v2/fluid/tests/test_iou_similarity_op.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,18 @@ def setUp(self):
3838
self.outputs = {'Out': self.output}
3939

4040

41+
class TestIOUSimilarityOpWithLoD(TestIOUSimilarityOp):
42+
def test_check_output(self):
43+
self.check_output()
44+
45+
def setUp(self):
46+
super(TestIOUSimilarityOpWithLoD, self).setUp()
47+
self.boxes1_lod = [[0, 1, 2]]
48+
self.output_lod = [[0, 1, 2]]
49+
50+
self.inputs = {'X': (self.boxes1, self.boxes1_lod), 'Y': self.boxes2}
51+
self.outputs = {'Out': (self.output, self.output_lod)}
52+
53+
4154
if __name__ == '__main__':
4255
unittest.main()

0 commit comments

Comments
 (0)