From 4b41566ea333dda61f18f731024761e8d7b4c650 Mon Sep 17 00:00:00 2001 From: guoshengjian Date: Fri, 29 Aug 2025 12:28:47 +0000 Subject: [PATCH 1/2] feat add return word box --- deploy/cpp_infer/cli.cc | 8 +- .../src/api/models/text_recognition.cc | 1 + .../src/api/models/text_recognition.h | 1 + deploy/cpp_infer/src/api/pipelines/ocr.cc | 1 + deploy/cpp_infer/src/api/pipelines/ocr.h | 1 + deploy/cpp_infer/src/common/processors.cc | 137 +++++++++++++++++- deploy/cpp_infer/src/common/processors.h | 16 +- .../src/modules/text_recognition/predictor.cc | 46 +++++- .../src/modules/text_recognition/predictor.h | 3 + .../modules/text_recognition/processors.cc | 127 ++++++++++++++-- .../src/modules/text_recognition/processors.h | 37 +++-- .../cpp_infer/src/pipelines/ocr/pipeline.cc | 49 +++++++ deploy/cpp_infer/src/pipelines/ocr/pipeline.h | 8 +- deploy/cpp_infer/src/pipelines/ocr/result.cc | 83 ++++++++--- deploy/cpp_infer/src/utils/args.cc | 1 + deploy/cpp_infer/src/utils/args.h | 1 + deploy/cpp_infer/src/utils/yaml_config.cc | 7 + 17 files changed, 465 insertions(+), 62 deletions(-) diff --git a/deploy/cpp_infer/cli.cc b/deploy/cpp_infer/cli.cc index a5b526efb1..305e80f1e9 100644 --- a/deploy/cpp_infer/cli.cc +++ b/deploy/cpp_infer/cli.cc @@ -67,7 +67,7 @@ std::tuple -GetPipelineMoudleParams() { +GetPipelineModuleParams() { PaddleOCRParams ocr_params; DocPreprocessorParams doc_pre_params; DocImgOrientationClassificationParams doc_orient_params; @@ -181,6 +181,10 @@ GetPipelineMoudleParams() { if (!FLAGS_text_rec_score_thresh.empty()) { ocr_params.text_rec_score_thresh = std::stof(FLAGS_text_rec_score_thresh); } + if (!FLAGS_return_word_box.empty()) { + ocr_params.return_word_box = Utility::StringToBool(FLAGS_return_word_box); + rec_params.return_word_box = Utility::StringToBool(FLAGS_return_word_box); + } if (!FLAGS_text_rec_input_shape.empty()) { ocr_params.text_rec_input_shape = YamlConfig::SmartParseVector(FLAGS_text_rec_input_shape).vec_int; @@ -280,7 +284,7 @@ int main(int argc, char *argv[]) { " [--param1] [--param2] [...]"); exit(-1); } - auto params = GetPipelineMoudleParams(); + auto params = GetPipelineModuleParams(); using PredFunc = std::function>( const std::string &)>; std::unordered_map pred_map = { diff --git a/deploy/cpp_infer/src/api/models/text_recognition.cc b/deploy/cpp_infer/src/api/models/text_recognition.cc index 6ca09cf47e..90dcea95e5 100644 --- a/deploy/cpp_infer/src/api/models/text_recognition.cc +++ b/deploy/cpp_infer/src/api/models/text_recognition.cc @@ -51,6 +51,7 @@ TextRecPredictorParams TextRecognition::ToTextRecognitionModelParams( COPY_PARAMS(model_dir) COPY_PARAMS(batch_size) COPY_PARAMS(input_shape) + COPY_PARAMS(return_word_box) COPY_PARAMS(vis_font_dir) COPY_PARAMS(device) COPY_PARAMS(enable_mkldnn) diff --git a/deploy/cpp_infer/src/api/models/text_recognition.h b/deploy/cpp_infer/src/api/models/text_recognition.h index 360e036cce..1f872b3398 100644 --- a/deploy/cpp_infer/src/api/models/text_recognition.h +++ b/deploy/cpp_infer/src/api/models/text_recognition.h @@ -29,6 +29,7 @@ struct TextRecognitionParams { int cpu_threads = 8; int batch_size = 1; absl::optional> input_shape = absl::nullopt; + absl::optional return_word_box = absl::nullopt; }; class TextRecognition { diff --git a/deploy/cpp_infer/src/api/pipelines/ocr.cc b/deploy/cpp_infer/src/api/pipelines/ocr.cc index b1f17371c2..0efe7a1c72 100644 --- a/deploy/cpp_infer/src/api/pipelines/ocr.cc +++ b/deploy/cpp_infer/src/api/pipelines/ocr.cc @@ -86,6 +86,7 @@ OCRPipelineParams PaddleOCR::ToOCRPipelineParams(const PaddleOCRParams &from) { COPY_PARAMS(text_det_input_shape) COPY_PARAMS(text_rec_score_thresh) COPY_PARAMS(text_rec_input_shape) + COPY_PARAMS(return_word_box) COPY_PARAMS(lang) COPY_PARAMS(ocr_version) COPY_PARAMS(vis_font_dir) diff --git a/deploy/cpp_infer/src/api/pipelines/ocr.h b/deploy/cpp_infer/src/api/pipelines/ocr.h index c9f87e6003..89327afd50 100644 --- a/deploy/cpp_infer/src/api/pipelines/ocr.h +++ b/deploy/cpp_infer/src/api/pipelines/ocr.h @@ -42,6 +42,7 @@ struct PaddleOCRParams { absl::optional> text_det_input_shape = absl::nullopt; absl::optional text_rec_score_thresh = absl::nullopt; absl::optional> text_rec_input_shape = absl::nullopt; + absl::optional return_word_box = absl::nullopt; absl::optional lang = absl::nullopt; absl::optional ocr_version = absl::nullopt; absl::optional vis_font_dir = absl::nullopt; diff --git a/deploy/cpp_infer/src/common/processors.cc b/deploy/cpp_infer/src/common/processors.cc index 25b82d8f59..cdf2d10cca 100644 --- a/deploy/cpp_infer/src/common/processors.cc +++ b/deploy/cpp_infer/src/common/processors.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -636,9 +637,9 @@ std::vector> ComponentsProcessor::SortPolyBoxes( return dt_polys_rank; } -std::vector> ComponentsProcessor::ConvertPointsToBoxes( +std::vector> ComponentsProcessor::ConvertPointsToBoxes( const std::vector> &dt_polys) { - std::vector> dt_boxes; + std::vector> dt_boxes; for (const auto &poly : dt_polys) { if (poly.empty()) { continue; @@ -658,7 +659,8 @@ std::vector> ComponentsProcessor::ConvertPointsToBoxes( if (pt.y > bottom) bottom = pt.y; } - dt_boxes.push_back({left, top, right, bottom}); + dt_boxes.push_back({static_cast(left), static_cast(top), + static_cast(right), static_cast(bottom)}); } return dt_boxes; } @@ -773,17 +775,14 @@ CropByPolys::GetPolyRectCrop(const cv::Mat &img, if (poly.size() < 4) return absl::InvalidArgumentError( "Less than 4 points for GetPolyRectCrop."); - // 对Poly和最小外接矩形做IoU判断 std::vector minrect = GetMinAreaRectPoints(poly); if (minrect.size() != 4) return absl::InternalError("Failed to get minarea rect."); double iou = IoU(poly, minrect); - // 若IoU>0.7则返回直接crop,否则可做更复杂处理,如透视矫正,可进一步实现自定义变形矫正 auto crop_result = GetRotateCropImage(img, minrect); if (!crop_result.ok()) return crop_result.status(); - // 测试下如果IoU很高就用直接的最小外接矩形crop,否则复杂矫正(本实现只用直接crop) - // 若需更强几何修复,可集成TPS、ThinPlateSpline或AutoRectifier + return *crop_result; } @@ -824,3 +823,127 @@ double CropByPolys::IoU(const std::vector &poly1, return 0.0; return area_inter / area_union; } + +std::vector +ComponentsProcessor::SortBoxes(const std::vector &boxes, float y_thresh) { + struct BoxWithCenter { + Box box; + Point center; + }; + std::vector items; + for (const Box &box : boxes) { + double x = 0, y = 0; + for (const auto &p : box) { + x += p.x; + y += p.y; + } + x /= box.size(); + y /= box.size(); + items.push_back({box, Point(x, y)}); + } + std::sort(items.begin(), items.end(), + [](const BoxWithCenter &a, const BoxWithCenter &b) { + return a.center.y < b.center.y; + }); + + std::vector> lines; + std::vector current_line; + double last_y = NAN; + for (const auto &item : items) { + if (std::isnan(last_y) || std::fabs(item.center.y - last_y) < y_thresh) { + current_line.push_back(item); + } else { + lines.push_back(current_line); + current_line.clear(); + current_line.push_back(item); + } + last_y = item.center.y; + } + if (!current_line.empty()) + lines.push_back(current_line); + + std::vector final_boxes; + for (auto &line : lines) { + std::sort(line.begin(), line.end(), + [](const BoxWithCenter &a, const BoxWithCenter &b) { + return a.center.x < b.center.x; + }); + for (const auto &item : line) { + final_boxes.push_back(item.box); + } + } + return final_boxes; +} + +std::pair, std::vector> +ComponentsProcessor::CalOCRWordBox( + const std::string &rec_str, const Box &box, int col_num, + const std::vector &word_list, + const std::vector> &word_col_list, + const std::vector &state_list) { + std::wstring_convert> converter; + std::wstring text = converter.from_bytes(rec_str); + double bbox_x_start = box[0].x; + double bbox_x_end = box[1].x; + double bbox_y_start = box[0].y; + double bbox_y_end = box[2].y; + double cell_width = (bbox_x_end - bbox_x_start) / col_num; + + std::vector word_box_list; + std::vector word_box_content_list; + std::vector cn_width_list; + std::vector cn_col_list; + std::wstring word_box_content_cn; + + for (size_t idx = 0; idx < word_list.size(); ++idx) { + const std::wstring &word = word_list[idx]; + const std::vector &word_col = word_col_list[idx]; + const std::string &state = state_list[idx]; + if (state == "cn") { + if (word_col.size() != 1) { + double char_seq_length = + (word_col.back() - word_col.front() + 1) * cell_width; + double char_width = char_seq_length / (word_col.size() - 1); + cn_width_list.push_back(char_width); + } + for (int col : word_col) + cn_col_list.push_back(col); + word_box_content_cn += word; + } else { + double cell_x_start = bbox_x_start + word_col.front() * cell_width; + double cell_x_end = bbox_x_start + (word_col.back() + 1) * cell_width; + Box cell = { + Point(cell_x_start, bbox_y_start), Point(cell_x_end, bbox_y_start), + Point(cell_x_end, bbox_y_end), Point(cell_x_start, bbox_y_end)}; + word_box_list.push_back(cell); + word_box_content_list.push_back(word); + } + } + + if (!cn_col_list.empty()) { + double avg_char_width; + if (!cn_width_list.empty()) { + avg_char_width = + std::accumulate(cn_width_list.begin(), cn_width_list.end(), 0.0) / + cn_width_list.size(); + } else { + avg_char_width = (bbox_x_end - bbox_x_start) / rec_str.size(); + } + for (int center_idx : cn_col_list) { + double center_x = (center_idx + 0.5) * cell_width; + double cell_x_start = + std::max(center_x - avg_char_width / 2, 0.0) + bbox_x_start; + double cell_x_end = + std::min(center_x + avg_char_width / 2, bbox_x_end - bbox_x_start) + + bbox_x_start; + Box cell = { + Point(cell_x_start, bbox_y_start), Point(cell_x_end, bbox_y_start), + Point(cell_x_end, bbox_y_end), Point(cell_x_start, bbox_y_end)}; + word_box_list.push_back(cell); + } + word_box_content_list.push_back(word_box_content_cn); + } + + std::vector sorted_word_box_list = SortBoxes(word_box_list, 12.0); + return {word_list, sorted_word_box_list}; +} diff --git a/deploy/cpp_infer/src/common/processors.h b/deploy/cpp_infer/src/common/processors.h index c63f7ae06f..2e816eb12f 100644 --- a/deploy/cpp_infer/src/common/processors.h +++ b/deploy/cpp_infer/src/common/processors.h @@ -135,13 +135,27 @@ class ToBatch : public BaseProcessor { class ComponentsProcessor { public: + // using Point = std::array; + // using Box = std::array ; + using Point = cv::Point2f; + using Box = std::vector; + static absl::StatusOr RotateImage(const cv::Mat &image, int angle); static std::vector> SortQuadBoxes(const std::vector> &dt_polys); static std::vector> SortPolyBoxes(const std::vector> &dt_polys); - static std::vector> + static std::vector> ConvertPointsToBoxes(const std::vector> &dt_polys); + + static std::vector SortBoxes(const std::vector &boxes, + float y_thresh = 10.0); + + static std::pair, std::vector> + CalOCRWordBox(const std::string &rec_str, const Box &box, int col_num, + const std::vector &word_list, + const std::vector> &word_col_list, + const std::vector &state_list); }; class CropByPolys { diff --git a/deploy/cpp_infer/src/modules/text_recognition/predictor.cc b/deploy/cpp_infer/src/modules/text_recognition/predictor.cc index d4df5a908b..dd3ece16a2 100644 --- a/deploy/cpp_infer/src/modules/text_recognition/predictor.cc +++ b/deploy/cpp_infer/src/modules/text_recognition/predictor.cc @@ -36,7 +36,11 @@ TextRecPredictor::TextRecPredictor(const TextRecPredictorParams ¶ms) absl::Status TextRecPredictor::Build() { const auto &pre_params = config_.PreProcessOpInfo(); Register("Read", "BGR"); //****** - Register("ReisizeNorm", params_.input_shape); + rec_image_shape_ = + YamlConfig::SmartParseVector(pre_params.at("RecResizeImg.image_shape")) + .vec_int; + Register("ReisizeNorm", params_.input_shape, + rec_image_shape_); Register("ToBatch"); infer_ptr_ = CreateStaticInfer(); const auto &post_params = config_.PostProcessOpInfo(); @@ -60,6 +64,19 @@ TextRecPredictor::Process(std::vector &batch_data) { exit(-1); } + std::vector width_list; + for (const auto &img : batch_read.value()) { + double ratio = static_cast(img.cols) / static_cast(img.rows); + width_list.push_back(ratio); + } + + std::vector indices(width_list.size()); + for (int i = 0; i < indices.size(); ++i) + indices[i] = i; + + std::sort(indices.begin(), indices.end(), + [&](int a, int b) { return width_list[a] < width_list[b]; }); + auto batch_resize_norm = pre_op_.at("ReisizeNorm")->Apply(batch_read.value()); if (!batch_resize_norm.ok()) { INFOE(batch_resize_norm.status().ToString().c_str()); @@ -77,8 +94,26 @@ TextRecPredictor::Process(std::vector &batch_data) { exit(-1); } - auto ctc_result = - post_op_.at("CTCLabelDecode")->Apply(batch_infer.value()[0]); + int batch_num = batch_sampler_ptr_->BatchSize(); + int img_num = batch_data.size(); + + int imgC = rec_image_shape_[0]; + int imgH = rec_image_shape_[1]; + int imgW = rec_image_shape_[2]; + float max_wh_ratio = static_cast(imgW) / static_cast(imgH); + int end_img_no = std::min(img_num, batch_num); + std::vector wh_ratio_list = {}; + for (int ino = 0; ino < end_img_no; ino++) { + int h = batch_read.value()[indices[ino]].size[0]; + int w = batch_read.value()[indices[ino]].size[1]; + float wh_ratio = static_cast(w) / static_cast(h); + max_wh_ratio = std::max(max_wh_ratio, wh_ratio); + wh_ratio_list.push_back(wh_ratio); + } + auto ctc_result = post_op_.at("CTCLabelDecode") + ->Apply(batch_infer.value()[0], + params_.return_word_box.value_or(false), + wh_ratio_list, max_wh_ratio); if (!ctc_result.ok()) { INFOE(ctc_result.status().ToString().c_str()); @@ -94,8 +129,9 @@ TextRecPredictor::Process(std::vector &batch_data) { predictor_result.input_path = input_path_[input_index_]; } predictor_result.input_image = origin_image[i]; - predictor_result.rec_text = ctc_result.value()[i].first; - predictor_result.rec_score = ctc_result.value()[i].second; + predictor_result.rec_text = ctc_result.value()[i].sentence.first; + predictor_result.rec_score = ctc_result.value()[i].sentence.second; + predictor_result.ctc_result = ctc_result.value()[i]; predictor_result.vis_font = params_.vis_font_dir.value_or(""); predictor_result_vec_.push_back(predictor_result); base_cv_result_ptr_vec.push_back( diff --git a/deploy/cpp_infer/src/modules/text_recognition/predictor.h b/deploy/cpp_infer/src/modules/text_recognition/predictor.h index 41b003ea40..aa0542c695 100644 --- a/deploy/cpp_infer/src/modules/text_recognition/predictor.h +++ b/deploy/cpp_infer/src/modules/text_recognition/predictor.h @@ -26,6 +26,7 @@ struct TextRecPredictorResult { std::string rec_text = ""; float rec_score = 0.0; std::string vis_font = ""; + CTCLabelDecodeResult ctc_result; }; struct TextRecPredictorParams { @@ -41,6 +42,7 @@ struct TextRecPredictorParams { int cpu_threads = 8; int batch_size = 1; absl::optional> input_shape = absl::nullopt; + absl::optional return_word_box = absl::nullopt; }; class TextRecPredictor : public BasePredictor { @@ -66,4 +68,5 @@ class TextRecPredictor : public BasePredictor { std::unique_ptr infer_ptr_; TextRecPredictorParams params_; int input_index_ = 0; + std::vector rec_image_shape_ = {}; }; diff --git a/deploy/cpp_infer/src/modules/text_recognition/processors.cc b/deploy/cpp_infer/src/modules/text_recognition/processors.cc index 1f374b5193..ef3e93f09f 100644 --- a/deploy/cpp_infer/src/modules/text_recognition/processors.cc +++ b/deploy/cpp_infer/src/modules/text_recognition/processors.cc @@ -14,7 +14,10 @@ #include "processors.h" +#include +#include #include +#include #include #include #include @@ -22,7 +25,7 @@ #include "src/utils/utility.h" absl::StatusOr> -OCRReisizeNormImg::Apply(std::vector &input, const void *param) const { +OCRResizeNormImg::Apply(std::vector &input, const void *param) const { std::vector output = {}; output.reserve(input.size()); if (input_shape_.empty()) { @@ -45,7 +48,7 @@ OCRReisizeNormImg::Apply(std::vector &input, const void *param) const { return output; } -absl::StatusOr OCRReisizeNormImg::Resize(cv::Mat &image) const { +absl::StatusOr OCRResizeNormImg::Resize(cv::Mat &image) const { float rec_wh_ratio = (float)rec_image_shape_[2] / (float)rec_image_shape_[1]; float image_wh_ratio = (float)image.size[1] / (float)image.size[0]; float max_wh_ratio = std::max(rec_wh_ratio, image_wh_ratio); @@ -56,7 +59,7 @@ absl::StatusOr OCRReisizeNormImg::Resize(cv::Mat &image) const { return image_result.value(); } -absl::StatusOr OCRReisizeNormImg::StaticResize(cv::Mat &image) const { +absl::StatusOr OCRResizeNormImg::StaticResize(cv::Mat &image) const { cv::Mat resize_image; int img_c = input_shape_[0]; int img_h = input_shape_[1]; @@ -79,7 +82,7 @@ absl::StatusOr OCRReisizeNormImg::StaticResize(cv::Mat &image) const { } absl::StatusOr -OCRReisizeNormImg::ResizeNormImg(cv::Mat &image, float max_wh_ratio) const { +OCRResizeNormImg::ResizeNormImg(cv::Mat &image, float max_wh_ratio) const { assert(rec_image_shape_[0] == image.channels()); int rec_c = rec_image_shape_[0]; int rec_h = rec_image_shape_[1]; @@ -146,26 +149,35 @@ CTCLabelDecode::CTCLabelDecode(const std::vector &character_list, } } -absl::StatusOr>> -CTCLabelDecode::Apply(const cv::Mat &preds) const { +absl::StatusOr> +CTCLabelDecode::Apply(const cv::Mat &preds, const bool return_word_box, + std::vector wh_ratio_list, + float max_wh_ratio) const { auto preds_batch = Utility::SplitBatch(preds); - std::vector> ctc_result = {}; + std::vector ctc_result = {}; ctc_result.reserve(preds_batch.value().size()); if (!preds_batch.ok()) { return preds_batch.status(); } for (const auto &pred : preds_batch.value()) { - auto result = Process(pred); + auto result = Process(pred, return_word_box); if (!result.ok()) { return result.status(); } ctc_result.push_back(result.value()); } + if (return_word_box) { + for (int i = 0; i < ctc_result.size(); i++) { + float wh_ratio = wh_ratio_list[i]; + ctc_result[i].sentence_len = + ctc_result[i].sentence_len * (wh_ratio / max_wh_ratio); + } + } return ctc_result; } -absl::StatusOr> -CTCLabelDecode::Process(const cv::Mat &pred_data) const { +absl::StatusOr +CTCLabelDecode::Process(const cv::Mat &pred_data, bool return_word_box) const { std::vector shape_squeeze = {}; for (int i = 1; i < pred_data.dims; i++) { shape_squeeze.push_back(pred_data.size[i]); @@ -190,16 +202,16 @@ CTCLabelDecode::Process(const cv::Mat &pred_data) const { text_index.push_back(max_idx); text_prob.push_back(max_val); } - auto decode_result = Decode(text_index, text_prob, true); + auto decode_result = Decode(text_index, text_prob, true, return_word_box); if (!decode_result.ok()) { return decode_result.status(); } return decode_result.value(); } -absl::StatusOr> +absl::StatusOr CTCLabelDecode::Decode(std::list &text_index, std::list &text_prob, - bool is_remove_duplicate) const { + bool is_remove_duplicate, bool return_word_box) const { std::vector selection(text_index.size(), true); if (is_remove_duplicate && text_index.size() > 1) { auto prev = text_index.begin(); @@ -258,10 +270,97 @@ CTCLabelDecode::Decode(std::list &text_index, std::list &text_prob, } float sum = std::accumulate(conf_list.begin(), conf_list.end(), 0.0f); float mean = sum / conf_list.size(); + CTCLabelDecodeResult result; + if (return_word_box) { + auto word_info_tuple = GetWordInfo(text, selection); + result.sentence_len = selection.size(); + result.word_list = std::get<0>(word_info_tuple); + result.word_col_list = std::get<1>(word_info_tuple); + result.state_list = std::get<2>(word_info_tuple); + } + result.sentence = std::pair(text, mean); - return std::pair(text, mean); + return result; } void CTCLabelDecode::AddSpecialChar() { character_list_.insert(character_list_.begin(), "blank"); } + +std::tuple, std::vector>, + std::vector> +CTCLabelDecode::GetWordInfo(const std::string &text_origin, + const std::vector &selection) const { + + std::wstring_convert> converter; + std::wstring text = converter.from_bytes(text_origin); + + std::string state = ""; + std::wstring word_content = L""; + std::vector word_col_content; + + std::vector word_list = {}; + std::vector> word_col_list = {}; + std::vector state_list = {}; + + std::vector valid_col; + for (int i = 0; i < selection.size(); ++i) { + if (selection[i]) { + valid_col.push_back(i); + } + } + + std::wregex en_num_pattern(L"[a-zA-Z0-9]"); + std::wregex num_pattern(L"[0-9]"); + + for (int c_i = 0; c_i < text.length(); ++c_i) { + wchar_t ch = text[c_i]; + std::string c_state; + + if (ch >= L'\u4e00' && ch <= L'\u9fff') { + c_state = "cn"; + } else if (std::regex_search(std::wstring(1, ch), en_num_pattern)) { // [5] + c_state = "en&num"; + } else { + c_state = "symbol"; + } + + if (ch == L'.' && state == "en&num" && c_i + 1 < text.length()) { + if (std::regex_search(std::wstring(1, text[c_i + 1]), num_pattern)) { + c_state = "en&num"; + } + } + + if (ch == L'-' && state == "en&num") { + c_state = "en&num"; + } + + if (state.empty()) { + state = c_state; + } + + if (state != c_state && !word_content.empty()) { + if (!word_content.empty()) { + word_list.push_back(word_content); + word_col_list.push_back(word_col_content); + state_list.push_back(state); + word_content.clear(); + word_col_content.clear(); + } + state = c_state; + } + + word_content += ch; + if (c_i < valid_col.size()) { + word_col_content.push_back(valid_col[c_i]); + } + } + + if (!word_content.empty()) { + word_list.push_back(word_content); + word_col_list.push_back(word_col_content); + state_list.push_back(state); + } + auto result = std::make_tuple(word_list, word_col_list, state_list); + return result; +} diff --git a/deploy/cpp_infer/src/modules/text_recognition/processors.h b/deploy/cpp_infer/src/modules/text_recognition/processors.h index 76aef3af99..e5d2d53028 100644 --- a/deploy/cpp_infer/src/modules/text_recognition/processors.h +++ b/deploy/cpp_infer/src/modules/text_recognition/processors.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include "absl/status/status.h" @@ -24,11 +25,10 @@ #include "src/common/processors.h" #include "src/utils/func_register.h" -class OCRReisizeNormImg : public BaseProcessor { +class OCRResizeNormImg : public BaseProcessor { public: - OCRReisizeNormImg( - absl::optional> input_shape = absl::nullopt, - std::vector rec_image_shape = {3, 48, 320}) + OCRResizeNormImg(absl::optional> input_shape = absl::nullopt, + std::vector rec_image_shape = {3, 48, 320}) : rec_image_shape_(rec_image_shape), input_shape_(input_shape.value_or(std::vector())){}; absl::StatusOr> @@ -45,17 +45,34 @@ class OCRReisizeNormImg : public BaseProcessor { std::vector input_shape_; }; +struct CTCLabelDecodeResult { + std::pair sentence; + float sentence_len = -1; + std::vector word_list = {}; + std::vector> word_col_list = {}; + std::vector state_list = {}; +}; + class CTCLabelDecode { public: CTCLabelDecode(const std::vector &character_list = {}, bool use_space_char = true); - absl::StatusOr>> - Apply(const cv::Mat &preds) const; - absl::StatusOr> - Process(const cv::Mat &pred_data) const; - absl::StatusOr> + + absl::StatusOr> + Apply(const cv::Mat &preds, const bool return_word_box = false, + std::vector wh_ratio_list = {}, float max_wh_ratio = 0.0) const; + + absl::StatusOr + Process(const cv::Mat &pred_data, bool return_word_box = false) const; + + absl::StatusOr Decode(std::list &text_index, std::list &text_prob, - bool is_remove_duplicate = false) const; + bool is_remove_duplicate = false, bool return_word_box = false) const; + + std::tuple, std::vector>, + std::vector> + GetWordInfo(const std::string &text_origin, + const std::vector &selection) const; void AddSpecialChar(); private: diff --git a/deploy/cpp_infer/src/pipelines/ocr/pipeline.cc b/deploy/cpp_infer/src/pipelines/ocr/pipeline.cc index 0678db7980..1a8933406e 100644 --- a/deploy/cpp_infer/src/pipelines/ocr/pipeline.cc +++ b/deploy/cpp_infer/src/pipelines/ocr/pipeline.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "pipeline.h" +#include #include "result.h" #include "src/utils/args.h" @@ -214,6 +215,9 @@ _OCRPipeline::_OCRPipeline(const OCRPipelineParams ¶ms) params_rec.input_shape = config_.SmartParseVector(result_rec_input_shape.value()).vec_int; } + params_rec.return_word_box = + config_.GetBool("TextRecognition.return_word_box", false).value(); + return_word_box_ = params_rec.return_word_box.value(); params_rec.model_dir = result_text_rec_model_dir.value(); params_rec.lang = params_.lang; params_rec.ocr_version = params_.ocr_version; @@ -349,6 +353,7 @@ _OCRPipeline::Predict(const std::vector &input) { results[k].text_det_params = text_det_params_; results[k].text_type = text_type_; results[k].text_rec_score_thresh = text_rec_score_thresh_; + results[k].return_word_box = return_word_box_; } if (!indices.empty()) { std::vector all_subs_of_imgs = {}; @@ -435,6 +440,32 @@ _OCRPipeline::Predict(const std::vector &input) { for (int sno = 0; sno < sub_img_info_list.size(); sno++) { auto rec_res = sub_img_info_list[sno].second; if (rec_res.rec_score >= text_rec_score_thresh_) { + if (return_word_box_) { + auto word_result = ComponentsProcessor::CalOCRWordBox( + rec_res.rec_text, dt_polys_list[l][sno], + rec_res.ctc_result.sentence_len, rec_res.ctc_result.word_list, + rec_res.ctc_result.word_col_list, + rec_res.ctc_result.state_list); + auto word_box_content_list = word_result.first; + auto word_box_list = word_result.second; + std::vector word_content = {}; + for (int word_index = 0; + word_index < word_box_content_list.size(); word_index++) { + if (rec_res.ctc_result.state_list[word_index] == "cn") { + for (auto &word : word_box_content_list[word_index]) { + std::wstring ws(1, word); + std::wstring_convert> conv; + word_content.push_back(conv.to_bytes(ws)); + } + } else { + std::wstring_convert> conv; + word_content.push_back( + conv.to_bytes(word_box_content_list[word_index])); + } + } + results[l].text_word.push_back(word_content); + results[l].text_word_region.push_back(word_box_list); + } results[l].rec_texts.push_back(rec_res.rec_text); results[l].rec_scores.push_back(rec_res.rec_score); results[l].rec_polys.push_back(dt_polys_list[l][sno]); @@ -447,6 +478,12 @@ _OCRPipeline::Predict(const std::vector &input) { if (text_type_ == "general") { res.rec_boxes = ComponentsProcessor::ConvertPointsToBoxes(res.rec_polys); + if (return_word_box_) { + for (auto &line : res.text_word_region) { + res.text_word_boxes.push_back( + ComponentsProcessor::ConvertPointsToBoxes(line)); + } + } } pipeline_result_vec_.push_back(res); base_results.push_back(std::unique_ptr(new OCRResult(res))); @@ -764,4 +801,16 @@ void _OCRPipeline::OverrideConfig() { data[key] = Utility::VecToString(params_.text_rec_input_shape.value()); } } + + if (params_.return_word_box.has_value()) { + auto it = config_.FindKey("TextRecognition.return_word_box"); + if (!it.ok()) { + data["SubModules.TextRecognition.return_word_box"] = + params_.return_word_box ? "true" : "false"; + } else { + auto key = it.value().first; + data.erase(data.find(key)); + data[key] = params_.return_word_box ? "true" : "false"; + } + } } diff --git a/deploy/cpp_infer/src/pipelines/ocr/pipeline.h b/deploy/cpp_infer/src/pipelines/ocr/pipeline.h index 875c8949a9..3b7817bd17 100644 --- a/deploy/cpp_infer/src/pipelines/ocr/pipeline.h +++ b/deploy/cpp_infer/src/pipelines/ocr/pipeline.h @@ -50,10 +50,14 @@ struct OCRPipelineResult { std::string text_type = ""; float text_rec_score_thresh = 0.0; std::vector rec_texts = {}; + std::vector> text_word = {}; + std::vector>> text_word_region = {}; std::vector rec_scores = {}; + bool return_word_box = false; std::vector textline_orientation_angles = {}; std::vector> rec_polys = {}; - std::vector> rec_boxes = {}; + std::vector> rec_boxes = {}; + std::vector>> text_word_boxes = {}; std::string vis_fonts = ""; }; @@ -83,6 +87,7 @@ struct OCRPipelineParams { absl::optional> text_det_input_shape = absl::nullopt; absl::optional text_rec_score_thresh = absl::nullopt; absl::optional> text_rec_input_shape = absl::nullopt; + absl::optional return_word_box = absl::nullopt; absl::optional lang = absl::nullopt; absl::optional ocr_version = absl::nullopt; absl::optional vis_font_dir = absl::nullopt; @@ -137,6 +142,7 @@ class _OCRPipeline : public BasePipeline { float text_rec_score_thresh_ = 0.0; std::string text_type_; TextDetParams text_det_params_; + bool return_word_box_; }; class OCRPipeline diff --git a/deploy/cpp_infer/src/pipelines/ocr/result.cc b/deploy/cpp_infer/src/pipelines/ocr/result.cc index f1fb4be640..ae3c8ac6ff 100644 --- a/deploy/cpp_infer/src/pipelines/ocr/result.cc +++ b/deploy/cpp_infer/src/pipelines/ocr/result.cc @@ -28,11 +28,42 @@ using json = nlohmann::json; void OCRResult::SaveToImg(const std::string &save_path) { cv::Mat image = pipeline_result_.doc_preprocessor_res.output_image; - auto texts = pipeline_result_.rec_texts; - std::vector> boxes; - std::vector> boxes_float = - pipeline_result_.rec_polys; - for (const auto &floatPolygon : pipeline_result_.rec_polys) { + std::vector texts = {}; + std::vector> boxes = {}; + std::vector> boxes_float = {}; + if (pipeline_result_.return_word_box) { + std::vector> flat_word_region; + for (const auto &sublist : pipeline_result_.text_word_region) { + for (const auto &item : sublist) { + flat_word_region.push_back(item); + } + } + std::vector text_word = {}; + for (const auto &word : pipeline_result_.text_word) { + for (const auto &item : word) { + text_word.push_back(item); + } + } + for (size_t idx = 0; idx < flat_word_region.size(); ++idx) { + const std::vector &word_region = flat_word_region[idx]; + if (word_region.size() < 4) + continue; + int box_height = static_cast( + std::sqrt(std::pow(word_region[0].x - word_region[3].x, 2) + + std::pow(word_region[0].y - word_region[3].y, 2))); + int box_width = static_cast( + std::sqrt(std::pow(word_region[0].x - word_region[1].x, 2) + + std::pow(word_region[0].y - word_region[1].y, 2))); + if (box_height == 0 || box_width == 0) + continue; + boxes_float.push_back(word_region); + texts.push_back(text_word[idx]); + } + } else { + texts = pipeline_result_.rec_texts; + boxes_float = pipeline_result_.rec_polys; + } + for (const auto &floatPolygon : boxes_float) { std::vector intPolygon; for (const auto &point : floatPolygon) { intPolygon.push_back(cv::Point(cvRound(point.x), cvRound(point.y))); @@ -356,6 +387,7 @@ void OCRResult::SaveToJson(const std::string &save_path) const { pipeline_result_.textline_orientation_angles; } j["text_rec_score_thresh"] = pipeline_result_.text_rec_score_thresh; + j["return_word_box"] = pipeline_result_.return_word_box; j["rec_texts"] = pipeline_result_.rec_texts; j["rec_scores"] = pipeline_result_.rec_scores; json rec_polys_json = json::array(); @@ -369,19 +401,11 @@ void OCRResult::SaveToJson(const std::string &save_path) const { } j["rec_polys"] = rec_polys_json; - std::vector> int_vec; - int_vec.reserve(pipeline_result_.rec_boxes.size()); - - std::transform(pipeline_result_.rec_boxes.begin(), - pipeline_result_.rec_boxes.end(), std::back_inserter(int_vec), - [](const std::array &arr) { - std::array res; - for (size_t i = 0; i < 4; ++i) { - res[i] = static_cast(arr[i]); - } - return res; - }); - j["rec_boxes"] = int_vec; + j["rec_boxes"] = pipeline_result_.rec_boxes; + if (pipeline_result_.return_word_box) { + j["text_word_boxes"] = pipeline_result_.text_word_boxes; + j["text_word"] = pipeline_result_.text_word; + } absl::StatusOr full_path; if (pipeline_result_.input_path.empty()) { @@ -479,7 +503,7 @@ void PrintIntArray(const std::vector &arr) { std::cout << "]"; } -void PrintRecBoxes(const std::vector> &arr) { +void PrintRecBoxes(const std::vector> &arr) { std::cout << "["; for (size_t i = 0; i < arr.size(); ++i) { if (i != 0) @@ -509,18 +533,33 @@ void OCRResult::Print() const { PrintDocPreprocessorPipelineResult(pipeline_result_.doc_preprocessor_res); std::cout << ",\n"; } - std::cout << " \"dt_polys\": "; - PrintPolys(pipeline_result_.dt_polys); - std::cout << ",\n"; std::cout << " \"model_settings\": "; PrintModelSettings(pipeline_result_.model_settings); std::cout << ",\n"; + std::cout << " \"dt_polys\": "; + PrintPolys(pipeline_result_.dt_polys); + std::cout << ",\n"; std::cout << " \"text_det_params\": "; PrintTextDetParams(pipeline_result_.text_det_params); std::cout << ",\n"; std::cout << " \"text_type\": \"" << pipeline_result_.text_type << "\",\n"; std::cout << " \"text_rec_score_thresh\": " << pipeline_result_.text_rec_score_thresh << ",\n"; + std::cout << " \"return_word_box\": " + << (pipeline_result_.return_word_box ? "true" : "false") << ",\n"; + + std::cout << " \"text_word_boxes\": "; + if (pipeline_result_.return_word_box) { + for (const auto &item : pipeline_result_.text_word_boxes) { + PrintRecBoxes(item); + } + } + std::cout << " \"text_word\": "; + if (pipeline_result_.return_word_box) { + for (const auto &item : pipeline_result_.text_word) { + PrintStringArray(item); + } + } std::cout << " \"rec_texts\": "; PrintStringArray(pipeline_result_.rec_texts); std::cout << ",\n"; diff --git a/deploy/cpp_infer/src/utils/args.cc b/deploy/cpp_infer/src/utils/args.cc index 444cf9cb62..5fe0e746f5 100644 --- a/deploy/cpp_infer/src/utils/args.cc +++ b/deploy/cpp_infer/src/utils/args.cc @@ -75,6 +75,7 @@ DEFINE_string(text_rec_score_thresh, "0", "than this threshold are retained."); DEFINE_string(text_rec_input_shape, "", "Input shape of the text recognition model.eg C,H,W"); +DEFINE_string(return_word_box, "", "Determines whether to return word box"); DEFINE_string(lang, "", "Language in the input image for OCR processing."); DEFINE_string(ocr_version, "", "PP-OCR version to use."); #ifdef WITH_GPU diff --git a/deploy/cpp_infer/src/utils/args.h b/deploy/cpp_infer/src/utils/args.h index af8091d4a9..80579dcb29 100644 --- a/deploy/cpp_infer/src/utils/args.h +++ b/deploy/cpp_infer/src/utils/args.h @@ -41,6 +41,7 @@ DECLARE_string(text_det_unclip_ratio); DECLARE_string(text_det_input_shape); DECLARE_string(text_rec_score_thresh); DECLARE_string(text_rec_input_shape); +DECLARE_string(return_word_box); DECLARE_string(lang); DECLARE_string(ocr_version); DECLARE_string(device); diff --git a/deploy/cpp_infer/src/utils/yaml_config.cc b/deploy/cpp_infer/src/utils/yaml_config.cc index e1132c309c..effdbbfd43 100644 --- a/deploy/cpp_infer/src/utils/yaml_config.cc +++ b/deploy/cpp_infer/src/utils/yaml_config.cc @@ -114,6 +114,13 @@ void YamlConfig::Init() { pre_process_op_info_["CropImage.size"] = info.second; } else if (info.first.find("ToCHWImage") != std::string::npos) { pre_process_op_info_["ToCHWImage"] = info.second; + } else if (info.first.find("RecResizeImg.image_shape") != + std::string::npos) { + size_t pos = info.first.find("RecResizeImg.image_shape"); + size_t after = pos + std::string("RecResizeImg.image_shape").size(); + if (info.first[after] != '[') { + pre_process_op_info_["RecResizeImg.image_shape"] = info.second; + } } else if (info.first.find("KeepKeys.keep_keys") != std::string::npos) { pre_process_op_info_["KeepKeys.keep_keys"] = info.second; } else if (info.first.find("PostProcess.name") != std::string::npos) { From 0a26a8dfa150f1e48fc2821c6b6641f7ec5d0a6b Mon Sep 17 00:00:00 2001 From: guoshengjian Date: Fri, 29 Aug 2025 13:23:10 +0000 Subject: [PATCH 2/2] modify ocr.md --- docs/version3.x/deployment/cpp/OCR.en.md | 6 ++++++ docs/version3.x/deployment/cpp/OCR.md | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/docs/version3.x/deployment/cpp/OCR.en.md b/docs/version3.x/deployment/cpp/OCR.en.md index 71d5c53ebe..76776bfd89 100644 --- a/docs/version3.x/deployment/cpp/OCR.en.md +++ b/docs/version3.x/deployment/cpp/OCR.en.md @@ -604,6 +604,12 @@ Any floating-point number greater than 0. If not set, it will use t str "" + +return_word_box +Whether to return the bounding box of a single character. If not set, it will use the default value initialized by the pipeline, which is initialized to false by default. +bool +false + diff --git a/docs/version3.x/deployment/cpp/OCR.md b/docs/version3.x/deployment/cpp/OCR.md index 9e83de7b0a..c3f470c35b 100644 --- a/docs/version3.x/deployment/cpp/OCR.md +++ b/docs/version3.x/deployment/cpp/OCR.md @@ -603,6 +603,12 @@ MKL-DNN 缓存容量。 str "" + +return_word_box +是否返回单字符坐标框。如果不设置,将使用产线初始化的该参数值,默认初始化为 false。 +bool +false +