|
| 1 | +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include "paddle/operators/iou_similarity_op.h" |
| 16 | + |
| 17 | +namespace paddle { |
| 18 | +namespace operators { |
| 19 | + |
| 20 | +class IOUSimilarityOp : public framework::OperatorWithKernel { |
| 21 | + public: |
| 22 | + using framework::OperatorWithKernel::OperatorWithKernel; |
| 23 | + |
| 24 | + protected: |
| 25 | + void InferShape(framework::InferShapeContext *ctx) const override { |
| 26 | + PADDLE_ENFORCE(ctx->HasInput("X"), |
| 27 | + "Input(X) of IOUSimilarityOp should not be null."); |
| 28 | + PADDLE_ENFORCE(ctx->HasInput("Y"), |
| 29 | + "Input(Y) of IOUSimilarityOp should not be null."); |
| 30 | + auto x_dims = ctx->GetInputDim("X"); |
| 31 | + auto y_dims = ctx->GetInputDim("Y"); |
| 32 | + |
| 33 | + PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "The rank of Input(X) must be 2."); |
| 34 | + PADDLE_ENFORCE_EQ(x_dims[1], 4UL, "The shape of X is [N, 4]"); |
| 35 | + PADDLE_ENFORCE_EQ(y_dims.size(), 2UL, "The rank of Input(Y) must be 2."); |
| 36 | + PADDLE_ENFORCE_EQ(y_dims[1], 4UL, "The shape of Y is [M, 4]"); |
| 37 | + |
| 38 | + ctx->ShareLoD("X", /*->*/ "Out"); |
| 39 | + ctx->SetOutputDim("Out", framework::make_ddim({x_dims[0], y_dims[0]})); |
| 40 | + } |
| 41 | +}; |
| 42 | + |
| 43 | +class IOUSimilarityOpMaker : public framework::OpProtoAndCheckerMaker { |
| 44 | + public: |
| 45 | + IOUSimilarityOpMaker(OpProto *proto, OpAttrChecker *op_checker) |
| 46 | + : OpProtoAndCheckerMaker(proto, op_checker) { |
| 47 | + AddInput("X", |
| 48 | + "(LoDTensor, default LoDTensor<float>) " |
| 49 | + "Box list X is a 2-D LoDTensor with shape [N, 4] holds N boxes, " |
| 50 | + "each box is represented as [xmin, ymin, xmax, ymax], " |
| 51 | + "the shape of X is [N, 4]. [xmin, ymin] is the left top " |
| 52 | + "coordinate of the box if the input is image feature map, they " |
| 53 | + "are close to the origin of the coordinate system. " |
| 54 | + "[xmax, ymax] is the right bottom coordinate of the box. " |
| 55 | + "This tensor can contain LoD information to represent a batch " |
| 56 | + "of inputs. One instance of this batch can contain different " |
| 57 | + "numbers of entities."); |
| 58 | + AddInput("Y", |
| 59 | + "(Tensor, default Tensor<float>) " |
| 60 | + "Box list Y holds M boxes, each box is represented as " |
| 61 | + "[xmin, ymin, xmax, ymax], the shape of X is [N, 4]. " |
| 62 | + "[xmin, ymin] is the left top coordinate of the box if the " |
| 63 | + "input is image feature map, and [xmax, ymax] is the right " |
| 64 | + "bottom coordinate of the box."); |
| 65 | + |
| 66 | + AddOutput("Out", |
| 67 | + "(LoDTensor, the lod is same as input X) The output of " |
| 68 | + "iou_similarity op, a tensor with shape [N, M] " |
| 69 | + "representing pairwise iou scores."); |
| 70 | + |
| 71 | + AddComment(R"DOC( |
| 72 | +IOU Similarity Operator. |
| 73 | +Computes intersection-over-union (IOU) between two box lists. |
| 74 | + Box list 'X' should be a LoDTensor and 'Y' is a common Tensor, |
| 75 | + boxes in 'Y' are shared by all instance of the batched inputs of X. |
| 76 | + Given two boxes A and B, the calculation of IOU is as follows: |
| 77 | +
|
| 78 | +$$ |
| 79 | +IOU(A, B) = |
| 80 | +\frac{area(A\cap B)}{area(A)+area(B)-area(A\cap B)} |
| 81 | +$$ |
| 82 | +
|
| 83 | +)DOC"); |
| 84 | + } |
| 85 | +}; |
| 86 | +} // namespace operators |
| 87 | +} // namespace paddle |
| 88 | + |
| 89 | +namespace ops = paddle::operators; |
| 90 | +REGISTER_OP_WITHOUT_GRADIENT(iou_similarity, ops::IOUSimilarityOp, |
| 91 | + ops::IOUSimilarityOpMaker); |
| 92 | + |
| 93 | +REGISTER_OP_CPU_KERNEL( |
| 94 | + iou_similarity, |
| 95 | + ops::IOUSimilarityKernel<paddle::platform::CPUDeviceContext, float>, |
| 96 | + ops::IOUSimilarityKernel<paddle::platform::CPUDeviceContext, double>); |
0 commit comments