diff --git a/src/torchcontentarea/csrc/common.hpp b/src/torchcontentarea/csrc/common.hpp index 32123ca..6eab40b 100644 --- a/src/torchcontentarea/csrc/common.hpp +++ b/src/torchcontentarea/csrc/common.hpp @@ -41,6 +41,13 @@ enum ImageFormat gray_long, }; +enum FitCircleStatus +{ + success, + no_points, + invalid, +}; + struct Image { Image(ImageFormat format, const void* data) : format(format), data(data) {} diff --git a/src/torchcontentarea/csrc/cpu_functions.hpp b/src/torchcontentarea/csrc/cpu_functions.hpp index edc73e5..39e230c 100644 --- a/src/torchcontentarea/csrc/cpu_functions.hpp +++ b/src/torchcontentarea/csrc/cpu_functions.hpp @@ -6,5 +6,5 @@ namespace cpu void find_points(Image image, const int batch_count, const int channel_count, const int image_height, const int image_width, const int strip_count, const FeatureThresholds feature_thresholds, float* points_x, float* points_y, float* point_score); void make_strips(Image image, const int batch_count, const int channel_count, const int image_height, const int image_width, const int strip_count, const int strip_width, float* strips); void find_points_from_strip_scores(const float* strips, const int batch_count, const int image_width, const int image_height, const int strip_count, const int model_patch_size, float* points_x, float* points_y, float* point_score); - void fit_circle(const float* points_x, const float* points_y, const float* points_score, const int batch_count, const int point_count, const ConfidenceThresholds confidence_thresholds, const int image_height, const int image_width, float* results); + void fit_circle(const float* points_x, const float* points_y, const float* points_score, const int batch_count, const int point_count, const ConfidenceThresholds confidence_thresholds, const int image_height, const int image_width, float* results, FitCircleStatus* fit_circle_status); } diff --git a/src/torchcontentarea/csrc/implementation.cpp b/src/torchcontentarea/csrc/implementation.cpp index e34becd..048458b 100644 --- a/src/torchcontentarea/csrc/implementation.cpp +++ b/src/torchcontentarea/csrc/implementation.cpp @@ -64,7 +64,7 @@ Image get_image_data(torch::Tensor image) } } -torch::Tensor estimate_area_handcrafted(torch::Tensor image, int strip_count, FeatureThresholds feature_thresholds, ConfidenceThresholds confidence_thresholds) +std::tuple estimate_area_handcrafted(torch::Tensor image, int strip_count, FeatureThresholds feature_thresholds, ConfidenceThresholds confidence_thresholds) { check_image_tensor(image); @@ -79,7 +79,9 @@ torch::Tensor estimate_area_handcrafted(torch::Tensor image, int strip_count, Fe int point_count = 2 * strip_count; torch::TensorOptions options = torch::device(image.device()).dtype(torch::kFloat32); - torch::Tensor result = batched ? torch::empty({batch_count, 4}, options) : torch::empty({4}, options); + auto result = batched ? + std::make_tuple(torch::empty({batch_count, 4}, options), FitCircleStatus::invalid) : + std::make_tuple(torch::empty({4}, options), FitCircleStatus::invalid); if (image.device().is_cpu()) { @@ -90,7 +92,7 @@ torch::Tensor estimate_area_handcrafted(torch::Tensor image, int strip_count, Fe cpu::find_points(image_data, batch_count, channel_count, image_height, image_width, strip_count, feature_thresholds, points_x, points_y, points_s); - cpu::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr()); + cpu::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, std::get(result).data_ptr(), &std::get(result)); free(temp_buffer); } @@ -104,7 +106,7 @@ torch::Tensor estimate_area_handcrafted(torch::Tensor image, int strip_count, Fe cuda::find_points(image_data, batch_count, channel_count, image_height, image_width, strip_count, feature_thresholds, points_x, points_y, points_s); - cuda::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr()); + cuda::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, std::get(result).data_ptr()); cudaFree(temp_buffer); } @@ -112,7 +114,7 @@ torch::Tensor estimate_area_handcrafted(torch::Tensor image, int strip_count, Fe return result; } -torch::Tensor estimate_area_learned(torch::Tensor image, int strip_count, torch::jit::Module model, int model_patch_size, ConfidenceThresholds confidence_thresholds) +std::tuple estimate_area_learned(torch::Tensor image, int strip_count, torch::jit::Module model, int model_patch_size, ConfidenceThresholds confidence_thresholds) { check_image_tensor(image); @@ -127,7 +129,10 @@ torch::Tensor estimate_area_learned(torch::Tensor image, int strip_count, torch: int point_count = 2 * strip_count; torch::TensorOptions options = torch::device(image.device()).dtype(torch::kFloat32); - torch::Tensor result = batched ? torch::empty({batch_count, 4}, options) : torch::empty({4}, options); + auto result = batched ? + std::make_tuple(torch::empty({batch_count, 4}, options), FitCircleStatus::invalid) : + std::make_tuple(torch::empty({4}, options), FitCircleStatus::invalid); + torch::Tensor strips = torch::empty({batch_count * strip_count, 5, model_patch_size, image_width}, options); std::vector model_input = {strips}; @@ -145,7 +150,7 @@ torch::Tensor estimate_area_learned(torch::Tensor image, int strip_count, torch: cpu::find_points_from_strip_scores(strip_scores.data_ptr(), batch_count, image_height, image_width, strip_count, model_patch_size, points_x, points_y, points_s); - cpu::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr()); + cpu::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, std::get(result).data_ptr(), &std::get(result)); free(temp_buffer); } @@ -163,7 +168,7 @@ torch::Tensor estimate_area_learned(torch::Tensor image, int strip_count, torch: cuda::find_points_from_strip_scores(strip_scores.data_ptr(), batch_count, image_height, image_width, strip_count, model_patch_size, points_x, points_y, points_s); - cuda::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr()); + cuda::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, std::get(result).data_ptr()); cudaFree(temp_buffer); } @@ -250,7 +255,7 @@ torch::Tensor get_points_learned(torch::Tensor image, int strip_count, torch::ji return result; } -torch::Tensor fit_area(torch::Tensor points, py::tuple image_size, ConfidenceThresholds confidence_thresholds) +std::tuple fit_area(torch::Tensor points, py::tuple image_size, ConfidenceThresholds confidence_thresholds) { check_points(points); @@ -262,8 +267,10 @@ torch::Tensor fit_area(torch::Tensor points, py::tuple image_size, ConfidenceThr int point_count = points.size(-1); torch::TensorOptions options = torch::device(points.device()).dtype(torch::kFloat32); - torch::Tensor result = batched ? torch::empty({batch_count, 4}, options) : torch::empty({4}, options); - + auto result = batched ? + std::make_tuple(torch::empty({batch_count, 4}, options), FitCircleStatus::invalid) : + std::make_tuple(torch::empty({4}, options), FitCircleStatus::invalid); + float* temp_buffer = points.data_ptr(); float* points_x = temp_buffer + 0 * point_count; float* points_y = temp_buffer + 1 * point_count; @@ -271,11 +278,11 @@ torch::Tensor fit_area(torch::Tensor points, py::tuple image_size, ConfidenceThr if (points.device().is_cpu()) { - cpu::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr()); + cpu::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, std::get(result).data_ptr(), &std::get(result)); } else { - cuda::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr()); + cuda::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, std::get(result).data_ptr()); } return result; diff --git a/src/torchcontentarea/csrc/implementation.hpp b/src/torchcontentarea/csrc/implementation.hpp index 06c09f2..ff7b70f 100644 --- a/src/torchcontentarea/csrc/implementation.hpp +++ b/src/torchcontentarea/csrc/implementation.hpp @@ -1,11 +1,12 @@ #pragma once +#include #include #include "common.hpp" -torch::Tensor estimate_area_handcrafted(torch::Tensor image, int strip_count, FeatureThresholds feature_thresholds, ConfidenceThresholds confidence_thresholds); -torch::Tensor estimate_area_learned(torch::Tensor image, int strip_count, torch::jit::Module model, int model_patch_size, ConfidenceThresholds confidence_thresholds); +std::tuple estimate_area_handcrafted(torch::Tensor image, int strip_count, FeatureThresholds feature_thresholds, ConfidenceThresholds confidence_thresholds); +std::tuple estimate_area_learned(torch::Tensor image, int strip_count, torch::jit::Module model, int model_patch_size, ConfidenceThresholds confidence_thresholds); torch::Tensor get_points_handcrafted(torch::Tensor points, int strip_count, FeatureThresholds feature_thresholds); torch::Tensor get_points_learned(torch::Tensor points, int strip_count, torch::jit::Module model, int model_patch_size); -torch::Tensor fit_area(torch::Tensor points, py::tuple image_size, ConfidenceThresholds confidence_thresholds); +std::tuple fit_area(torch::Tensor points, py::tuple image_size, ConfidenceThresholds confidence_thresholds); diff --git a/src/torchcontentarea/csrc/source/fit_circle_cpu.cpp b/src/torchcontentarea/csrc/source/fit_circle_cpu.cpp index 5dd1c78..b4a0e42 100644 --- a/src/torchcontentarea/csrc/source/fit_circle_cpu.cpp +++ b/src/torchcontentarea/csrc/source/fit_circle_cpu.cpp @@ -159,12 +159,13 @@ namespace cpu // ========================================================================= // Main function... - void fit_circle(const float* points_x, const float* points_y, const float* points_score, const int batch_count, const int point_count, const ConfidenceThresholds confidence_thresholds, const int image_height, const int image_width, float* results) + void fit_circle(const float* points_x, const float* points_y, const float* points_score, const int batch_count, const int point_count, const ConfidenceThresholds confidence_thresholds, const int image_height, const int image_width, float* results, FitCircleStatus* fit_circle_status) { int* compacted_points = (int*)malloc(3 * point_count * sizeof(int)); int* compacted_points_x = compacted_points + 0 * point_count; int* compacted_points_y = compacted_points + 1 * point_count; float* compacted_points_s = (float*)compacted_points + 2 * point_count; + *fit_circle_status = FitCircleStatus::invalid; for (int batch_index = 0; batch_index < batch_count; ++batch_index) @@ -190,10 +191,12 @@ namespace cpu // Early out... if (real_point_count < 3) { + *fit_circle_status = FitCircleStatus::no_points; return; } // Ransac attempts... + bool circle_found = false; for (int ransac_attempt = 0; ransac_attempt < RANSAC_ATTEMPTS; ++ransac_attempt) { int inlier_count = 3; @@ -242,8 +245,13 @@ namespace cpu results[1 + batch_index * 4] = circle_y; results[2 + batch_index * 4] = circle_r; results[3 + batch_index * 4] = circle_score; + circle_found = true; } } + + if (circle_found) { + *fit_circle_status = FitCircleStatus::success; + } } free(compacted_points);