|
| 1 | +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | +http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +limitations under the License. */ |
| 13 | + |
| 14 | +#include "paddle/fluid/framework/op_registry.h" |
| 15 | +#include "paddle/fluid/operators/detection/nms_util.h" |
| 16 | + |
| 17 | +namespace paddle { |
| 18 | +namespace operators { |
| 19 | + |
| 20 | +using Tensor = framework::Tensor; |
| 21 | +using LoDTensor = framework::LoDTensor; |
| 22 | + |
| 23 | +class MatrixNMSOp : public framework::OperatorWithKernel { |
| 24 | + public: |
| 25 | + using framework::OperatorWithKernel::OperatorWithKernel; |
| 26 | + |
| 27 | + void InferShape(framework::InferShapeContext* ctx) const override { |
| 28 | + OP_INOUT_CHECK(ctx->HasInput("BBoxes"), "Input", "BBoxes", "MatrixNMS"); |
| 29 | + OP_INOUT_CHECK(ctx->HasInput("Scores"), "Input", "Scores", "MatrixNMS"); |
| 30 | + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "MatrixNMS"); |
| 31 | + auto box_dims = ctx->GetInputDim("BBoxes"); |
| 32 | + auto score_dims = ctx->GetInputDim("Scores"); |
| 33 | + auto score_size = score_dims.size(); |
| 34 | + |
| 35 | + if (ctx->IsRuntime()) { |
| 36 | + PADDLE_ENFORCE_EQ(score_size == 3, true, |
| 37 | + platform::errors::InvalidArgument( |
| 38 | + "The rank of Input(Scores) must be 3. " |
| 39 | + "But received rank = %d.", |
| 40 | + score_size)); |
| 41 | + PADDLE_ENFORCE_EQ(box_dims.size(), 3, |
| 42 | + platform::errors::InvalidArgument( |
| 43 | + "The rank of Input(BBoxes) must be 3." |
| 44 | + "But received rank = %d.", |
| 45 | + box_dims.size())); |
| 46 | + PADDLE_ENFORCE_EQ(box_dims[2] == 4, true, |
| 47 | + platform::errors::InvalidArgument( |
| 48 | + "The last dimension of Input (BBoxes) must be 4, " |
| 49 | + "represents the layout of coordinate " |
| 50 | + "[xmin, ymin, xmax, ymax].")); |
| 51 | + PADDLE_ENFORCE_EQ( |
| 52 | + box_dims[1], score_dims[2], |
| 53 | + platform::errors::InvalidArgument( |
| 54 | + "The 2nd dimension of Input(BBoxes) must be equal to " |
| 55 | + "last dimension of Input(Scores), which represents the " |
| 56 | + "predicted bboxes." |
| 57 | + "But received box_dims[1](%s) != socre_dims[2](%s)", |
| 58 | + box_dims[1], score_dims[2])); |
| 59 | + } |
| 60 | + ctx->SetOutputDim("Out", {box_dims[1], box_dims[2] + 2}); |
| 61 | + ctx->SetOutputDim("Index", {box_dims[1], 1}); |
| 62 | + if (!ctx->IsRuntime()) { |
| 63 | + ctx->SetLoDLevel("Out", std::max(ctx->GetLoDLevel("BBoxes"), 1)); |
| 64 | + ctx->SetLoDLevel("Index", std::max(ctx->GetLoDLevel("BBoxes"), 1)); |
| 65 | + } |
| 66 | + } |
| 67 | + |
| 68 | + protected: |
| 69 | + framework::OpKernelType GetExpectedKernelType( |
| 70 | + const framework::ExecutionContext& ctx) const override { |
| 71 | + return framework::OpKernelType( |
| 72 | + OperatorWithKernel::IndicateVarDataType(ctx, "Scores"), |
| 73 | + platform::CPUPlace()); |
| 74 | + } |
| 75 | +}; |
| 76 | + |
| 77 | +template <typename T, bool gaussian> |
| 78 | +struct decay_score; |
| 79 | + |
| 80 | +template <typename T> |
| 81 | +struct decay_score<T, true> { |
| 82 | + T operator()(T iou, T max_iou, T sigma) { |
| 83 | + return std::exp((max_iou * max_iou - iou * iou) * sigma); |
| 84 | + } |
| 85 | +}; |
| 86 | + |
| 87 | +template <typename T> |
| 88 | +struct decay_score<T, false> { |
| 89 | + T operator()(T iou, T max_iou, T sigma) { |
| 90 | + return (1. - iou) / (1. - max_iou); |
| 91 | + } |
| 92 | +}; |
| 93 | + |
| 94 | +template <typename T, bool gaussian> |
| 95 | +void NMSMatrix(const Tensor& bbox, const Tensor& scores, |
| 96 | + const T score_threshold, const T post_threshold, |
| 97 | + const float sigma, const int64_t top_k, const bool normalized, |
| 98 | + std::vector<int>* selected_indices, |
| 99 | + std::vector<T>* decayed_scores) { |
| 100 | + int64_t num_boxes = bbox.dims()[0]; |
| 101 | + int64_t box_size = bbox.dims()[1]; |
| 102 | + |
| 103 | + auto score_ptr = scores.data<T>(); |
| 104 | + auto bbox_ptr = bbox.data<T>(); |
| 105 | + |
| 106 | + std::vector<int32_t> perm(num_boxes); |
| 107 | + std::iota(perm.begin(), perm.end(), 0); |
| 108 | + auto end = std::remove_if(perm.begin(), perm.end(), |
| 109 | + [&score_ptr, score_threshold](int32_t idx) { |
| 110 | + return score_ptr[idx] <= score_threshold; |
| 111 | + }); |
| 112 | + |
| 113 | + auto sort_fn = [&score_ptr](int32_t lhs, int32_t rhs) { |
| 114 | + return score_ptr[lhs] > score_ptr[rhs]; |
| 115 | + }; |
| 116 | + |
| 117 | + int64_t num_pre = std::distance(perm.begin(), end); |
| 118 | + if (num_pre <= 0) { |
| 119 | + return; |
| 120 | + } |
| 121 | + if (top_k > -1 && num_pre > top_k) { |
| 122 | + num_pre = top_k; |
| 123 | + } |
| 124 | + std::partial_sort(perm.begin(), perm.begin() + num_pre, end, sort_fn); |
| 125 | + |
| 126 | + std::vector<T> iou_matrix((num_pre * (num_pre - 1)) >> 1); |
| 127 | + std::vector<T> iou_max(num_pre); |
| 128 | + |
| 129 | + iou_max[0] = 0.; |
| 130 | + for (int64_t i = 1; i < num_pre; i++) { |
| 131 | + T max_iou = 0.; |
| 132 | + auto idx_a = perm[i]; |
| 133 | + for (int64_t j = 0; j < i; j++) { |
| 134 | + auto idx_b = perm[j]; |
| 135 | + auto iou = JaccardOverlap<T>(bbox_ptr + idx_a * box_size, |
| 136 | + bbox_ptr + idx_b * box_size, normalized); |
| 137 | + max_iou = std::max(max_iou, iou); |
| 138 | + iou_matrix[i * (i - 1) / 2 + j] = iou; |
| 139 | + } |
| 140 | + iou_max[i] = max_iou; |
| 141 | + } |
| 142 | + |
| 143 | + if (score_ptr[perm[0]] > post_threshold) { |
| 144 | + selected_indices->push_back(perm[0]); |
| 145 | + decayed_scores->push_back(score_ptr[perm[0]]); |
| 146 | + } |
| 147 | + |
| 148 | + decay_score<T, gaussian> decay_fn; |
| 149 | + for (int64_t i = 1; i < num_pre; i++) { |
| 150 | + T min_decay = 1.; |
| 151 | + for (int64_t j = 0; j < i; j++) { |
| 152 | + auto max_iou = iou_max[j]; |
| 153 | + auto iou = iou_matrix[i * (i - 1) / 2 + j]; |
| 154 | + auto decay = decay_fn(iou, max_iou, sigma); |
| 155 | + min_decay = std::min(min_decay, decay); |
| 156 | + } |
| 157 | + auto ds = min_decay * score_ptr[perm[i]]; |
| 158 | + if (ds <= post_threshold) continue; |
| 159 | + selected_indices->push_back(perm[i]); |
| 160 | + decayed_scores->push_back(ds); |
| 161 | + } |
| 162 | +} |
| 163 | + |
| 164 | +template <typename T> |
| 165 | +class MatrixNMSKernel : public framework::OpKernel<T> { |
| 166 | + public: |
| 167 | + size_t MultiClassMatrixNMS(const Tensor& scores, const Tensor& bboxes, |
| 168 | + std::vector<T>* out, std::vector<int>* indices, |
| 169 | + int start, int64_t background_label, |
| 170 | + int64_t nms_top_k, int64_t keep_top_k, |
| 171 | + bool normalized, T score_threshold, |
| 172 | + T post_threshold, bool use_gaussian, |
| 173 | + float gaussian_sigma) const { |
| 174 | + std::vector<int> all_indices; |
| 175 | + std::vector<T> all_scores; |
| 176 | + std::vector<T> all_classes; |
| 177 | + all_indices.reserve(scores.numel()); |
| 178 | + all_scores.reserve(scores.numel()); |
| 179 | + all_classes.reserve(scores.numel()); |
| 180 | + |
| 181 | + size_t num_det = 0; |
| 182 | + auto class_num = scores.dims()[0]; |
| 183 | + Tensor score_slice; |
| 184 | + for (int64_t c = 0; c < class_num; ++c) { |
| 185 | + if (c == background_label) continue; |
| 186 | + score_slice = scores.Slice(c, c + 1); |
| 187 | + if (use_gaussian) { |
| 188 | + NMSMatrix<T, true>(bboxes, score_slice, score_threshold, post_threshold, |
| 189 | + gaussian_sigma, nms_top_k, normalized, &all_indices, |
| 190 | + &all_scores); |
| 191 | + } else { |
| 192 | + NMSMatrix<T, false>(bboxes, score_slice, score_threshold, |
| 193 | + post_threshold, gaussian_sigma, nms_top_k, |
| 194 | + normalized, &all_indices, &all_scores); |
| 195 | + } |
| 196 | + for (size_t i = 0; i < all_indices.size() - num_det; i++) { |
| 197 | + all_classes.push_back(static_cast<T>(c)); |
| 198 | + } |
| 199 | + num_det = all_indices.size(); |
| 200 | + } |
| 201 | + |
| 202 | + if (num_det <= 0) { |
| 203 | + return num_det; |
| 204 | + } |
| 205 | + |
| 206 | + if (keep_top_k > -1) { |
| 207 | + auto k = static_cast<size_t>(keep_top_k); |
| 208 | + if (num_det > k) num_det = k; |
| 209 | + } |
| 210 | + |
| 211 | + std::vector<int32_t> perm(all_indices.size()); |
| 212 | + std::iota(perm.begin(), perm.end(), 0); |
| 213 | + |
| 214 | + std::partial_sort(perm.begin(), perm.begin() + num_det, perm.end(), |
| 215 | + [&all_scores](int lhs, int rhs) { |
| 216 | + return all_scores[lhs] > all_scores[rhs]; |
| 217 | + }); |
| 218 | + |
| 219 | + for (size_t i = 0; i < num_det; i++) { |
| 220 | + auto p = perm[i]; |
| 221 | + auto idx = all_indices[p]; |
| 222 | + auto cls = all_classes[p]; |
| 223 | + auto score = all_scores[p]; |
| 224 | + auto bbox = bboxes.data<T>() + idx * bboxes.dims()[1]; |
| 225 | + (*indices).push_back(start + idx); |
| 226 | + (*out).push_back(cls); |
| 227 | + (*out).push_back(score); |
| 228 | + for (int j = 0; j < bboxes.dims()[1]; j++) { |
| 229 | + (*out).push_back(bbox[j]); |
| 230 | + } |
| 231 | + } |
| 232 | + |
| 233 | + return num_det; |
| 234 | + } |
| 235 | + |
| 236 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 237 | + auto* boxes = ctx.Input<LoDTensor>("BBoxes"); |
| 238 | + auto* scores = ctx.Input<LoDTensor>("Scores"); |
| 239 | + auto* outs = ctx.Output<LoDTensor>("Out"); |
| 240 | + auto* index = ctx.Output<LoDTensor>("Index"); |
| 241 | + |
| 242 | + auto background_label = ctx.Attr<int>("background_label"); |
| 243 | + auto nms_top_k = ctx.Attr<int>("nms_top_k"); |
| 244 | + auto keep_top_k = ctx.Attr<int>("keep_top_k"); |
| 245 | + auto normalized = ctx.Attr<bool>("normalized"); |
| 246 | + auto score_threshold = ctx.Attr<float>("score_threshold"); |
| 247 | + auto post_threshold = ctx.Attr<float>("post_threshold"); |
| 248 | + auto use_gaussian = ctx.Attr<bool>("use_gaussian"); |
| 249 | + auto gaussian_sigma = ctx.Attr<float>("gaussian_sigma"); |
| 250 | + |
| 251 | + auto score_dims = scores->dims(); |
| 252 | + auto batch_size = score_dims[0]; |
| 253 | + auto num_boxes = score_dims[2]; |
| 254 | + auto box_dim = boxes->dims()[2]; |
| 255 | + auto out_dim = box_dim + 2; |
| 256 | + |
| 257 | + Tensor boxes_slice, scores_slice; |
| 258 | + size_t num_out = 0; |
| 259 | + std::vector<size_t> offsets = {0}; |
| 260 | + std::vector<T> detections; |
| 261 | + std::vector<int> indices; |
| 262 | + detections.reserve(out_dim * num_boxes * batch_size); |
| 263 | + indices.reserve(num_boxes * batch_size); |
| 264 | + for (int i = 0; i < batch_size; ++i) { |
| 265 | + scores_slice = scores->Slice(i, i + 1); |
| 266 | + scores_slice.Resize({score_dims[1], score_dims[2]}); |
| 267 | + boxes_slice = boxes->Slice(i, i + 1); |
| 268 | + boxes_slice.Resize({score_dims[2], box_dim}); |
| 269 | + int start = i * score_dims[2]; |
| 270 | + num_out = MultiClassMatrixNMS( |
| 271 | + scores_slice, boxes_slice, &detections, &indices, start, |
| 272 | + background_label, nms_top_k, keep_top_k, normalized, score_threshold, |
| 273 | + post_threshold, use_gaussian, gaussian_sigma); |
| 274 | + offsets.push_back(offsets.back() + num_out); |
| 275 | + } |
| 276 | + |
| 277 | + int64_t num_kept = offsets.back(); |
| 278 | + if (num_kept == 0) { |
| 279 | + outs->mutable_data<T>({0, out_dim}, ctx.GetPlace()); |
| 280 | + index->mutable_data<int>({0, 1}, ctx.GetPlace()); |
| 281 | + } else { |
| 282 | + outs->mutable_data<T>({num_kept, out_dim}, ctx.GetPlace()); |
| 283 | + index->mutable_data<int>({num_kept, 1}, ctx.GetPlace()); |
| 284 | + std::copy(detections.begin(), detections.end(), outs->data<T>()); |
| 285 | + std::copy(indices.begin(), indices.end(), index->data<int>()); |
| 286 | + } |
| 287 | + |
| 288 | + framework::LoD lod; |
| 289 | + lod.emplace_back(offsets); |
| 290 | + outs->set_lod(lod); |
| 291 | + index->set_lod(lod); |
| 292 | + } |
| 293 | +}; |
| 294 | + |
| 295 | +class MatrixNMSOpMaker : public framework::OpProtoAndCheckerMaker { |
| 296 | + public: |
| 297 | + void Make() override { |
| 298 | + AddInput("BBoxes", |
| 299 | + "(Tensor) A 3-D Tensor with shape " |
| 300 | + "[N, M, 4] represents the predicted locations of M bounding boxes" |
| 301 | + ", N is the batch size. " |
| 302 | + "Each bounding box has four coordinate values and the layout is " |
| 303 | + "[xmin, ymin, xmax, ymax], when box size equals to 4."); |
| 304 | + AddInput("Scores", |
| 305 | + "(Tensor) A 3-D Tensor with shape [N, C, M] represents the " |
| 306 | + "predicted confidence predictions. N is the batch size, C is the " |
| 307 | + "class number, M is number of bounding boxes. For each category " |
| 308 | + "there are total M scores which corresponding M bounding boxes. " |
| 309 | + " Please note, M is equal to the 2nd dimension of BBoxes. "); |
| 310 | + AddAttr<int>( |
| 311 | + "background_label", |
| 312 | + "(int, default: 0) " |
| 313 | + "The index of background label, the background label will be ignored. " |
| 314 | + "If set to -1, then all categories will be considered.") |
| 315 | + .SetDefault(0); |
| 316 | + AddAttr<float>("score_threshold", |
| 317 | + "(float) " |
| 318 | + "Threshold to filter out bounding boxes with low " |
| 319 | + "confidence score."); |
| 320 | + AddAttr<float>("post_threshold", |
| 321 | + "(float, default 0.) " |
| 322 | + "Threshold to filter out bounding boxes with low " |
| 323 | + "confidence score AFTER decaying.") |
| 324 | + .SetDefault(0.); |
| 325 | + AddAttr<int>("nms_top_k", |
| 326 | + "(int64_t) " |
| 327 | + "Maximum number of detections to be kept according to the " |
| 328 | + "confidences after the filtering detections based on " |
| 329 | + "score_threshold"); |
| 330 | + AddAttr<int>("keep_top_k", |
| 331 | + "(int64_t) " |
| 332 | + "Number of total bboxes to be kept per image after NMS " |
| 333 | + "step. -1 means keeping all bboxes after NMS step."); |
| 334 | + AddAttr<bool>("normalized", |
| 335 | + "(bool, default true) " |
| 336 | + "Whether detections are normalized.") |
| 337 | + .SetDefault(true); |
| 338 | + AddAttr<bool>("use_gaussian", |
| 339 | + "(bool, default false) " |
| 340 | + "Whether to use Gaussian as decreasing function.") |
| 341 | + .SetDefault(false); |
| 342 | + AddAttr<float>("gaussian_sigma", |
| 343 | + "(float) " |
| 344 | + "Sigma for Gaussian decreasing function, only takes effect ", |
| 345 | + "when 'use_gaussian' is enabled.") |
| 346 | + .SetDefault(2.); |
| 347 | + AddOutput("Out", |
| 348 | + "(LoDTensor) A 2-D LoDTensor with shape [No, 6] represents the " |
| 349 | + "detections. Each row has 6 values: " |
| 350 | + "[label, confidence, xmin, ymin, xmax, ymax]. " |
| 351 | + "the offsets in first dimension are called LoD, the number of " |
| 352 | + "offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is " |
| 353 | + "no detected bbox."); |
| 354 | + AddOutput("Index", |
| 355 | + "(LoDTensor) A 2-D LoDTensor with shape [No, 1] represents the " |
| 356 | + "index of selected bbox. The index is the absolute index cross " |
| 357 | + "batches."); |
| 358 | + AddComment(R"DOC( |
| 359 | +This operator does multi-class matrix non maximum suppression (NMS) on batched |
| 360 | +boxes and scores. |
| 361 | +In the NMS step, this operator greedily selects a subset of detection bounding |
| 362 | +boxes that have high scores larger than score_threshold, if providing this |
| 363 | +threshold, then selects the largest nms_top_k confidences scores if nms_top_k |
| 364 | +is larger than -1. Then this operator decays boxes score according to the |
| 365 | +Matrix NMS scheme. |
| 366 | +Aftern NMS step, at most keep_top_k number of total bboxes are to be kept |
| 367 | +per image if keep_top_k is larger than -1. |
| 368 | +This operator support multi-class and batched inputs. It applying NMS |
| 369 | +independently for each class. The outputs is a 2-D LoDTenosr, for each |
| 370 | +image, the offsets in first dimension of LoDTensor are called LoD, the number |
| 371 | +of offset is N + 1, where N is the batch size. If LoD[i + 1] - LoD[i] == 0, |
| 372 | +means there is no detected bbox for this image. |
| 373 | +
|
| 374 | +For more information on Matrix NMS, please refer to: |
| 375 | +https://arxiv.org/abs/2003.10152 |
| 376 | +)DOC"); |
| 377 | + } |
| 378 | +}; |
| 379 | + |
| 380 | +} // namespace operators |
| 381 | +} // namespace paddle |
| 382 | + |
| 383 | +namespace ops = paddle::operators; |
| 384 | +REGISTER_OPERATOR( |
| 385 | + matrix_nms, ops::MatrixNMSOp, ops::MatrixNMSOpMaker, |
| 386 | + paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, |
| 387 | + paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); |
| 388 | +REGISTER_OP_CPU_KERNEL(matrix_nms, ops::MatrixNMSKernel<float>, |
| 389 | + ops::MatrixNMSKernel<double>); |
0 commit comments