@@ -560,116 +560,6 @@ std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>> rpn_nms_kernel(
560560 return std::make_tuple (bboxes_out, scores_out);
561561}
562562
563- template <typename scalar_t >
564- std::tuple<
565- std::vector<at::Tensor>,
566- std::vector<at::Tensor>,
567- std::vector<at::Tensor>>
568- box_head_nms_kernel (
569- const std::vector<at::Tensor>& batch_bboxes,
570- const std::vector<at::Tensor>& batch_scores,
571- const std::vector<std::tuple<int64_t , int64_t >>& image_shapes,
572- const float score_thresh,
573- const float threshold,
574- const int detections_per_img,
575- const int num_classes) {
576- auto nbatch = batch_scores.size (); // number of batches
577- auto nbatch_x_nclass =
578- nbatch * num_classes; // (number of batches) * (number of labels)
579-
580- std::vector<at::Tensor> bboxes_out (nbatch_x_nclass);
581- std::vector<at::Tensor> scores_out (nbatch_x_nclass);
582- std::vector<at::Tensor> labels_out (nbatch_x_nclass);
583-
584- #ifdef _OPENMP
585- #if (_OPENMP >= 201307)
586- #pragma omp parallel for simd schedule( \
587- static ) if (omp_get_max_threads () > 1 && !omp_in_parallel ())
588- #else
589- #pragma omp parallel for schedule( \
590- static ) if (omp_get_max_threads () > 1 && !omp_in_parallel ())
591- #endif
592- #endif
593- for (int bs = 0 ; bs < nbatch; bs++) {
594- at::Tensor bboxes = batch_bboxes[bs].reshape ({-1 , 4 });
595- at::Tensor scores = batch_scores[bs];
596- auto image_shape = image_shapes[bs];
597- bboxes.slice (1 , 0 , 1 ).clamp_ (0 , std::get<0 >(image_shape) - 1 );
598- bboxes.slice (1 , 1 , 2 ).clamp_ (0 , std::get<1 >(image_shape) - 1 );
599- bboxes.slice (1 , 2 , 3 ).clamp_ (0 , std::get<0 >(image_shape) - 1 );
600- bboxes.slice (1 , 3 , 4 ).clamp_ (0 , std::get<1 >(image_shape) - 1 );
601- bboxes = bboxes.reshape ({-1 , num_classes * 4 });
602- scores = scores.reshape ({-1 , num_classes});
603- at::Tensor indexes = scores > score_thresh;
604-
605- for (int j = 1 ; j < num_classes; j++) {
606- at::Tensor index =
607- at::nonzero (indexes.slice (1 , j, j + 1 ).squeeze (1 )).squeeze (1 );
608- at::Tensor score =
609- scores.slice (1 , j, j + 1 ).squeeze (1 ).index_select (0 , index);
610- at::Tensor bbox =
611- bboxes.slice (1 , j * 4 , (j + 1 ) * 4 ).index_select (0 , index);
612- if (score.size (0 ) == 0 ) {
613- continue ;
614- }
615- auto iter = bs * num_classes + j;
616- if (threshold > 0 ) {
617- at::Tensor keep =
618- nms_cpu_kernel<scalar_t , /* sorted*/ false >(bbox, score, threshold);
619- bboxes_out[iter] = bbox.index_select (0 , keep);
620- scores_out[iter] = score.index_select (0 , keep);
621- labels_out[iter] = at::full ({keep.sizes ()}, j, torch::kInt64 );
622- } else {
623- bboxes_out[iter] = bbox;
624- scores_out[iter] = score;
625- labels_out[iter] = at::full ({score.sizes ()}, j, torch::kInt64 );
626- }
627- }
628- }
629-
630- std::vector<at::Tensor> bboxes_out_ (nbatch);
631- std::vector<at::Tensor> scores_out_ (nbatch);
632- std::vector<at::Tensor> labels_out_ (nbatch);
633-
634- #ifdef _OPENMP
635- #if (_OPENMP >= 201307)
636- #pragma omp parallel for simd schedule( \
637- static ) if (omp_get_max_threads () > 1 && !omp_in_parallel ())
638- #else
639- #pragma omp parallel for schedule( \
640- static ) if (omp_get_max_threads () > 1 && !omp_in_parallel ())
641- #endif
642- #endif
643- for (int bs = 0 ; bs < nbatch; bs++) {
644- std::vector<at::Tensor> valid_bboxes_out =
645- remove_empty (bboxes_out, bs * num_classes, (bs + 1 ) * num_classes);
646- std::vector<at::Tensor> valid_scores_out =
647- remove_empty (scores_out, bs * num_classes, (bs + 1 ) * num_classes);
648- std::vector<at::Tensor> valid_labels_out =
649- remove_empty (labels_out, bs * num_classes, (bs + 1 ) * num_classes);
650- if (valid_bboxes_out.size () > 0 ) {
651- bboxes_out_[bs] = at::cat (valid_bboxes_out, 0 );
652- scores_out_[bs] = at::cat (valid_scores_out, 0 );
653- labels_out_[bs] = at::cat (valid_labels_out, 0 );
654- } else {
655- bboxes_out_[bs] = at::empty ({0 , 4 }, torch::kFloat );
656- scores_out_[bs] = at::empty ({0 }, torch::kFloat );
657- labels_out_[bs] = at::empty ({0 }, torch::kInt64 );
658- }
659- auto number_of_detections = bboxes_out_[bs].size (0 );
660- if (number_of_detections > detections_per_img && detections_per_img > 0 ) {
661- auto out_ = scores_out_[bs].kthvalue (
662- number_of_detections - detections_per_img + 1 );
663- at::Tensor keep =
664- at::nonzero (scores_out_[bs] >= std::get<0 >(out_).item ()).squeeze (1 );
665- bboxes_out_[bs] = bboxes_out_[bs].index_select (0 , keep);
666- scores_out_[bs] = scores_out_[bs].index_select (0 , keep);
667- labels_out_[bs] = labels_out_[bs].index_select (0 , keep);
668- }
669- }
670- return std::make_tuple (bboxes_out_, scores_out_, labels_out_);
671- }
672-
673563at::Tensor nms_cpu_kernel_impl (
674564 const at::Tensor& dets,
675565 const at::Tensor& scores,
@@ -718,37 +608,6 @@ rpn_nms_cpu_kernel_impl(
718608 return result;
719609}
720610
721- std::tuple<
722- std::vector<at::Tensor>,
723- std::vector<at::Tensor>,
724- std::vector<at::Tensor>>
725- box_head_nms_cpu_kernel_impl (
726- const std::vector<at::Tensor>& batch_bboxes,
727- const std::vector<at::Tensor>& batch_scores,
728- const std::vector<std::tuple<int64_t , int64_t >>& image_shapes,
729- const float score_thresh,
730- const float threshold,
731- const int detections_per_img,
732- const int num_classes) {
733- std::tuple<
734- std::vector<at::Tensor>,
735- std::vector<at::Tensor>,
736- std::vector<at::Tensor>>
737- result;
738- AT_DISPATCH_FLOATING_TYPES (
739- batch_bboxes[0 ].scalar_type (), " box_head_nms" , [&] {
740- result = box_head_nms_kernel<scalar_t >(
741- batch_bboxes,
742- batch_scores,
743- image_shapes,
744- score_thresh,
745- threshold,
746- detections_per_img,
747- num_classes);
748- });
749- return result;
750- }
751-
752611} // anonymous namespace
753612
754613IPEX_REGISTER_DISPATCH (nms_cpu_kernel_stub, &nms_cpu_kernel_impl);
@@ -759,9 +618,5 @@ IPEX_REGISTER_DISPATCH(
759618
760619IPEX_REGISTER_DISPATCH (rpn_nms_cpu_kernel_stub, &rpn_nms_cpu_kernel_impl);
761620
762- IPEX_REGISTER_DISPATCH (
763- box_head_nms_cpu_kernel_stub,
764- &box_head_nms_cpu_kernel_impl);
765-
766621} // namespace cpu
767622} // namespace torch_ipex
0 commit comments