Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/torchcontentarea/csrc/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ enum ImageFormat
gray_long,
};

enum FitCircleStatus
{
success,
no_points,
invalid,
};

struct Image
{
Image(ImageFormat format, const void* data) : format(format), data(data) {}
Expand Down
2 changes: 1 addition & 1 deletion src/torchcontentarea/csrc/cpu_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ namespace cpu
void find_points(Image image, const int batch_count, const int channel_count, const int image_height, const int image_width, const int strip_count, const FeatureThresholds feature_thresholds, float* points_x, float* points_y, float* point_score);
void make_strips(Image image, const int batch_count, const int channel_count, const int image_height, const int image_width, const int strip_count, const int strip_width, float* strips);
void find_points_from_strip_scores(const float* strips, const int batch_count, const int image_width, const int image_height, const int strip_count, const int model_patch_size, float* points_x, float* points_y, float* point_score);
void fit_circle(const float* points_x, const float* points_y, const float* points_score, const int batch_count, const int point_count, const ConfidenceThresholds confidence_thresholds, const int image_height, const int image_width, float* results);
void fit_circle(const float* points_x, const float* points_y, const float* points_score, const int batch_count, const int point_count, const ConfidenceThresholds confidence_thresholds, const int image_height, const int image_width, float* results, FitCircleStatus* fit_circle_status);
}
33 changes: 20 additions & 13 deletions src/torchcontentarea/csrc/implementation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ Image get_image_data(torch::Tensor image)
}
}

torch::Tensor estimate_area_handcrafted(torch::Tensor image, int strip_count, FeatureThresholds feature_thresholds, ConfidenceThresholds confidence_thresholds)
std::tuple<torch::Tensor, FitCircleStatus> estimate_area_handcrafted(torch::Tensor image, int strip_count, FeatureThresholds feature_thresholds, ConfidenceThresholds confidence_thresholds)
{
check_image_tensor(image);

Expand All @@ -79,7 +79,9 @@ torch::Tensor estimate_area_handcrafted(torch::Tensor image, int strip_count, Fe
int point_count = 2 * strip_count;

torch::TensorOptions options = torch::device(image.device()).dtype(torch::kFloat32);
torch::Tensor result = batched ? torch::empty({batch_count, 4}, options) : torch::empty({4}, options);
auto result = batched ?
std::make_tuple(torch::empty({batch_count, 4}, options), FitCircleStatus::invalid) :
std::make_tuple(torch::empty({4}, options), FitCircleStatus::invalid);

if (image.device().is_cpu())
{
Expand All @@ -90,7 +92,7 @@ torch::Tensor estimate_area_handcrafted(torch::Tensor image, int strip_count, Fe

cpu::find_points(image_data, batch_count, channel_count, image_height, image_width, strip_count, feature_thresholds, points_x, points_y, points_s);

cpu::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr<float>());
cpu::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, std::get<torch::Tensor>(result).data_ptr<float>(), &std::get<FitCircleStatus>(result));

free(temp_buffer);
}
Expand All @@ -104,15 +106,15 @@ torch::Tensor estimate_area_handcrafted(torch::Tensor image, int strip_count, Fe

cuda::find_points(image_data, batch_count, channel_count, image_height, image_width, strip_count, feature_thresholds, points_x, points_y, points_s);

cuda::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr<float>());
cuda::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, std::get<torch::Tensor>(result).data_ptr<float>());

cudaFree(temp_buffer);
}

return result;
}

torch::Tensor estimate_area_learned(torch::Tensor image, int strip_count, torch::jit::Module model, int model_patch_size, ConfidenceThresholds confidence_thresholds)
std::tuple<torch::Tensor, FitCircleStatus> estimate_area_learned(torch::Tensor image, int strip_count, torch::jit::Module model, int model_patch_size, ConfidenceThresholds confidence_thresholds)
{
check_image_tensor(image);

Expand All @@ -127,7 +129,10 @@ torch::Tensor estimate_area_learned(torch::Tensor image, int strip_count, torch:
int point_count = 2 * strip_count;

torch::TensorOptions options = torch::device(image.device()).dtype(torch::kFloat32);
torch::Tensor result = batched ? torch::empty({batch_count, 4}, options) : torch::empty({4}, options);
auto result = batched ?
std::make_tuple(torch::empty({batch_count, 4}, options), FitCircleStatus::invalid) :
std::make_tuple(torch::empty({4}, options), FitCircleStatus::invalid);


torch::Tensor strips = torch::empty({batch_count * strip_count, 5, model_patch_size, image_width}, options);
std::vector<torch::jit::IValue> model_input = {strips};
Expand All @@ -145,7 +150,7 @@ torch::Tensor estimate_area_learned(torch::Tensor image, int strip_count, torch:

cpu::find_points_from_strip_scores(strip_scores.data_ptr<float>(), batch_count, image_height, image_width, strip_count, model_patch_size, points_x, points_y, points_s);

cpu::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr<float>());
cpu::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, std::get<torch::Tensor>(result).data_ptr<float>(), &std::get<FitCircleStatus>(result));

free(temp_buffer);
}
Expand All @@ -163,7 +168,7 @@ torch::Tensor estimate_area_learned(torch::Tensor image, int strip_count, torch:

cuda::find_points_from_strip_scores(strip_scores.data_ptr<float>(), batch_count, image_height, image_width, strip_count, model_patch_size, points_x, points_y, points_s);

cuda::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr<float>());
cuda::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, std::get<torch::Tensor>(result).data_ptr<float>());

cudaFree(temp_buffer);
}
Expand Down Expand Up @@ -250,7 +255,7 @@ torch::Tensor get_points_learned(torch::Tensor image, int strip_count, torch::ji
return result;
}

torch::Tensor fit_area(torch::Tensor points, py::tuple image_size, ConfidenceThresholds confidence_thresholds)
std::tuple<torch::Tensor, FitCircleStatus> fit_area(torch::Tensor points, py::tuple image_size, ConfidenceThresholds confidence_thresholds)
{
check_points(points);

Expand All @@ -262,20 +267,22 @@ torch::Tensor fit_area(torch::Tensor points, py::tuple image_size, ConfidenceThr
int point_count = points.size(-1);

torch::TensorOptions options = torch::device(points.device()).dtype(torch::kFloat32);
torch::Tensor result = batched ? torch::empty({batch_count, 4}, options) : torch::empty({4}, options);

auto result = batched ?
std::make_tuple(torch::empty({batch_count, 4}, options), FitCircleStatus::invalid) :
std::make_tuple(torch::empty({4}, options), FitCircleStatus::invalid);

float* temp_buffer = points.data_ptr<float>();
float* points_x = temp_buffer + 0 * point_count;
float* points_y = temp_buffer + 1 * point_count;
float* points_s = temp_buffer + 2 * point_count;

if (points.device().is_cpu())
{
cpu::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr<float>());
cpu::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, std::get<torch::Tensor>(result).data_ptr<float>(), &std::get<FitCircleStatus>(result));
}
else
{
cuda::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr<float>());
cuda::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, std::get<torch::Tensor>(result).data_ptr<float>());
}

return result;
Expand Down
7 changes: 4 additions & 3 deletions src/torchcontentarea/csrc/implementation.hpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#pragma once
#include <tuple>
#include <torch/extension.h>
#include "common.hpp"

torch::Tensor estimate_area_handcrafted(torch::Tensor image, int strip_count, FeatureThresholds feature_thresholds, ConfidenceThresholds confidence_thresholds);
torch::Tensor estimate_area_learned(torch::Tensor image, int strip_count, torch::jit::Module model, int model_patch_size, ConfidenceThresholds confidence_thresholds);
std::tuple<torch::Tensor, FitCircleStatus> estimate_area_handcrafted(torch::Tensor image, int strip_count, FeatureThresholds feature_thresholds, ConfidenceThresholds confidence_thresholds);
std::tuple<torch::Tensor, FitCircleStatus> estimate_area_learned(torch::Tensor image, int strip_count, torch::jit::Module model, int model_patch_size, ConfidenceThresholds confidence_thresholds);

torch::Tensor get_points_handcrafted(torch::Tensor points, int strip_count, FeatureThresholds feature_thresholds);
torch::Tensor get_points_learned(torch::Tensor points, int strip_count, torch::jit::Module model, int model_patch_size);

torch::Tensor fit_area(torch::Tensor points, py::tuple image_size, ConfidenceThresholds confidence_thresholds);
std::tuple<torch::Tensor, FitCircleStatus> fit_area(torch::Tensor points, py::tuple image_size, ConfidenceThresholds confidence_thresholds);
10 changes: 9 additions & 1 deletion src/torchcontentarea/csrc/source/fit_circle_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,13 @@ namespace cpu
// =========================================================================
// Main function...

void fit_circle(const float* points_x, const float* points_y, const float* points_score, const int batch_count, const int point_count, const ConfidenceThresholds confidence_thresholds, const int image_height, const int image_width, float* results)
void fit_circle(const float* points_x, const float* points_y, const float* points_score, const int batch_count, const int point_count, const ConfidenceThresholds confidence_thresholds, const int image_height, const int image_width, float* results, FitCircleStatus* fit_circle_status)
{
int* compacted_points = (int*)malloc(3 * point_count * sizeof(int));
int* compacted_points_x = compacted_points + 0 * point_count;
int* compacted_points_y = compacted_points + 1 * point_count;
float* compacted_points_s = (float*)compacted_points + 2 * point_count;
*fit_circle_status = FitCircleStatus::invalid;


for (int batch_index = 0; batch_index < batch_count; ++batch_index)
Expand All @@ -190,10 +191,12 @@ namespace cpu
// Early out...
if (real_point_count < 3)
{
*fit_circle_status = FitCircleStatus::no_points;
return;
}

// Ransac attempts...
bool circle_found = false;
for (int ransac_attempt = 0; ransac_attempt < RANSAC_ATTEMPTS; ++ransac_attempt)
{
int inlier_count = 3;
Expand Down Expand Up @@ -242,8 +245,13 @@ namespace cpu
results[1 + batch_index * 4] = circle_y;
results[2 + batch_index * 4] = circle_r;
results[3 + batch_index * 4] = circle_score;
circle_found = true;
}
}

if (circle_found) {
*fit_circle_status = FitCircleStatus::success;
}
}

free(compacted_points);
Expand Down