diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 00000000..86f68ea8 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,13 @@ +**Issues:** + +- https://github.com/openmpf/openmpf/issues/1 < Example +- https://github.com/openmpf/openmpf/issues/2 < Example + +If a related issue doesn't exist, then create one first and assign it to yourself. + +**Related PRs:** + +- https://github.com/openmpf/openmpf/pull/22 < Example +- https://github.com/openmpf/openmpf/pull/53 < Example + +Please review our [Contributor Guide](https://openmpf.github.io/docs/site/Contributor-Guide/index.html). diff --git a/cpp/TesseractOCRTextDetection/README.md b/cpp/TesseractOCRTextDetection/README.md index e27637d7..77f4409e 100755 --- a/cpp/TesseractOCRTextDetection/README.md +++ b/cpp/TesseractOCRTextDetection/README.md @@ -3,8 +3,11 @@ This repository contains source code and model data for the OpenMPF Tesseract OCR text detection component. -The component extracts text found in an image, reported as a single track -detection. PDF documents can also be processed with one track detection per +The component extracts text found in an image, video, or generic document. +Image results are reported as a single track +detection [per specified language setting](#detecting-multiple-languages). +Video results are reported as single track detections per frame and language setting. +PDF documents can also be processed with one track detection per page. The first page corresponds to the detection property `PAGE_NUM=1`. For debugging purposes, images converted from documents are stored in a temporary job directory under `plugin/TesseractOCR/tmp-[job-id]-[random tag]`. This diff --git a/cpp/TesseractOCRTextDetection/TesseractOCRTextDetection.cpp b/cpp/TesseractOCRTextDetection/TesseractOCRTextDetection.cpp index e4ffe231..45dc838b 100755 --- a/cpp/TesseractOCRTextDetection/TesseractOCRTextDetection.cpp +++ b/cpp/TesseractOCRTextDetection/TesseractOCRTextDetection.cpp @@ -28,7 +28,6 @@ #include #include -#include #include #include @@ -50,13 +49,16 @@ #include #include +#include #include using namespace MPF; using namespace COMPONENT; using namespace std; + using log4cxx::Logger; + typedef boost::multi_index::multi_index_container, @@ -66,9 +68,9 @@ typedef boost::multi_index::multi_index_container TesseractOCRTextDetection::GetDetections(const MPFAudioJob &job) { + throw MPFDetectionException(MPF_UNSUPPORTED_DATA_TYPE); } /* @@ -409,7 +416,7 @@ string random_string(size_t length) { /* * Preprocess image before running OSD and OCR. */ -void TesseractOCRTextDetection::preprocess_image(const MPFImageJob &job, cv::Mat &image_data, +void TesseractOCRTextDetection::preprocess_image(const MPFJob &job, cv::Mat &image_data, const OCR_filter_settings &ocr_fset) { if (image_data.empty()) { throw MPFDetectionException(MPF_IMAGE_READ_ERROR, @@ -453,17 +460,17 @@ void TesseractOCRTextDetection::preprocess_image(const MPFImageJob &job, cv::Mat cv::subtract(tmp_imb, image_data, image_data); } } - catch (const std::exception& ex) { + catch (const exception& ex) { throw MPFDetectionException( MPF_OTHER_DETECTION_ERROR_TYPE, - std::string("Error during image preprocessing: ") + ex.what()); + string("Error during image preprocessing: ") + ex.what()); } } /* * Rescales image before running OSD and OCR. */ -void TesseractOCRTextDetection::rescale_image(const MPFImageJob &job, cv::Mat &image_data, +void TesseractOCRTextDetection::rescale_image(const MPFJob &job, cv::Mat &image_data, const OCR_filter_settings &ocr_fset) { int im_width = image_data.size().width; int im_height = image_data.size().height; @@ -481,12 +488,12 @@ void TesseractOCRTextDetection::rescale_image(const MPFImageJob &job, cv::Mat &i if (im_height < ocr_fset.invalid_min_image_size) { throw MPFDetectionException(MPF_BAD_FRAME_SIZE, - "Invalid image height, image too short: " + std::to_string(im_height)); + "Invalid image height, image too short: " + to_string(im_height)); } if (im_width < ocr_fset.invalid_min_image_size) { throw MPFDetectionException(MPF_BAD_FRAME_SIZE, - "Invalid image width, image too narrow: " + std::to_string(im_width)); + "Invalid image width, image too narrow: " + to_string(im_width)); } int min_dim = im_width; @@ -530,7 +537,7 @@ void TesseractOCRTextDetection::rescale_image(const MPFImageJob &job, cv::Mat &i // If rescale still exceeds pixel limits, decrease further. if (ocr_fset.max_pixels > 0 && (default_rescale * im_width) * (im_height * default_rescale) > ocr_fset.max_pixels) { - default_rescale = std::sqrt((double)ocr_fset.max_pixels / (double)(im_height * im_width)); + default_rescale = sqrt((double)ocr_fset.max_pixels / (double)(im_height * im_width)); } if (min_dim * default_rescale < ocr_fset.invalid_min_image_size) { @@ -577,9 +584,9 @@ void TesseractOCRTextDetection::rescale_image(const MPFImageJob &job, cv::Mat &i if (ocr_fset.sharpen > 0) { sharpen(image_data, ocr_fset.sharpen); } - } catch (const std::exception& ex) { + } catch (const exception& ex) { throw MPFDetectionException(MPF_OTHER_DETECTION_ERROR_TYPE, - std::string("Error during image rescaling: ") + ex.what()); + string("Error during image rescaling: ") + ex.what()); } } @@ -594,7 +601,7 @@ void TesseractOCRTextDetection::process_tesseract_lang_model(OCR_job_inputs &inp bool tess_api_in_map = inputs.tess_api_map->find(tess_api_key) != inputs.tess_api_map->end(); string tessdata_dir; - std::unique_ptr tess_api_for_parallel; + unique_ptr tess_api_for_parallel; set languages_found, missing_languages; // If running OSD scripts, set tessdata_dir to the location found during OSD processing. // Otherwise, for each individual language setting, locate the appropriate tessdata directory. @@ -612,7 +619,6 @@ void TesseractOCRTextDetection::process_tesseract_lang_model(OCR_job_inputs &inp // Fail if user specified language is missing. // This error should not occur as both OSD and user specified languages have been already checked. if (!missing_languages.empty()) { - results.job_status = MPF_COULD_NOT_OPEN_DATAFILE; throw MPFDetectionException(MPF_COULD_NOT_OPEN_DATAFILE, "Tesseract language models not found."); } @@ -624,9 +630,9 @@ void TesseractOCRTextDetection::process_tesseract_lang_model(OCR_job_inputs &inp } else if (!tess_api_in_map) { inputs.tess_api_map->emplace( - std::piecewise_construct, - std::forward_as_tuple(tess_api_key), - std::forward_as_tuple(tessdata_dir, results.lang, (tesseract::OcrEngineMode) inputs.ocr_fset->oem)); + piecewise_construct, + forward_as_tuple(tess_api_key), + forward_as_tuple(tessdata_dir, results.lang, (tesseract::OcrEngineMode) inputs.ocr_fset->oem)); } } else if (!missing_languages.empty()) { LOG4CXX_WARN(inputs.hw_logger_, "[" + *inputs.job_name + "] Tesseract language models no longer found in " + @@ -640,14 +646,14 @@ void TesseractOCRTextDetection::process_tesseract_lang_model(OCR_job_inputs &inp tess_api.SetPageSegMode((tesseract::PageSegMode) inputs.ocr_fset->psm); tess_api.SetImage(*inputs.imi); - std::string text = tess_api.GetUTF8Text(); + string text = tess_api.GetUTF8Text(); results.confidence = tess_api.MeanTextConf(); // Free up recognition results and any stored image data. tess_api.Clear(); if (!text.empty()) { - results.text_result = std::move(text); + results.text_result = move(text); } else { LOG4CXX_WARN(inputs.hw_logger_, "[" + *inputs.job_name + "] OCR processing unsuccessful, no outputs."); } @@ -657,26 +663,25 @@ void TesseractOCRTextDetection::process_parallel_image_runs(OCR_job_inputs &inpu Image_results &results) { OCR_results ocr_results[inputs.ocr_lang_inputs.size()]; - std::future ocr_threads[inputs.ocr_lang_inputs.size()]; + future ocr_threads[inputs.ocr_lang_inputs.size()]; int index = 0; - std::set active_threads; + set active_threads; // Initialize a new track for each language specified. for (const string &lang: inputs.ocr_lang_inputs) { - ocr_results[index].job_status = results.job_status; ocr_results[index].lang = lang; - ocr_threads[index] = std::async(launch::async, + ocr_threads[index] = async(launch::async, &process_tesseract_lang_model, - std::ref(inputs), - std::ref(ocr_results[index])); + ref(inputs), + ref(ocr_results[index])); active_threads.insert(index); index++; while (active_threads.size() >= inputs.ocr_fset->max_parallel_ocr_threads) { - std::this_thread::sleep_for(std::chrono::seconds(1)); + this_thread::sleep_for(chrono::seconds(1)); for (auto active_thread_iter = active_threads.begin(); active_thread_iter != active_threads.end(); /* Intentionally no ++ to support erasing. */) { int i = *active_thread_iter; - if (ocr_threads[i].wait_for(std::chrono::milliseconds(0)) == std::future_status::ready) { + if (ocr_threads[i].wait_for(chrono::milliseconds(0)) == future_status::ready) { // Will re-throw exception from thread. ocr_threads[i].get(); active_thread_iter = active_threads.erase(active_thread_iter); @@ -708,8 +713,6 @@ void TesseractOCRTextDetection::process_serial_image_runs(OCR_job_inputs &inputs Image_results &results) { for (const string &lang: inputs.ocr_lang_inputs) { OCR_results ocr_results; - - ocr_results.job_status = results.job_status; ocr_results.lang = lang; process_tesseract_lang_model(inputs, ocr_results); @@ -721,7 +724,7 @@ void TesseractOCRTextDetection::process_serial_image_runs(OCR_job_inputs &inputs } } -bool TesseractOCRTextDetection::process_ocr_text(Properties &detection_properties, const MPFImageJob &job, +bool TesseractOCRTextDetection::process_ocr_text(Properties &detection_properties, const MPFJob &job, const TesseractOCRTextDetection::OCR_output &ocr_out, const TesseractOCRTextDetection::OCR_filter_settings &ocr_fset, int page_num) { @@ -885,7 +888,7 @@ string TesseractOCRTextDetection::process_osd_lang(const string &script_type, co // Both the scaled imi image and unscaled imi_original image are also needed since vertical rotations may result in // changes to the original image rescaling. The final rescaled image will be stored in imi_scaled. bool TesseractOCRTextDetection::get_OSD_rotation(OSResults *results, cv::Mat &imi_scaled, cv::Mat &imi_original, - int &rotation, const MPFImageJob &job, OCR_filter_settings &ocr_fset) { + int &rotation, const MPFJob &job, OCR_filter_settings &ocr_fset) { switch (results->best_result.orientation_id) { case 0: @@ -932,7 +935,7 @@ bool TesseractOCRTextDetection::get_OSD_rotation(OSResults *results, cv::Mat &im return true; } -void TesseractOCRTextDetection::get_OSD(OSBestResult &best_result, cv::Mat &imi, const MPFImageJob &job, +void TesseractOCRTextDetection::get_OSD(OSBestResult &best_result, cv::Mat &imi, const MPFJob &job, OCR_filter_settings &ocr_fset, Properties &detection_properties, string &tessdata_script_dir, set &missing_languages) { @@ -971,9 +974,9 @@ void TesseractOCRTextDetection::get_OSD(OSBestResult &best_result, cv::Mat &imi, } tess_api_map.emplace( - std::piecewise_construct, - std::forward_as_tuple(tess_api_key), - std::forward_as_tuple(tessdata_dir, "osd", (tesseract::OcrEngineMode) oem)); + piecewise_construct, + forward_as_tuple(tess_api_key), + forward_as_tuple(tessdata_dir, "osd", (tesseract::OcrEngineMode) oem)); LOG4CXX_DEBUG(hw_logger_, "[" + job.job_name + "] OSD model ready."); } @@ -1227,13 +1230,11 @@ void TesseractOCRTextDetection::get_OSD(OSBestResult &best_result, cv::Mat &imi, } - void -TesseractOCRTextDetection::load_settings(const MPFJob &job, OCR_filter_settings &ocr_fset, - const Text_type &text_type) { - // Load in settings specified from job_properties and default configuration. - - // Image preprocessing +TesseractOCRTextDetection::load_image_preprocessing_settings(const MPFJob &job, + OCR_filter_settings &ocr_fset, + const Text_type &text_type) { + // Image preprocessing bool default_processing_wild = DetectionComponentUtils::GetProperty(job.job_properties, "UNSTRUCTURED_TEXT_ENABLE_PREPROCESSING", default_ocr_fset.processing_wild_text); if ((text_type == Unstructured) || (text_type == Unknown && default_processing_wild)) { @@ -1251,7 +1252,12 @@ TesseractOCRTextDetection::load_settings(const MPFJob &job, OCR_filter_settings ocr_fset.enable_otsu_thrs = DetectionComponentUtils::GetProperty(job.job_properties,"STRUCTURED_TEXT_ENABLE_OTSU_THRS", default_ocr_fset.enable_otsu_thrs); ocr_fset.sharpen = DetectionComponentUtils::GetProperty(job.job_properties,"STRUCTURED_TEXT_SHARPEN", default_ocr_fset.sharpen); } +} + +void +TesseractOCRTextDetection::load_settings(const MPFJob &job, OCR_filter_settings &ocr_fset) { + // Load in settings specified from job_properties and default configuration ocr_fset.invert = DetectionComponentUtils::GetProperty(job.job_properties,"INVERT", default_ocr_fset.invert); ocr_fset.min_height = DetectionComponentUtils::GetProperty(job.job_properties, "MIN_HEIGHT", default_ocr_fset.min_height); ocr_fset.invalid_min_image_size = DetectionComponentUtils::GetProperty(job.job_properties, "INVALID_MIN_IMAGE_SIZE", default_ocr_fset.invalid_min_image_size); @@ -1261,7 +1267,7 @@ TesseractOCRTextDetection::load_settings(const MPFJob &job, OCR_filter_settings ocr_fset.adaptive_thrs_pixel = DetectionComponentUtils::GetProperty(job.job_properties,"ADAPTIVE_THRS_BLOCKSIZE", default_ocr_fset.adaptive_thrs_pixel); // OCR and OSD Engine Settings. - ocr_fset.tesseract_lang = DetectionComponentUtils::GetProperty(job.job_properties,"TESSERACT_LANGUAGE", default_ocr_fset.tesseract_lang); + ocr_fset.tesseract_lang = DetectionComponentUtils::GetProperty(job.job_properties,"TESSERACT_LANGUAGE", default_ocr_fset.tesseract_lang); ocr_fset.psm = DetectionComponentUtils::GetProperty(job.job_properties,"TESSERACT_PSM", default_ocr_fset.psm); ocr_fset.oem = DetectionComponentUtils::GetProperty(job.job_properties,"TESSERACT_OEM", default_ocr_fset.oem); @@ -1282,8 +1288,8 @@ TesseractOCRTextDetection::load_settings(const MPFJob &job, OCR_filter_settings ocr_fset.max_parallel_pdf_threads = DetectionComponentUtils::GetProperty(job.job_properties, "MAX_PARALLEL_PAGE_THREADS", default_ocr_fset.max_parallel_pdf_threads); // Tessdata setup - ocr_fset.model_dir = DetectionComponentUtils::GetProperty(job.job_properties, "MODELS_DIR_PATH", default_ocr_fset.model_dir); - ocr_fset.tessdata_models_subdir = DetectionComponentUtils::GetProperty(job.job_properties, "TESSDATA_MODELS_SUBDIRECTORY", default_ocr_fset.tessdata_models_subdir); + ocr_fset.model_dir = DetectionComponentUtils::GetProperty(job.job_properties, "MODELS_DIR_PATH", default_ocr_fset.model_dir); + ocr_fset.tessdata_models_subdir = DetectionComponentUtils::GetProperty(job.job_properties, "TESSDATA_MODELS_SUBDIRECTORY", default_ocr_fset.tessdata_models_subdir); if (ocr_fset.model_dir != "") { ocr_fset.model_dir = ocr_fset.model_dir + "/" + ocr_fset.tessdata_models_subdir; } else { @@ -1342,9 +1348,7 @@ inline void set_coordinates(int &xLeftUpper, int &yLeftUpper, int &width, int &h */ void TesseractOCRTextDetection::check_default_languages(const OCR_filter_settings &ocr_fset, const string &job_name, - const string &run_dir, - MPFDetectionError &job_status) { - + const string &run_dir) { set languages_found, missing_languages; set lang_list = generate_lang_set(ocr_fset.tesseract_lang); bool tess_api_in_map = true; @@ -1372,10 +1376,74 @@ void TesseractOCRTextDetection::check_default_languages(const OCR_filter_setting } } +vector TesseractOCRTextDetection::GetDetections(const MPFVideoJob &job) { + try { + LOG4CXX_INFO(hw_logger_, "[" + job.job_name + "] Starting job."); + + MPFVideoCapture cap(job); + + OCR_filter_settings ocr_fset; + + string run_dir = GetRunDirectory(); + cv::Mat frame; + vector tracks; + + int frame_index = 0; + + load_settings(job, ocr_fset); + check_default_languages(ocr_fset, job.job_name, run_dir); + + // Depending on the frame, the default language will be replaced by script detection results. + string default_lang = ocr_fset.tesseract_lang; + + while (cap.Read(frame)) { + Text_type text_type = Unknown; + if (job.has_feed_forward_track && + job.feed_forward_track.frame_locations.count(frame_index) && + job.feed_forward_track.frame_locations.at(frame_index).detection_properties.count("TEXT_TYPE")) { + string job_text_type = job.feed_forward_track.frame_locations.at(frame_index).detection_properties.at( + "TEXT_TYPE"); + if (job_text_type == "UNSTRUCTURED") { + text_type = Unstructured; + } else if (job_text_type == "STRUCTURED") { + text_type = Structured; + } + + LOG4CXX_DEBUG(hw_logger_, "[" + job.job_name + "] Identified text type: \"" + + job_text_type + "\"."); + } + + load_image_preprocessing_settings(job, ocr_fset, text_type); + vector locations = process_image_job(job, ocr_fset, frame, run_dir); + for (auto location: locations) { + MPFVideoTrack video_track(frame_index, frame_index); + video_track.confidence = location.confidence; + video_track.frame_locations[frame_index] = location; + tracks.push_back(video_track); + } + + ocr_fset.tesseract_lang = default_lang; + frame_index++; + } + + for (MPFVideoTrack &track : tracks) { + cap.ReverseTransform(track); + } + + LOG4CXX_INFO(hw_logger_, + "[" + job.job_name + "] Processing complete. Found " + to_string(tracks.size()) + " tracks."); + + return tracks; + } + catch (...) { + Utils::LogAndReThrowException(job, hw_logger_); + } +} + vector TesseractOCRTextDetection::GetDetections(const MPFImageJob &job) { - LOG4CXX_DEBUG(hw_logger_, "[" + job.job_name + "] Processing \"" + job.data_uri + "\"."); try{ + LOG4CXX_INFO(hw_logger_, "[" + job.job_name + "] Starting job."); OCR_filter_settings ocr_fset; Text_type text_type = Unknown; @@ -1389,213 +1457,216 @@ vector TesseractOCRTextDetection::GetDetections(const MPFImage LOG4CXX_DEBUG(hw_logger_, "[" + job.job_name + "] Identified text type: \"" + job.feed_forward_location.detection_properties.at("TEXT_TYPE") + "\"."); } - load_settings(job, ocr_fset, text_type); - - MPFDetectionError job_status = MPF_DETECTION_SUCCESS; - map>> json_kvs_regex; + load_settings(job, ocr_fset); + load_image_preprocessing_settings(job, ocr_fset, text_type); - LOG4CXX_DEBUG(hw_logger_, "[" + job.job_name + "] About to run tesseract"); - vector ocr_outputs; MPFImageReader image_reader(job); cv::Mat image_data = image_reader.GetImage(); - cv::Mat image_data_rotated; - cv::Size input_size = image_data.size(); - string run_dir = GetRunDirectory(); - check_default_languages(ocr_fset, job.job_name, run_dir, job_status); - preprocess_image(job, image_data, ocr_fset); - Properties osd_detection_properties; - string tessdata_script_dir = ""; - set missing_languages; - int xLeftUpper = 0; - int yLeftUpper = 0; - int width = input_size.width; - int height = input_size.height; - int orientation_result = 0; - vector locations; + check_default_languages(ocr_fset, job.job_name, run_dir); + vector locations = process_image_job(job, ocr_fset, image_data, run_dir); - if (ocr_fset.psm == 0 || ocr_fset.enable_osd) { + for (auto &location : locations) { + image_reader.ReverseTransform(location); + } - OSBestResult os_best_result; - get_OSD(os_best_result, image_data, job, ocr_fset, osd_detection_properties, - tessdata_script_dir, missing_languages); + LOG4CXX_INFO(hw_logger_, + "[" + job.job_name + "] Processing complete. Found " + to_string(locations.size()) + " locations."); - // Rotate upper left coordinates based on OSD results. - if (ocr_fset.min_orientation_confidence <= os_best_result.oconfidence) { - orientation_result = os_best_result.orientation_id; - set_coordinates(xLeftUpper, yLeftUpper, width, height, input_size, os_best_result.orientation_id); - } + return locations; + } + catch (...) { + Utils::LogAndReThrowException(job, hw_logger_); + } +} - // When PSM is set to 0, there is no need to process any further. - if (ocr_fset.psm == 0) { - LOG4CXX_INFO(hw_logger_, - "[" + job.job_name + "] Processing complete. Found " + to_string(locations.size()) + - " tracks."); - osd_detection_properties["MISSING_LANGUAGE_MODELS"] = boost::algorithm::join(missing_languages, ", "); - MPFImageLocation osd_detection(xLeftUpper, yLeftUpper, width, height, -1, osd_detection_properties); - locations.push_back(osd_detection); - return locations; - } - } else { - // If OSD is not run, the image won't be rescaled yet. - rescale_image(job, image_data, ocr_fset); - } - osd_detection_properties["MISSING_LANGUAGE_MODELS"] = boost::algorithm::join(missing_languages, ", "); - set remaining_languages; - string first_pass_rotation, second_pass_rotation; - double min_ocr_conf = ocr_fset.rotate_and_detect_min_confidence; - int corrected_orientation; - int corrected_X, corrected_Y, corrected_width, corrected_height; +vector TesseractOCRTextDetection::process_image_job(const MPFJob &job, + OCR_filter_settings &ocr_fset, + cv::Mat &image_data, + const string &run_dir) { - if (ocr_fset.rotate_and_detect) { - cv::rotate(image_data, image_data_rotated, cv::ROTATE_180); - remaining_languages = generate_lang_set(ocr_fset.tesseract_lang); - double rotation_val = 0.0; - if (osd_detection_properties.count("ROTATION")) { - rotation_val = boost::lexical_cast(osd_detection_properties["ROTATION"]); - } - first_pass_rotation = to_string(rotation_val); - second_pass_rotation = to_string(180.0 + rotation_val); - corrected_orientation = (orientation_result + 2) % 4; - set_coordinates(corrected_X, corrected_Y, corrected_width, corrected_height, input_size, corrected_orientation); - } + LOG4CXX_DEBUG(hw_logger_, "[" + job.job_name + "] About to run tesseract"); + vector ocr_outputs; + cv::Mat image_data_rotated; + cv::Size input_size = image_data.size(); - // Run initial get_tesseract_detections. When autorotate is set, for any languages that fall below initial pass - // create a second round of extractions with a 180 degree rotation applied on top of the original setting. - // Second rotation only triggers if ROTATE_AND_DETECT is set. + preprocess_image(job, image_data, ocr_fset); - OCR_job_inputs ocr_job_inputs; - ocr_job_inputs.job_name = &job.job_name; - ocr_job_inputs.lang = &ocr_fset.tesseract_lang; - ocr_job_inputs.tessdata_script_dir = &tessdata_script_dir; - ocr_job_inputs.run_dir = &run_dir; - ocr_job_inputs.imi = &image_data; - ocr_job_inputs.ocr_fset = &ocr_fset; - ocr_job_inputs.process_pdf = false; - ocr_job_inputs.hw_logger_ = hw_logger_; - ocr_job_inputs.tess_api_map = &tess_api_map; + Properties osd_detection_properties; + string tessdata_script_dir = ""; + set missing_languages; + int xLeftUpper = 0; + int yLeftUpper = 0; + int width = input_size.width; + int height = input_size.height; + int orientation_result = 0; + vector locations; - Image_results image_results; - image_results.job_status = job_status; - get_tesseract_detections(ocr_job_inputs, image_results); - ocr_outputs = image_results.detections_by_lang; - job_status = image_results.job_status; - - vector all_results; - - for (const OCR_output &ocr_out: ocr_outputs) { - OCR_output final_out = ocr_out; - if (ocr_fset.rotate_and_detect) { - remaining_languages.erase(ocr_out.language); - final_out.two_pass_rotation = first_pass_rotation; - final_out.two_pass_correction = false; - - // Perform second pass OCR if min threshold is disabled (negative) or first pass confidence too low. - if (min_ocr_conf <= 0 || ocr_out.confidence < min_ocr_conf) { - // Perform second pass OCR and provide best result to output. - vector ocr_outputs_rotated; - ocr_fset.tesseract_lang = ocr_out.language; - ocr_job_inputs.lang = &ocr_fset.tesseract_lang; - ocr_job_inputs.imi = &image_data_rotated; - image_results.detections_by_lang.clear(); - image_results.job_status = job_status; - - get_tesseract_detections(ocr_job_inputs, image_results); - ocr_outputs_rotated = image_results.detections_by_lang; - job_status = image_results.job_status; - - OCR_output ocr_out_rotated = ocr_outputs_rotated.front(); - if (ocr_out_rotated.confidence > ocr_out.confidence) { - final_out = ocr_out_rotated; - final_out.two_pass_rotation = second_pass_rotation; - final_out.two_pass_correction = true; - } + if (ocr_fset.psm == 0 || ocr_fset.enable_osd) { + + OSBestResult os_best_result; + get_OSD(os_best_result, image_data, job, ocr_fset, osd_detection_properties, + tessdata_script_dir, missing_languages); + + // Rotate upper left coordinates based on OSD results. + if (ocr_fset.min_orientation_confidence <= os_best_result.oconfidence) { + orientation_result = os_best_result.orientation_id; + set_coordinates(xLeftUpper, yLeftUpper, width, height, input_size, os_best_result.orientation_id); + } + + // When PSM is set to 0, there is no need to process any further. + if (ocr_fset.psm == 0) { + LOG4CXX_INFO(hw_logger_, + "[" + job.job_name + "] Processing complete. Found " + to_string(locations.size()) + + " tracks."); + osd_detection_properties["MISSING_LANGUAGE_MODELS"] = boost::algorithm::join(missing_languages, ", "); + MPFImageLocation osd_detection(xLeftUpper, yLeftUpper, width, height, -1, osd_detection_properties); + locations.push_back(osd_detection); + return locations; + } + } else { + // If OSD is not run, the image won't be rescaled yet. + rescale_image(job, image_data, ocr_fset); + } + osd_detection_properties["MISSING_LANGUAGE_MODELS"] = boost::algorithm::join(missing_languages, ", "); + + set remaining_languages; + string first_pass_rotation, second_pass_rotation; + double min_ocr_conf = ocr_fset.rotate_and_detect_min_confidence; + int corrected_orientation; + int corrected_X, corrected_Y, corrected_width, corrected_height; + + if (ocr_fset.rotate_and_detect) { + cv::rotate(image_data, image_data_rotated, cv::ROTATE_180); + remaining_languages = generate_lang_set(ocr_fset.tesseract_lang); + double rotation_val = 0.0; + if (osd_detection_properties.count("ROTATION")) { + rotation_val = boost::lexical_cast(osd_detection_properties["ROTATION"]); + } + first_pass_rotation = to_string(rotation_val); + second_pass_rotation = to_string(180.0 + rotation_val); + corrected_orientation = (orientation_result + 2) % 4; + set_coordinates(corrected_X, corrected_Y, corrected_width, corrected_height, input_size, corrected_orientation); + } + + // Run initial get_tesseract_detections. When autorotate is set, for any languages that fall below initial pass + // create a second round of extractions with a 180 degree rotation applied on top of the original setting. + // Second rotation only triggers if ROTATE_AND_DETECT is set. + + OCR_job_inputs ocr_job_inputs; + ocr_job_inputs.job_name = &job.job_name; + ocr_job_inputs.lang = &ocr_fset.tesseract_lang; + ocr_job_inputs.tessdata_script_dir = &tessdata_script_dir; + ocr_job_inputs.run_dir = &run_dir; + ocr_job_inputs.imi = &image_data; + ocr_job_inputs.ocr_fset = &ocr_fset; + ocr_job_inputs.process_pdf = false; + ocr_job_inputs.hw_logger_ = hw_logger_; + ocr_job_inputs.tess_api_map = &tess_api_map; + + Image_results image_results; + get_tesseract_detections(ocr_job_inputs, image_results); + ocr_outputs = image_results.detections_by_lang; + + vector all_results; + + for (const OCR_output &ocr_out: ocr_outputs) { + OCR_output final_out = ocr_out; + if (ocr_fset.rotate_and_detect) { + remaining_languages.erase(ocr_out.language); + final_out.two_pass_rotation = first_pass_rotation; + final_out.two_pass_correction = false; + + // Perform second pass OCR if min threshold is disabled (negative) or first pass confidence too low. + if (min_ocr_conf <= 0 || ocr_out.confidence < min_ocr_conf) { + // Perform second pass OCR and provide best result to output. + vector ocr_outputs_rotated; + ocr_fset.tesseract_lang = ocr_out.language; + ocr_job_inputs.lang = &ocr_fset.tesseract_lang; + ocr_job_inputs.imi = &image_data_rotated; + image_results.detections_by_lang.clear(); + + get_tesseract_detections(ocr_job_inputs, image_results); + ocr_outputs_rotated = image_results.detections_by_lang; + + OCR_output ocr_out_rotated = ocr_outputs_rotated.front(); + if (ocr_out_rotated.confidence > ocr_out.confidence) { + final_out = ocr_out_rotated; + final_out.two_pass_rotation = second_pass_rotation; + final_out.two_pass_correction = true; } } - all_results.push_back(final_out); } + all_results.push_back(final_out); + } - // If two stage OCR is enabled, run the second pass of OCR on any remaining languages where the first pass failed - // to generate an output. - for (const string &rem_lang: remaining_languages) { - // Perform second pass OCR and provide best result to output. - vector ocr_outputs_rotated; - ocr_fset.tesseract_lang = rem_lang; - ocr_job_inputs.lang = &ocr_fset.tesseract_lang; - ocr_job_inputs.imi = &image_data_rotated; - image_results.detections_by_lang.clear(); - image_results.job_status = job_status; + // If two stage OCR is enabled, run the second pass of OCR on any remaining languages where the first pass failed + // to generate an output. + for (const string &rem_lang: remaining_languages) { + // Perform second pass OCR and provide best result to output. + vector ocr_outputs_rotated; + ocr_fset.tesseract_lang = rem_lang; + ocr_job_inputs.lang = &ocr_fset.tesseract_lang; + ocr_job_inputs.imi = &image_data_rotated; + image_results.detections_by_lang.clear(); - get_tesseract_detections(ocr_job_inputs, image_results); - ocr_outputs_rotated = image_results.detections_by_lang; - job_status = image_results.job_status; + get_tesseract_detections(ocr_job_inputs, image_results); + ocr_outputs_rotated = image_results.detections_by_lang; - OCR_output ocr_out_rotated = ocr_outputs_rotated.front(); - ocr_out_rotated.two_pass_rotation = second_pass_rotation; - ocr_out_rotated.two_pass_correction = true; - all_results.push_back(ocr_out_rotated); - } + OCR_output ocr_out_rotated = ocr_outputs_rotated.front(); + ocr_out_rotated.two_pass_rotation = second_pass_rotation; + ocr_out_rotated.two_pass_correction = true; + all_results.push_back(ocr_out_rotated); + } - // If max_text_tracks is set, filter out to return only the top specified tracks. - if (ocr_fset.max_text_tracks > 0) { - sort(all_results.begin(), all_results.end(), greater()); - all_results.resize(ocr_fset.max_text_tracks); - } - - for (const OCR_output &final_out : all_results) { - MPFImageLocation image_location(xLeftUpper, yLeftUpper, width, height, final_out.confidence); - // Copy over OSD detection results into OCR output. - image_location.detection_properties = osd_detection_properties; - - // Mark two-pass OCR final selected rotation. - if (ocr_fset.rotate_and_detect) { - image_location.detection_properties["ROTATION"] = final_out.two_pass_rotation; - if (final_out.two_pass_correction) { - image_location.detection_properties["ROTATE_AND_DETECT_PASS"] = "180"; - // Also correct top left corner designation: - image_location.x_left_upper = corrected_X; - image_location.y_left_upper = corrected_Y; - image_location.width = corrected_width; - image_location.height = corrected_height; - } else { - image_location.detection_properties["ROTATE_AND_DETECT_PASS"] = "0"; - } + // If max_text_tracks is set, filter out to return only the top specified tracks. + if (ocr_fset.max_text_tracks > 0) { + sort(all_results.begin(), all_results.end(), greater()); + all_results.resize(ocr_fset.max_text_tracks); + } + for (const OCR_output &final_out : all_results) { + MPFImageLocation image_location(xLeftUpper, yLeftUpper, width, height, final_out.confidence); + // Copy over OSD detection results into OCR output. + image_location.detection_properties = osd_detection_properties; + + // Mark two-pass OCR final selected rotation. + if (ocr_fset.rotate_and_detect) { + image_location.detection_properties["ROTATION"] = final_out.two_pass_rotation; + if (final_out.two_pass_correction) { + image_location.detection_properties["ROTATE_AND_DETECT_PASS"] = "180"; + // Also correct top left corner designation: + image_location.x_left_upper = corrected_X; + image_location.y_left_upper = corrected_Y; + image_location.width = corrected_width; + image_location.height = corrected_height; + } else { + image_location.detection_properties["ROTATE_AND_DETECT_PASS"] = "0"; } - bool process_text = process_ocr_text(image_location.detection_properties, job, final_out, - ocr_fset); - if (process_text) { - locations.push_back(image_location); - } - } - for (auto &location : locations) { - image_reader.ReverseTransform(location); } - - LOG4CXX_INFO(hw_logger_, - "[" + job.job_name + "] Processing complete. Found " + to_string(locations.size()) + " tracks."); - return locations; - } - catch (...) { - Utils::LogAndReThrowException(job, hw_logger_); + bool process_text = process_ocr_text(image_location.detection_properties, job, final_out, + ocr_fset); + if (process_text) { + locations.push_back(image_location); + } } -} + return locations; +} void TesseractOCRTextDetection::process_parallel_pdf_pages(PDF_page_inputs &page_inputs, PDF_page_results &page_results) { PDF_thread_variables thread_var[page_inputs.filelist.size()]; - std::future pdf_threads[page_inputs.filelist.size()]; + future pdf_threads[page_inputs.filelist.size()]; int index = 0; set active_threads; for (const string &filename : page_inputs.filelist) { - thread_var[index].page_thread_res.job_status = page_results.job_status; MPFImageJob c_job((*page_inputs.job).job_name, filename, (*page_inputs.job).job_properties, @@ -1628,18 +1699,18 @@ void TesseractOCRTextDetection::process_parallel_pdf_pages(PDF_page_inputs &page thread_var[index].ocr_input.process_pdf = true; thread_var[index].ocr_input.hw_logger_ = hw_logger_; thread_var[index].ocr_input.tess_api_map = &tess_api_map; - pdf_threads[index] = std::async(launch::async, + pdf_threads[index] = async(launch::async, &get_tesseract_detections, - std::ref(thread_var[index].ocr_input), - std::ref(thread_var[index].page_thread_res)); + ref(thread_var[index].ocr_input), + ref(thread_var[index].page_thread_res)); active_threads.insert(index); index ++; while (active_threads.size() >= page_inputs.ocr_fset.max_parallel_pdf_threads) { - std::this_thread::sleep_for(std::chrono::seconds(1)); + this_thread::sleep_for(chrono::seconds(1)); for (auto active_thread_iter = active_threads.begin(); active_thread_iter != active_threads.end(); /* Intentionally no ++ to support erasing. */) { int i = *active_thread_iter; - if (pdf_threads[i].wait_for(std::chrono::milliseconds(0)) == std::future_status::ready) { + if (pdf_threads[i].wait_for(chrono::milliseconds(0)) == future_status::ready) { // Will re-throw exception from thread. pdf_threads[i].get(); active_thread_iter = active_threads.erase(active_thread_iter); @@ -1658,7 +1729,6 @@ void TesseractOCRTextDetection::process_parallel_pdf_pages(PDF_page_inputs &page } for (index = 0; index < page_inputs.filelist.size(); index++ ) { - MPFImageJob c_job((*page_inputs.job).job_name, "", (*page_inputs.job).job_properties, (*page_inputs.job).media_properties); // If max_text_tracks is set, filter out to return only the top specified tracks. if (page_inputs.ocr_fset.max_text_tracks > 0) { @@ -1674,7 +1744,7 @@ void TesseractOCRTextDetection::process_parallel_pdf_pages(PDF_page_inputs &page // Copy over OSD results into OCR tracks. generic_track.detection_properties = thread_var[index].osd_track_results.detection_properties; - bool process_text = process_ocr_text(generic_track.detection_properties, c_job, ocr_out, + bool process_text = process_ocr_text(generic_track.detection_properties, *page_inputs.job, ocr_out, page_inputs.ocr_fset, index); if (process_text) { (*page_results.tracks).push_back(generic_track); @@ -1687,7 +1757,10 @@ void TesseractOCRTextDetection::process_serial_pdf_pages(PDF_page_inputs &page_i PDF_page_results &page_results) { int page_num = 0; for (const string &filename : page_inputs.filelist) { - MPFImageJob c_job((*page_inputs.job).job_name, filename, (*page_inputs.job).job_properties, (*page_inputs.job).media_properties); + MPFImageJob c_job((*page_inputs.job).job_name, + filename, + (*page_inputs.job).job_properties, + (*page_inputs.job).media_properties); MPFImageReader image_reader(c_job); cv::Mat image_data = image_reader.GetImage(); preprocess_image(c_job, image_data, page_inputs.ocr_fset); @@ -1717,7 +1790,6 @@ void TesseractOCRTextDetection::process_serial_pdf_pages(PDF_page_inputs &page_i rescale_image(c_job, image_data, page_inputs.ocr_fset); } - OCR_job_inputs ocr_job_inputs; ocr_job_inputs.job_name = &c_job.job_name; ocr_job_inputs.lang = &page_inputs.ocr_fset.tesseract_lang; @@ -1730,10 +1802,7 @@ void TesseractOCRTextDetection::process_serial_pdf_pages(PDF_page_inputs &page_i ocr_job_inputs.tess_api_map = &tess_api_map; Image_results image_results; - image_results.job_status = page_results.job_status; - get_tesseract_detections(ocr_job_inputs, image_results); - page_results.job_status = image_results.job_status; // If max_text_tracks is set, filter out to return only the top specified tracks. if (page_inputs.ocr_fset.max_text_tracks > 0) { @@ -1760,18 +1829,17 @@ void TesseractOCRTextDetection::process_serial_pdf_pages(PDF_page_inputs &page_i vector TesseractOCRTextDetection::GetDetections(const MPFGenericJob &job) { try { - LOG4CXX_DEBUG(hw_logger_, "[" + job.job_name + "] Processing \"" + job.data_uri + "\"."); + LOG4CXX_INFO(hw_logger_, "[" + job.job_name + "] Starting job."); PDF_page_inputs page_inputs; PDF_page_results page_results; - std::vector tracks; + vector tracks; page_results.tracks = &tracks; page_inputs.job = &job; load_settings(job, page_inputs.ocr_fset); - - page_results.job_status = MPF_DETECTION_SUCCESS; + load_image_preprocessing_settings(job, page_inputs.ocr_fset); vector job_names; boost::split(job_names, job.job_name, boost::is_any_of(":")); @@ -1783,7 +1851,7 @@ vector TesseractOCRTextDetection::GetDetections(const MPFGeneri page_inputs.run_dir = "."; } - check_default_languages(page_inputs.ocr_fset, job.job_name, page_inputs.run_dir, page_results.job_status); + check_default_languages(page_inputs.ocr_fset, job.job_name, page_inputs.run_dir); string plugin_path = page_inputs.run_dir + "/TesseractOCRTextDetection"; TempDirectory temp_im_directory(plugin_path + "/tmp-" + job_name + "-" + random_string(20)); @@ -1831,12 +1899,12 @@ vector TesseractOCRTextDetection::GetDetections(const MPFGeneri } -TessApiWrapper::TessApiWrapper(const std::string& data_path, const std::string& language, tesseract::OcrEngineMode oem) { +TessApiWrapper::TessApiWrapper(const string& data_path, const string& language, tesseract::OcrEngineMode oem) { int rc = tess_api_.Init(data_path.c_str(), language.c_str(), oem); if (rc != 0) { throw MPFDetectionException( MPF_DETECTION_NOT_INITIALIZED, - "Failed to initialize Tesseract! Error code: " + std::to_string(rc)); + "Failed to initialize Tesseract! Error code: " + to_string(rc)); } } @@ -1849,12 +1917,12 @@ void TessApiWrapper::SetImage(const cv::Mat &image) { static_cast(image.step)); } -std::string TessApiWrapper::GetUTF8Text() { - std::unique_ptr text{tess_api_.GetUTF8Text()}; +string TessApiWrapper::GetUTF8Text() { + unique_ptr text{tess_api_.GetUTF8Text()}; if (text == nullptr) { return ""; } - return std::string(text.get()); + return string(text.get()); } int TessApiWrapper::MeanTextConf() { diff --git a/cpp/TesseractOCRTextDetection/TesseractOCRTextDetection.h b/cpp/TesseractOCRTextDetection/TesseractOCRTextDetection.h index 3be9a4c6..33212afa 100755 --- a/cpp/TesseractOCRTextDetection/TesseractOCRTextDetection.h +++ b/cpp/TesseractOCRTextDetection/TesseractOCRTextDetection.h @@ -60,7 +60,7 @@ namespace MPF { class TessApiWrapper; - class TesseractOCRTextDetection : public MPFImageDetectionComponentAdapter { + class TesseractOCRTextDetection : public MPFDetectionComponent { public: bool Init() override; @@ -71,6 +71,10 @@ namespace MPF { std::vector GetDetections(const MPFGenericJob &job) override; + std::vector GetDetections(const MPFVideoJob &job) override; + + std::vector GetDetections(const MPFAudioJob &job) override; + std::string GetDetectionType() override; bool Supports(MPFDetectionDataType data_type) override; @@ -158,13 +162,11 @@ namespace MPF { struct Image_results{ std::vector detections_by_lang; - MPFDetectionError job_status; }; struct OCR_results { std::string text_result; std::string lang; - MPFDetectionError job_status; double confidence; }; @@ -180,7 +182,6 @@ namespace MPF { struct PDF_page_results { std::set all_missing_languages; - MPFDetectionError job_status; std::vector *tracks; }; @@ -201,7 +202,12 @@ namespace MPF { } }; - bool process_ocr_text(Properties &detection_properties, const MPFImageJob &job, const OCR_output &ocr_out, + std::vector process_image_job(const MPFJob &job, + TesseractOCRTextDetection::OCR_filter_settings &ocr_fset, + cv::Mat &image_data, + const std::string &run_dir); + + bool process_ocr_text(Properties &detection_properties, const MPFJob &job, const OCR_output &ocr_out, const TesseractOCRTextDetection::OCR_filter_settings &ocr_fset, int page_num = -1); @@ -222,8 +228,8 @@ namespace MPF { static void process_parallel_image_runs(OCR_job_inputs &inputs, Image_results &results); static void process_serial_image_runs(OCR_job_inputs &inputs, Image_results &results); - void preprocess_image(const MPFImageJob &job, cv::Mat &input_image, const OCR_filter_settings &ocr_fset); - void rescale_image(const MPFImageJob &job, cv::Mat &input_image, const OCR_filter_settings &ocr_fset); + void preprocess_image(const MPFJob &job, cv::Mat &input_image, const OCR_filter_settings &ocr_fset); + void rescale_image(const MPFJob &job, cv::Mat &input_image, const OCR_filter_settings &ocr_fset); static void process_tesseract_lang_model(OCR_job_inputs &input, OCR_results &result); @@ -231,20 +237,23 @@ namespace MPF { void set_read_config_parameters(); - void load_settings(const MPFJob &job, OCR_filter_settings &ocr_fset, const Text_type &text_type = Unknown); + void load_settings(const MPFJob &job, OCR_filter_settings &ocr_fset); + void load_image_preprocessing_settings(const MPFJob &job, + OCR_filter_settings &ocr_fset, + const Text_type &text_type = Unknown); void sharpen(cv::Mat &image, double weight); static std::string process_osd_lang(const std::string &script_type, const OCR_filter_settings &ocr_fset); - void get_OSD(OSBestResult &best_result, cv::Mat &imi, const MPFImageJob &job, + void get_OSD(OSBestResult &best_result, cv::Mat &imi, const MPFJob &job, OCR_filter_settings &ocr_fset, Properties &detection_properties, std::string &tessdata_script_dir, std::set &missing_languages); bool get_OSD_rotation(OSResults *results, cv::Mat &imi_scaled, cv::Mat &imi_original, - int &rotation, const MPFImageJob &job, OCR_filter_settings &ocr_fset); + int &rotation, const MPFJob &job, OCR_filter_settings &ocr_fset); static std::string return_valid_tessdir(const std::string &job_name, const std::string &lang_str, @@ -265,8 +274,7 @@ namespace MPF { void check_default_languages(const OCR_filter_settings &ocr_fset, const std::string &job_name, - const std::string &run_dir, - MPFDetectionError &job_status); + const std::string &run_dir); }; // The primary reason this class exists is that tesseract::TessBaseAPI segfaults when copying or moving. diff --git a/cpp/TesseractOCRTextDetection/sample_tesseract_ocr_detector.cpp b/cpp/TesseractOCRTextDetection/sample_tesseract_ocr_detector.cpp index 1f32de55..d177bd51 100755 --- a/cpp/TesseractOCRTextDetection/sample_tesseract_ocr_detector.cpp +++ b/cpp/TesseractOCRTextDetection/sample_tesseract_ocr_detector.cpp @@ -45,10 +45,10 @@ using std::to_string; void print_usage(char *argv[]) { std::cout << "Usage: " << argv[0] << - " <-i | -g> [--osd] [--oem TESSERACT_OEM] [TESSERACT_LANGUAGE]" << + " <-i | -v | -g> [--osd] [--oem TESSERACT_OEM] | GENERIC_URI> [TESSERACT_LANGUAGE]" << std::endl << std::endl; std::cout << "Notes: " << std::endl << std::endl; - std::cout << " -i | -g : Specifies whether to process an image (-i ) or generic document (-g )." << + std::cout << " <-i | -v | -g> : Specifies whether to process an image (-i ), video (-v ), or generic document (-g )." << std::endl << std::endl; std::cout << " --osd : When provided, runs the job with automatic orientation and script detection (OSD). " << std::endl; @@ -102,8 +102,8 @@ bool check_options(const std::string &next_option, const int &argc, char *argv[ if (next_option == "--osd") { algorithm_properties["ENABLE_OSD_AUTOMATION"] = "true"; uri_index++; - } else if (next_option == "--oem" || argc - uri_index > 2) { - std::cout << "Updating OEM MODE " << argv[uri_index + 1]; + } else if (next_option == "--oem" && argc - uri_index > 2) { + std::cout << "Updating OEM MODE " << argv[uri_index + 1] << std::endl; algorithm_properties["TESSERACT_OEM"] = argv[uri_index + 1]; uri_index += 2; } else { @@ -131,18 +131,30 @@ int main(int argc, char *argv[]) { algorithm_properties["SHARPEN"] = "1.0"; algorithm_properties["ENABLE_OSD_AUTOMATION"] = "false"; - int uri_index = 2; + int uri_index = 2, video_params = 0, start_frame = 0, end_frame = 1; + std::string next_option = std::string(argv[uri_index]); if (check_options(next_option, argc, argv, algorithm_properties, uri_index)) { next_option = std::string(argv[uri_index]); check_options(next_option, argc, argv, algorithm_properties, uri_index); } - if (argc - uri_index == 1) { + if (media_option == "-v") { + video_params = 2; + if (argc - uri_index < 3) { + print_usage(argv); + return 0; + + } + start_frame = std::stoi(argv[uri_index+1]); + end_frame = std::stoi(argv[uri_index+2]); + } + + if (argc - uri_index - video_params == 1) { uri = argv[uri_index]; - } else if (argc - uri_index == 2) { + } else if (argc - uri_index - video_params == 2) { uri = argv[uri_index]; - algorithm_properties["TESSERACT_LANGUAGE"] = argv[uri_index + 1]; + algorithm_properties["TESSERACT_LANGUAGE"] = argv[uri_index + video_params + 1]; } else { print_usage(argv); return 0; @@ -176,6 +188,22 @@ int main(int argc, char *argv[]) { print_detection_properties(locations[i].detection_properties, locations[i].confidence); } } + else if (media_option == "-v") { + // Run uri as an image data file. + std::cout << "Running job on video data uri: " << uri << std::endl; + MPFVideoJob job(job_name, uri, start_frame, end_frame, algorithm_properties, media_properties); + int count = 0; + for (auto track: im.GetDetections(job)) { + std::cout << "Track number: " << count << std::endl; + std::map locations = track.frame_locations; + std::cout << "Number of image locations: " << locations.size() << std::endl << std::endl; + for (const auto &location: locations) { + std::cout << "Frame number: " << location.first << std::endl; + print_detection_properties(location.second.detection_properties, location.second.confidence); + } + count ++; + } + } else { print_usage(argv); } diff --git a/cpp/TesseractOCRTextDetection/test/data/NOTICE b/cpp/TesseractOCRTextDetection/test/data/NOTICE index a392aa1a..d3a35d85 100644 --- a/cpp/TesseractOCRTextDetection/test/data/NOTICE +++ b/cpp/TesseractOCRTextDetection/test/data/NOTICE @@ -82,6 +82,19 @@ Custom generated pdf for testing document text extraction. # test-backslash.png Custom generated image for testing escaped backslash tagging. +# test-video-detection.avi +Short clip of three separate image frames for testing video detection capability. +Contains public domain text from the following sources: + + https://en.wikipedia.org/wiki/Diazepam + (Japanese Translation) + Public Domain + + http://www.un.org/en/universal-declaration-human-rights/ + English text from the Universal + Declaration of Human Rights. + Public Domain + # text-demo.png Text extracted from open source project https://github.com/tesseract-ocr/tesseract. diff --git a/cpp/TesseractOCRTextDetection/test/data/test-video-detection.avi b/cpp/TesseractOCRTextDetection/test/data/test-video-detection.avi new file mode 100755 index 00000000..4ec7244b Binary files /dev/null and b/cpp/TesseractOCRTextDetection/test/data/test-video-detection.avi differ diff --git a/cpp/TesseractOCRTextDetection/test/test_tesseract_ocr_detection.cpp b/cpp/TesseractOCRTextDetection/test/test_tesseract_ocr_detection.cpp index 4f6bcbfe..de4c3c8a 100644 --- a/cpp/TesseractOCRTextDetection/test/test_tesseract_ocr_detection.cpp +++ b/cpp/TesseractOCRTextDetection/test/test_tesseract_ocr_detection.cpp @@ -100,6 +100,25 @@ MPFGenericJob createPDFJob(const std::string &uri, const std::map &custom = {}) { + Properties algorithm_properties; + Properties media_properties; + std::string job_name("OCR_test"); + setAlgorithmProperties(algorithm_properties, custom); + MPFVideoJob job(job_name, uri, start, end, algorithm_properties, media_properties); + return job; +} + /** * Helper function for running given image job. Checks if job results is not empty. * @@ -136,6 +155,27 @@ void runDocumentDetection(const std::string &doc_path, TesseractOCRTextDetection ASSERT_FALSE(generic_tracks.empty()); } + +/** + * Helper function for running given video job. Checks if job results is not empty. + * + * @param vid_path - Path of given video. + * @param ocr - TesseractOCRTextDetection component for running given job. + * @param video_tracks - Output vector of video detection tracks for given job. + * @param start - Video start frame. + * @param end - Video end frame. + * @param custom - Mapping of input job properties. + */ +void runVideoDetection(const std::string &vid_path, TesseractOCRTextDetection &ocr, + std::vector &video_tracks, + const int &start, const int &end, + const std::map &custom = {}) { + MPFVideoJob job = createVideoJob(vid_path, start, end, custom); + video_tracks = ocr.GetDetections(job); + ASSERT_FALSE(video_tracks.empty()); +} + + /** * Helper function for checking if running given image job will return no results. * @@ -485,6 +525,40 @@ TEST(TESSERACTOCR, CustomModelTest) { ASSERT_TRUE(ocr.Close()); } +TEST(TESSERACTOCR, VideoProcessingTest) { + + // Ensure video processing works as expected. + + TesseractOCRTextDetection ocr; + ocr.SetRunDirectory("../plugin"); + std::vector track_results; + std::vector results; + ASSERT_TRUE(ocr.Init()); + + std::map custom_properties = {{"TESSERACT_LANGUAGE", "eng"}, + {"ENABLE_OSD_AUTOMATION","TRUE"}}; + + ASSERT_NO_FATAL_FAILURE(runVideoDetection("data/test-video-detection.avi", ocr, track_results, 0, 2, custom_properties)); + + for (auto track_result: track_results) { + for (auto result: track_result.frame_locations) { + results.push_back(result.second); + } + } + + assertInImage("data/test-video-detection.avi", "Testing Text Detection", results, "TEXT", 0); + assertInImage("data/test-video-detection.avi", "eng", results, "TEXT_LANGUAGE", 0); + + assertInImage("data/test-video-detection.avi", "Japanese", results, "OSD_PRIMARY_SCRIPT", 1); + assertInImage("data/test-video-detection.avi", "Japanese", results, "MISSING_LANGUAGE_MODELS", 1); + + assertInImage("data/test-video-detection.avi", "All human beings", results, "TEXT", 2); + assertInImage("data/test-video-detection.avi", "Latin", results, "TEXT_LANGUAGE", 2); + + + ASSERT_TRUE(ocr.Close()); +} + TEST(TESSERACTOCR, ImageProcessingTest) { // Ensure contrast and unstructured image processing settings are enabled. @@ -546,8 +620,6 @@ TEST(TESSERACTOCR, ImageProcessingTest) { ASSERT_TRUE(ocr.Close()); } - - TEST(TESSERACTOCR, ModelTest) { // Ensure user can specify custom model directory locations. diff --git a/cpp/TrtisDetection/CMakeLists.txt b/cpp/TrtisDetection/CMakeLists.txt index c4204ece..e9e8a1e1 100644 --- a/cpp/TrtisDetection/CMakeLists.txt +++ b/cpp/TrtisDetection/CMakeLists.txt @@ -40,13 +40,16 @@ set(PACKAGE_DIR ${CMAKE_CURRENT_BINARY_DIR}/plugin/${PROJECT_NAME}) message("Package in ${PACKAGE_DIR}") find_package(OpenCV 4.5.0 EXACT REQUIRED PATHS /opt/opencv-4.5.0 - COMPONENTS opencv_core) + COMPONENTS opencv_core) find_package(mpfComponentInterface REQUIRED) find_package(mpfDetectionComponentApi REQUIRED) find_package(mpfComponentUtils REQUIRED) find_package(Qt4 REQUIRED) -find_package(request REQUIRED) + +set(CMAKE_PREFIX_PATH /opt/triton) find_package(CURL REQUIRED) +find_package(TritonCommon PATH_SUFFIXES 64 REQUIRED) +find_package(TRITON REQUIRED) set(BUILD_SHARED_LIBS ON) # make AWS use shared linking find_package(AWSSDK REQUIRED COMPONENTS core s3) @@ -57,13 +60,26 @@ set(TRTIS_DETECTION_SOURCE_FILES S3FeatureStorage.cpp S3FeatureStorage.h S3StorageUtil.cpp S3StorageUtil.h base64.h uri.h) add_library(mpfTrtisDetection SHARED ${TRTIS_DETECTION_SOURCE_FILES}) -target_link_libraries(mpfTrtisDetection request mpfComponentInterface mpfDetectionComponentApi mpfComponentUtils - ${OpenCV_LIBS} ${PROTOBUF_LIBRARY} ${CURL_LIBRARIES} ${AWSSDK_LINK_LIBRARIES}) +target_link_libraries(mpfTrtisDetection + mpfComponentInterface + mpfDetectionComponentApi + mpfComponentUtils + TRITON::grpcclient_static + ${OpenCV_LIBS} + ${PROTOBUF_LIBRARY} + ${CURL_LIBRARIES} + ${AWSSDK_LINK_LIBRARIES}) +# add seperate dynamic libcurl for aws +target_link_libraries(mpfTrtisDetection /usr/local/lib64/libcurl.so) + configure_mpf_component(TrtisDetection TARGETS mpfTrtisDetection) add_subdirectory(test) # Build sample executable -add_executable(sample_trtis_detector sample_trtis_detector.cpp) -target_link_libraries(sample_trtis_detector mpfTrtisDetection) +include_directories(/home/mpf/mpf-sdk-install/include /opt/triton/include) +add_executable(sample_trtis_detector + sample_trtis_detector.cpp) +target_link_libraries(sample_trtis_detector + mpfTrtisDetection) \ No newline at end of file diff --git a/cpp/TrtisDetection/Dockerfile b/cpp/TrtisDetection/Dockerfile index 6a1345bd..76efa190 100644 --- a/cpp/TrtisDetection/Dockerfile +++ b/cpp/TrtisDetection/Dockerfile @@ -37,104 +37,126 @@ FROM ${BUILD_REGISTRY}openmpf_cpp_component_build:${BUILD_TAG} as build_componen RUN --mount=type=tmpfs,target=/tmp \ --mount=type=tmpfs,target=/var/cache/yum \ - yum -y --nogpgcheck install git openssl-devel curl-devel cuda-cufft-10-2 cuda-npp-10-2 \ - && yum clean all + yum update --assumeyes; \ + yum --assumeyes --nogpgcheck install \ + which file git \ + cuda-cufft-dev-10-2 cuda-npp-dev-10-2 \ + libtool openssl-devel rapidjson-devel \ + # gcc4.9.4 dependancies + gmp-devel mpfr-devel libmpc-devel; \ + # add libb64 needed by Triton + rpm --nosignature -i http://www6.atomicorp.com/channels/atomic/centos/7/x86_64/RPMS/libb64-libs-1.2.1-2.1.el7.art.x86_64.rpm; \ + rpm --nosignature -i http://www6.atomicorp.com/channels/atomic/centos/7/x86_64/RPMS/libb64-devel-1.2.1-2.1.el7.art.x86_64.rpm; \ + yum clean all; \ + rm --recursive /var/cache/yum/*; + +# install gcc 4.9.4 +RUN --mount=type=tmpfs,target=/tmp \ + cd /tmp; \ + curl https://ftp.gnu.org/gnu/gcc/gcc-4.9.4/gcc-4.9.4.tar.gz | tar --extract --gzip; \ + mkdir -p /tmp/gcc-4.9.4/build; \ + cd /tmp/gcc-4.9.4/build; \ + ../configure --enable-languages=c,c++ --disable-multilib; \ + make -j$(nproc) && make install; \ + echo '/usr/local/lib64' > /etc/ld.so.conf.d/locallib64.conf; \ + ldconfig; \ + rm --recursive /tmp/gcc-4.9.4; + +ENV CC=/usr/local/bin/gcc +ENV CXX=/usr/local/bin/g++ + + +# install updated cmake 3.18, configure existing opencv +RUN curl --location 'https://github.com/Kitware/CMake/releases/download/v3.18.6/cmake-3.18.6-Linux-x86_64.tar.gz' \ + | tar --extract --gzip --directory=/usr --strip-components=1; \ + ln -sf /usr/bin/cmake /usr/bin/cmake3; \ + echo '/opt/opencv-4.5.0/lib64' > /etc/ld.so.conf.d/opencv.conf; \ + ldconfig; \ + ln --symbolic '/opt/opencv-4.5.0/include/opencv4/opencv2' /usr/local/include/opencv2; -# Install cares, protobuf, zlib, gRPC WORKDIR /tmp +# Install newer version of curl needed for AWS S3 RUN --mount=type=tmpfs,target=/tmp \ -git clone -b v1.20.0 --depth 1 --recurse-submodules https://github.com/grpc/grpc \ -&& cd grpc \ -&& mkdir -p /tmp/grpc/third_party/protobuf/cmake/build \ -&& cd /tmp/grpc/third_party/protobuf/cmake/build \ -&& cmake3 -DCMAKE_POSITION_INDEPENDENT_CODE=ON -Dprotobuf_BUILD_TESTS=OFF -DCMAKE_BUILD_TYPE=Release .. \ -&& make -j`nproc` \ -&& make install \ -&& cd /tmp/ \ -&& rm -rf /tmp/grpc/third_party/protobuf \ -&& mkdir -p /tmp/grpc/third_party/cares/cares/cmake/build \ -&& cd /tmp/grpc/third_party/cares/cares/cmake/build \ -&& cmake3 -DCMAKE_BUILD_TYPE=Release ../.. \ -&& make -j`nproc` \ -&& make install \ -&& cd /tmp/ \ -&& rm -rf /tmp/grpc/third_party/cares/cares \ -&& rm -rf /tmp/grpc/third_party/zlib \ -&& rm -rf /tmp/grpc/third_party/protobuf \ -# build and install gRPC -&& mkdir -p /tmp/grpc/cmake/build \ -&& cd /tmp/grpc/cmake/build \ -&& cmake3 -DgRPC_INSTALL=ON \ - -DgRPC_BUILD_TESTS=OFF \ - -DgRPC_PROTOBUF_PROVIDER=package \ - -DgRPC_ZLIB_PROVIDER=package \ - -DgRPC_CARES_PROVIDER=package \ - -DgRPC_SSL_PROVIDER=package \ - -DCMAKE_BUILD_TYPE=Release ../.. \ -&& make -j`nproc` \ -&& make install \ -&& cd /tmp/ \ -&& rm -rf /tmp/grpc \ -&& ldconfig + cd /tmp; \ + git clone -b curl-7_76_1 https://github.com/curl/curl.git; \ + cd /tmp/curl; \ + mkdir build && cd build && cmake3 -DCMAKE_INSTALL_PREFIX:PATH=/usr/local -DBUILD_SHARED_LIBS=ON .. && make install; \ + ln -s /usr/local/lib64/libcurl.so /usr/local/lib/libcurl.so; \ + mv /usr/local/lib64/libcurl.so /lib64/libcurl.so.7.76.1; \ + ln -s /lib64/libcurl.so.7.76.1 /usr/local/lib64/libcurl.so; \ + echo '/usr/local/lib64' > /etc/ld.so.conf.d/locallib64.conf; \ + ldconfig; \ + rm --recursive /tmp/curl; # Install AWS SDK for C++ so we can use S3 storage. -# Do this before installing newer curl. -RUN --mount=type=tmpfs,target=/tmp \ - mkdir /tmp/aws-sdk-cpp \ - && cd /tmp/aws-sdk-cpp \ - && curl --location 'https://github.com/aws/aws-sdk-cpp/archive/1.8.48.tar.gz' \ - | tar --extract --gzip \ - && cd aws-sdk-cpp-1.8.48 \ - && mkdir build \ - && cd build \ - && cmake3 -DCMAKE_BUILD_TYPE=Release \ - -DBUILD_ONLY="s3" \ - -DCURL_DIR=/usr/lib64/ \ - -DCMAKE_C_FLAGS="-Wno-unused-variable -Wno-unused-parameter" \ - -DCMAKE_CXX_FLAGS="-Wno-unused-variable -Wno-unused-parameter" .. \ - && make --jobs "$(nproc)" install \ - && rm --recursive /tmp/aws-sdk-cpp - -# Install curl -# Newer version needed for TRTIS. RUN --mount=type=tmpfs,target=/tmp \ -cd /tmp/ \ -&& git clone -b curl-7_67_0 --depth 1 https://github.com/curl/curl.git \ -&& cd curl \ -&& mkdir build && cd build \ -&& cmake3 -DCMAKE_INSTALL_PREFIX:PATH=/usr/local -DBUILD_SHARED_LIBS=ON .. \ -&& make --jobs "$(nproc)" install \ -&& cp /usr/local/lib64/libcurl.so /usr/local/lib/libcurl.so \ -&& ldconfig - -# Install TensorRT-Inference-Server client libs. + mkdir /tmp/aws-sdk-cpp; \ + cd /tmp/aws-sdk-cpp; \ + curl --location 'https://github.com/aws/aws-sdk-cpp/archive/1.8.179.tar.gz' \ + | tar --extract --gzip; \ + cd aws-sdk-cpp-1.8.179; \ + mkdir build; \ + cd build; \ + cmake3 -DCMAKE_BUILD_TYPE=Release \ + -DBUILD_ONLY="s3" \ + -DBUILD_SHARED_LIBS=ON \ + -DENABLE_TESTING=OFF \ + -DCURL_DIR=/usr/local/lib64/ ..; \ + make --jobs "$(nproc)" install; \ + rm --recursive /tmp/aws-sdk-cpp; + +# Triton Client build +ARG TRITON_REPO_TAG=r21.04 RUN --mount=type=tmpfs,target=/tmp \ -cd /tmp/ \ -&& git clone -b v1.7.0 --depth 1 https://github.com/NVIDIA/tensorrt-inference-server.git tensorrt-inference-server \ -# Disable building TRTIS python client. -&& sed -i '/add_subdirectory(\.\.\/\.\.\/src\/clients\/python src\/clients\/python)/d' /tmp/tensorrt-inference-server/build/trtis-clients/CMakeLists.txt \ -# Set version number. -&& sed -i 's/project (trtis-clients)/project(trtis-clients VERSION "0.0.0")/g' /tmp/tensorrt-inference-server/build/trtis-clients/CMakeLists.txt \ -&& mkdir -p /tmp/tensorrt-inference-server/build/trtis-clients/build \ -&& cd /tmp/tensorrt-inference-server/build/trtis-clients/build \ -&& cmake3 -DCMAKE_INSTALL_PREFIX=/root/trtis -DCMAKE_PREFIX_PATH=/usr/lib64/ -DCURL_DIR=/usr/local/ .. \ -&& make -j`nproc` \ -&& make install \ -&& cp /tmp/tensorrt-inference-server/src/core/constants.h /root/trtis/include/constants.h \ -&& chmod 644 /root/trtis/include/constants.h \ -&& echo '/root/trtis/lib/' > /etc/ld.so.conf.d/trtis.conf \ -&& echo '/usr/local/lib' > /etc/ld.so.conf.d/locallib.conf \ -&& ldconfig \ -# Fix bad header include references to 'src'. -&& find /root/trtis/include/ -type f -exec sed -i 's/#include "src\/clients\/c++\//#include "/g' {} \; \ -&& find /root/trtis/include/ -type f -exec sed -i 's/#include "src\/core\//#include "/g' {} \; \ -&& mkdir -p /root/trtis/cmake/ \ -&& echo 'add_library(request SHARED IMPORTED) # or STATIC instead of SHARED' > /root/trtis/cmake/request-config.cmake \ -&& echo 'set_target_properties(request PROPERTIES' >> /root/trtis/cmake/request-config.cmake \ -&& echo ' IMPORTED_LOCATION "/root/trtis/lib/librequest.so"' >> /root/trtis/cmake/request-config.cmake \ -&& echo ' INTERFACE_INCLUDE_DIRECTORIES "/root/trtis/include")' >> /root/trtis/cmake/request-config.cmake - + mkdir -p /tmp/triton; \ + cd /tmp/triton; \ + git clone -b ${TRITON_REPO_TAG} https://github.com/triton-inference-server/server.git; \ + # build the client library and examples + mkdir -p /tmp/triton/server/builddir; \ + cd /tmp/triton/server/builddir; \ + # remove python client source so it doesn't try to build it + sed -i '/add_subdirectory(\.\.\/\.\.\/src\/clients\/python src\/clients\/python)/d' /tmp/triton/server/build/client/CMakeLists.txt; \ + # fix c++14 idiom for clearing queue + sed -i 's/data_buffers_ = {};/data_buffers_ = std::queue>();/g' /tmp/triton/server/src/clients/c++/library/http_client.cc; \ + cmake -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX:PATH=/opt/triton/ \ + -DTRITON_COMMON_REPO_TAG:STRING=${TRITON_REPO_TAG} \ + -DTRITON_CORE_REPO_TAG:STRING=${TRITON_REPO_TAG} \ + -DTRITON_CLIENT_SKIP_EXAMPLES=ON \ + -DTRITON_ENABLE_PYTHON=OFF \ + -DTRITON_ENABLE_METRICS=OFF \ + -DTRITON_ENABLE_METRICS_GPU=OFF \ + -DTRITON_ENABLE_LOGGING=OFF \ + -DTRITON_ENABLE_STATS=OFF \ + -DTRITON_ENABLE_GPU=OFF \ + -DTRITON_ENABLE_GRPC=ON \ + -DTRITON_ENABLE_HTTP=ON \ + -DTRITON_CURL_WITHOUT_CONFIG=ON \ + ../build; \ + # simlink curl's lib64 to lib so we can find curl.a + mkdir -p /tmp/triton/server/builddir/curl/install; \ + ln -s /tmp/triton/server/builddir/curl/install/lib64 /tmp/triton/server/builddir/curl/install/lib; \ + make -j "$(nproc)" client; \ + # move missing libraries to install in /opt/triton + rsync -aL /tmp/triton/server/builddir/curl/install/ /opt/triton/; \ + rsync -aL /tmp/triton/server/builddir/protobuf/bin/ /opt/triton/bin/; \ + rsync -aL /tmp/triton/server/builddir/protobuf/include/ /opt/triton/include/; \ + rsync -aL /tmp/triton/server/builddir/protobuf/lib64/ /opt/triton/lib64/; \ + rsync -aL /tmp/triton/server/builddir/grpc/bin/ /opt/triton/bin/; \ + rsync -aL /tmp/triton/server/builddir/grpc/include/ /opt/triton/include/; \ + rsync -aL /tmp/triton/server/builddir/grpc/lib/ /opt/triton/lib/; \ + rsync -aL /tmp/triton/server/builddir/grpc/lib64/ /opt/triton/lib64; \ + rsync -aL /tmp/triton/server/builddir/grpc/share/ /opt/triton/share/; \ + rsync -aL /tmp/triton/server/builddir/c-ares/bin/ /opt/triton/bin/; \ + rsync -aL /tmp/triton/server/builddir/c-ares/include/ /opt/triton/include/c-ares/; \ + rsync -aL /tmp/triton/server/builddir/c-ares/lib/ /opt/triton/lib; \ + echo '/opt/triton/lib' > /etc/ld.so.conf.d/triton.conf; \ + ldconfig; \ + rm --recursive /tmp/triton; + +# set Triton version correspoding to r21.04 +ENV TRITON_VERSION=2.9.0 ENV request_DIR=/root/trtis WORKDIR /home/mpf/component_src @@ -146,7 +168,7 @@ COPY . . RUN build-component.sh -FROM nvcr.io/nvidia/tensorrtserver:19.04-py3 as openmpf_trtis_test_server +FROM nvcr.io/nvidia/tritonserver:21.04-py3 as openmpf_trtis_test_server COPY --from=build_component /usr/lib64/libssl3.so /usr/lib64/libnspr4.so /usr/lib64/libnss3.so \ /usr/lib64/libgtest*.so.0.0.0 /home/mpf/mpf-sdk-install/lib/ /usr/local/lib/ @@ -175,7 +197,7 @@ COPY --from=openmpf_trtis_test_server /Testing_MPF_FILES/libmpfTrtisDetection.so RUN cp $PLUGINS_DIR/TrtisDetection/lib/librequest.so /usr/local/lib/ && ldconfig LABEL org.label-schema.license="Apache 2.0" \ - org.label-schema.name="OpenMPF TRTIS Client Detection" \ + org.label-schema.name="OpenMPF TRITON Client Detection" \ org.label-schema.schema-version="1.0" \ org.label-schema.url="https://openmpf.github.io" \ org.label-schema.vcs-url="https://github.com/openmpf/openmpf-components" \ diff --git a/cpp/TrtisDetection/TrtisDetection.cpp b/cpp/TrtisDetection/TrtisDetection.cpp index d7c02e51..0739d4bf 100644 --- a/cpp/TrtisDetection/TrtisDetection.cpp +++ b/cpp/TrtisDetection/TrtisDetection.cpp @@ -114,8 +114,22 @@ T get(const Properties &p, const string &k, const T def) { return DetectionComponentUtils::GetProperty(p, k, def); } + /** **************************************************************************** -* Parse TRTIS setting out of MPFJob +* get vector of raw pointers from vector of unique pointers +***************************************************************************** */ +template +vector getRaw(const vector> &v){ + vector raw; + for(const auto& i : v){ + raw.push_back(i.get()); + } + return raw; +} + + +/** **************************************************************************** +* Parse TRTIS setting MPFJob ***************************************************************************** */ TrtisJobConfig::TrtisJobConfig(const MPFJob &job, const log4cxx::LoggerPtr &log) : @@ -133,6 +147,9 @@ TrtisJobConfig::TrtisJobConfig(const MPFJob &job, maxInferConcurrency = get(jpr, "MAX_INFER_CONCURRENCY", 5); LOG4CXX_TRACE(log, "MAX_INFER_CONCURRENCY: " << maxInferConcurrency); + + clientTimeout = get(jpr, "INFER_TIMEOUT_US",0); + LOG4CXX_TRACE(log, "INFER_TIMEOUT_US: " << clientTimeout); } /** **************************************************************************** @@ -157,7 +174,9 @@ TrtisIpIrv2CocoJobConfig::TrtisIpIrv2CocoJobConfig(const MPFJob &job, : TrtisJobConfig(job, log), image_x_max(image_width - 1), image_y_max(image_height - 1), - userBBoxNorm({0.0f, 0.0f, 1.0f, 1.0f}) { + userBBoxNorm({0.0f, 0.0f, 1.0f, 1.0f}), + userBBoxNormShape({4}) +{ const Properties jpr = job.job_properties; userFeatEnabled = get(jpr, "USER_FEATURE_ENABLE", false); frameFeatEnabled = get(jpr, "FRAME_FEATURE_ENABLE", true); @@ -258,7 +277,7 @@ void TrtisDetection::_readClassNames(const string &model, * * \returns A continuous RGB data vector ready for inference server ***************************************************************************** */ -BytVec TrtisDetection::_cvRGBBytes(const cv::Mat &img, LngVec &shape) { +vector TrtisDetection::_cvRGBBytes(const cv::Mat &img, vector &shape) { cv::Mat rgbImg; if (img.channels() == 3) { cv::cvtColor(img, rgbImg, cv::COLOR_BGR2RGB); @@ -279,7 +298,7 @@ BytVec TrtisDetection::_cvRGBBytes(const cv::Mat &img, LngVec &shape) { } // return continuous chunk of image data in a byte vector - BytVec data; + vector data; size_t img_byte_size = rgbImg.total() * rgbImg.elemSize(); data.resize(img_byte_size); @@ -337,137 +356,73 @@ cv::Mat TrtisDetection::_cvResize(const cv::Mat &img, /** **************************************************************************** * Scale image colorspace and dimensions and prep for inference server * -* \param cfg configuration settings -* \param img OpenCV image to prep for inferencing -* \param ctx shared pointer to an inference context -* \param[out] shape shape of the imgDat tensor -* \param[out] imgDat image tensor data +* \param cfg configuration settings +* \param img OpenCV image to prep for inferencing +* \param[out] shape tensor shape vector output +* \param[out] imgDat tensor byte data output +* +* \returns vector of prepared inputs for inferencing * -* \note shape and imgDat need to persist till inference call ***************************************************************************** */ -void TrtisDetection::_ip_irv2_coco_prepImageData( - const TrtisIpIrv2CocoJobConfig &cfg, - const cv::Mat &img, - const sPtrInferCtx &ctx, - LngVec &shape, - BytVec &imgDat) { +vector> +TrtisDetection::_ip_irv2_coco_prepInputData(const TrtisIpIrv2CocoJobConfig &cfg, + const cv::Mat &img, + vector &imgShape, + vector &imgDat) { + double scaleFactor = 1.0; LOG4CXX_TRACE(_log, "Preparing image data for inferencing"); if (cfg.clientScaleEnabled) { - imgDat = _cvRGBBytes(_cvResize(img, scaleFactor, 1024, 600), shape); + imgDat = _cvRGBBytes(_cvResize(img, scaleFactor, 1024, 600), imgShape); LOG4CXX_TRACE(_log, "using client side image scaling"); } else { - imgDat = _cvRGBBytes(img, shape); + imgDat = _cvRGBBytes(img, imgShape); LOG4CXX_TRACE(_log, "using TRTIS model's image scaling"); } - // Initialize the inputs with the data. - sPtrInferCtxInp inImgDat, inBBox; - NI_CHECK_OK(ctx->GetInput("image_input", &inImgDat), - "unable to get image_input"); - NI_CHECK_OK(ctx->GetInput("bbox_input", &inBBox), - "unable to get bbox_input"); - NI_CHECK_OK(inImgDat->Reset(), - "unable to reset image_input"); - NI_CHECK_OK(inBBox->Reset(), - "unable to reset bbox_input"); - NI_CHECK_OK(inImgDat->SetShape(shape), - "failed setting image_input shape"); - NI_CHECK_OK(inImgDat->SetRaw(imgDat), - "failed setting image_input"); - NI_CHECK_OK(inBBox->SetRaw((uint8_t * )(&(cfg.userBBoxNorm[0])), - cfg.userBBoxNorm.size() * sizeof(float)), - "failed setting bbox_input"); - LOG4CXX_TRACE(_log, "Prepped data for inferencing"); -} + vector> inferInputs; + nic::InferInput *tmp; -/** **************************************************************************** -* Create an inference context for a model. -* -* \param cfg job configuration settings containing TRTIS server info -* -* \returns shared pointer to an inferencing context -* to use for inferencing requests -***************************************************************************** */ -sPtrInferCtx TrtisDetection::_niGetInferContext(const TrtisJobConfig &cfg, int ctxId) { - uPtrInferCtx ctx; - NI_CHECK_OK(nic::InferGrpcContext::Create(&ctx, ctxId, cfg.trtis_server, cfg.model_name, cfg.model_version), - "unable to create TRTIS inference context for \"" + cfg.trtis_server + "\""); - - // Configure context for 'batch_size'=1 and return all outputs - uPtrInferCtxOpt options; - NI_CHECK_OK(nic::InferContext::Options::Create(&options), - "failed initializing TRTIS inference options"); - options->SetBatchSize(1); - for (const auto &output : ctx->Outputs()) { - options->AddRawResult(output); - } - NI_CHECK_OK(ctx->SetRunOptions(*options), - "failed initializing TRTIS batch size and outputs"); + NI_CHECK_OK(nic::InferInput::Create(&tmp,"image_input",imgShape,"UINT8"), + "unable to create 'image_input'"); + inferInputs.push_back(unique_ptr(tmp)); + NI_CHECK_OK(inferInputs.back()->AppendRaw(imgDat), + "unable to set data for 'image_input"); - LOG4CXX_TRACE(_log, "Created context[" << ctx->CorrelationId() << "]"); + NI_CHECK_OK(nic::InferInput::Create(&tmp,"bbox_input",cfg.userBBoxNormShape,"FP32"), + "unable to create 'bbox_input'"); + inferInputs.push_back(unique_ptr(tmp)); + NI_CHECK_OK(inferInputs.back()->AppendRaw((uint8_t*)(&(cfg.userBBoxNorm[0])), + cfg.userBBoxNorm.size()*sizeof(float)), + "unable to set data for 'bbox_input'"); - return move(ctx); + return inferInputs; } /** **************************************************************************** -* Create inference contexts for a model. +* Create a pool of inference clients for a model. * -* \param cfg job configuration settings containing TRTIS server info +* \param cfg job configuration settings containing TRITON server info * -* \returns map of context ids to shared pointers to inferencing contexts -***************************************************************************** */ -unordered_map TrtisDetection::_niGetInferContexts(const TrtisJobConfig &cfg) { - unordered_map ctxMap; - - nic::Error err; - for (int i = 0; i < cfg.maxInferConcurrency; i++) { - ctxMap[i] = move(_niGetInferContext(cfg, i)); - } - - return ctxMap; -} - -/** **************************************************************************** -* convert ni::DataType enum to string for logging -* -* \param dt NVIDIA DataType enum +* \returns pointers to inferencing clients to use for inferencing requests * -* \returns string descriptor of enum value ***************************************************************************** */ -string TrtisDetection::_niType2Str(ni::DataType dt) { - switch (dt) { - case ni::TYPE_INVALID: - return "INVALID"; - case ni::TYPE_BOOL: - return "BOOL"; - case ni::TYPE_UINT8: - return "UINT8"; - case ni::TYPE_UINT16: - return "UINT16"; - case ni::TYPE_UINT32: - return "UINT32"; - case ni::TYPE_UINT64: - return "UINT64"; - case ni::TYPE_INT8: - return "INT8"; - case ni::TYPE_INT16: - return "INT16"; - case ni::TYPE_INT32: - return "INT32"; - case ni::TYPE_INT64: - return "INT64"; - case ni::TYPE_FP16: - return "FP16"; - case ni::TYPE_FP32: - return "FP32"; - case ni::TYPE_FP64: - return "FP64"; - case ni::TYPE_STRING: - return "STRING"; - default: - return "UNKNOWN"; +unordered_map> +TrtisDetection::_niGetInferenceClients(const TrtisJobConfig &cfg) { + + unordered_map> clients; + string server_url = cfg.trtis_server; + bool verbose = false; // keep client logging silent + bool use_ssl = false; // skip encryption overhead + auto ssl_options = nic::SslOptions(); + + for(int i=0; i < cfg.maxInferConcurrency; i++){ + unique_ptr client; + NI_CHECK_OK(nic::InferenceServerGrpcClient::Create(&client, server_url, verbose, use_ssl, ssl_options), + "unable to create TRTIS inference client for \"" + server_url + "\""); + clients[i] = move(client); } + return clients; } /** **************************************************************************** @@ -482,19 +437,18 @@ string TrtisDetection::_niType2Str(ni::DataType dt) { ***************************************************************************** */ cv::Mat TrtisDetection::_niResult2CVMat(const int batch_idx, const string &name, - StrUPtrInferCtxResMap &results) { + const unique_ptr &res) { // get raw data pointer and size const uint8_t *ptrRaw; size_t cntRaw; - uPtrInferCtxRes *res = &results.at(name); - NI_CHECK_OK((*res)->GetRaw(batch_idx, &ptrRaw, &cntRaw), - "Failed to get inference server result raw data"); + NI_CHECK_OK(res->RawData(name, &ptrRaw, &cntRaw), + "Failed to get inference server result raw data for '" + name +"'"); // get raw data shape - LngVec shape; - NI_CHECK_OK((*res)->GetRawShape(&shape), - "Failed to get inference server result shape"); + vector shape; + NI_CHECK_OK(res->Shape(name, &shape), + "Failed to get inference server result shape for '" + name +"'"); size_t ndim = shape.size(); if (ndim < 2) { // force matrix for vector with single col?! ndim = 2; @@ -502,7 +456,7 @@ cv::Mat TrtisDetection::_niResult2CVMat(const int batch_idx, } // calculate num elements from shape - IntVec iShape; + vector iShape; int64 numElementsFromShape = 1; for (const auto &d: shape) { numElementsFromShape *= d; @@ -512,53 +466,42 @@ cv::Mat TrtisDetection::_niResult2CVMat(const int batch_idx, // determine opencv type and calculate num elements from raw size and data type size_t cvType; size_t sizeofEl; - ni::DataType niType = (*res)->GetOutput()->DType(); - switch (niType) { - case ni::TYPE_UINT8: - cvType = CV_8UC(ndim - 1); - sizeofEl = sizeof(uint8_t); - break; - case ni::TYPE_UINT16: - cvType = CV_16UC(ndim - 1); - sizeofEl = sizeof(uint16_t); - break; - case ni::TYPE_INT8: - cvType = CV_8SC(ndim - 1); - sizeofEl = sizeof(int8_t); - break; - case ni::TYPE_INT16: - cvType = CV_16SC(ndim - 1); - sizeofEl = sizeof(int16_t); - break; - case ni::TYPE_INT32: - cvType = CV_32SC(ndim - 1); - sizeofEl = sizeof(int32_t); - break; - case ni::TYPE_FP32: - cvType = CV_32FC(ndim - 1); - sizeofEl = sizeof(float); - break; - case ni::TYPE_FP64: - cvType = CV_64FC(ndim - 1); - sizeofEl = sizeof(double); - break; - // OpenCV does not support these types ?! - case ni::TYPE_UINT32: //cvType = CV_32UC(ndim-1); sizeofEl=sizeof(uint32_t); break; - case ni::TYPE_UINT64: //cvType = CV_64UC(ndim-1); sizeofEl=sizeof(uint64_t); break; - case ni::TYPE_INT64: //cvType = CV_64SC(ndim-1); sizeofEl=sizeof(int64_t); break; - case ni::TYPE_FP16: //cvType = CV_16FC(ndim-1); sizeofEl=sizeof(float16_t); break; - case ni::TYPE_BOOL: - case ni::TYPE_STRING: - case ni::TYPE_INVALID: - default: THROW_TRTISEXCEPTION(MPF_DETECTION_FAILED, - "Unsupported data_type " + _niType2Str(niType) + " in cv:Mat conversion"); + string datType; + NI_CHECK_OK(res->Datatype(name, &datType), + "Failed to get inference server result data type for '" + name +"'"); + if(datType == "FP32"){ + cvType = CV_32FC(ndim - 1); + sizeofEl = sizeof(float); + }else if(datType == "UINT8"){ + cvType = CV_8UC(ndim - 1); + sizeofEl = sizeof(uint8_t); + }else if(datType == "UINT16"){ + cvType = CV_16UC(ndim - 1); + sizeofEl = sizeof(uint16_t); + }else if(datType == "INT8"){ + cvType = CV_8SC(ndim - 1); + sizeofEl = sizeof(int8_t); + }else if(datType == "INT16"){ + cvType = CV_16SC(ndim - 1); + sizeofEl = sizeof(int16_t); + }else if(datType == "INT32"){ + cvType = CV_32SC(ndim - 1); + sizeofEl = sizeof(int32_t); + }else if(datType == "FP64"){ + cvType = CV_64FC(ndim - 1); + sizeofEl = sizeof(double); + }else{ // OpenCV does not support these types + // UINT32, UINT64, INT64, FP16, BOOL, BYTES: + THROW_TRTISEXCEPTION(MPF_DETECTION_FAILED, + "Unsupported data_type '" + datType + + "' in cv:Mat conversion"); } if (cntRaw / sizeofEl == numElementsFromShape) { return cv::Mat(ndim, iShape.data(), cvType, (void *) ptrRaw); } else { stringstream ss("Shape "); - ss << shape << " and data-type " << _niType2Str(niType) << "are inconsistent with buffer size " << cntRaw; + ss << shape << " and data-type '" << datType << "' are inconsistent with buffer size " << cntRaw; THROW_TRTISEXCEPTION(MPF_DETECTION_FAILED, ss.str()); } } @@ -614,8 +557,8 @@ bool TrtisDetection::Init() { ***************************************************************************** */ void TrtisDetection::_ip_irv2_coco_getDetections( const TrtisIpIrv2CocoJobConfig &cfg, - StrUPtrInferCtxResMap &res, - MPFImageLocationVec &locations) { + const unique_ptr &res, + vector &locations) { if (cfg.frameFeatEnabled) { LOG4CXX_TRACE(_log, "processing global feature"); @@ -795,7 +738,7 @@ void TrtisDetection::_ip_irv2_coco_tracker( const TrtisIpIrv2CocoJobConfig &cfg, MPFImageLocation &loc, const int frameIdx, - MPFVideoTrackVec &tracks) { + vector &tracks) { MPFVideoTrack *bestTrackPtr = nullptr; float minFeatureGap = FLT_MAX; @@ -858,6 +801,10 @@ void TrtisDetection::_ip_irv2_coco_tracker( * * \returns Tracks collection to which detections will be added ***************************************************************************** */ +/*vector TrtisDetection::GetDetections(const MPFVideoJob &job){ + return vector {}; +} +/**/ vector TrtisDetection::GetDetections(const MPFVideoJob &job) { try { LOG4CXX_INFO(_log, "[" << job.job_name << "] Starting job"); @@ -885,71 +832,79 @@ vector TrtisDetection::GetDetections(const MPFVideoJob &job) { // frames per milli-sec if available double fp_ms = get(job.media_properties, "FPS", 0.0) / 1000.0; - unordered_map ctxMap = _niGetInferContexts(cfg); - size_t initialCtxPoolSize = ctxMap.size(); - LOG4CXX_TRACE(_log, "Retrieved inferencing context pool of size " << initialCtxPoolSize << " for model '" - << cfg.model_name << "' from server " - << cfg.trtis_server); + unordered_map> clients = _niGetInferenceClients(cfg); + int initialClientPoolSize = clients.size(); + LOG4CXX_TRACE(_log, "Retrieved inferencing client pool of size " << initialClientPoolSize << " for model '" + << cfg.model_name << "' from server " + << cfg.trtis_server); - unordered_set freeCtxPool; - for (const auto &pair : ctxMap) { - freeCtxPool.insert(pair.first); + unordered_set freeClients; + for (const auto &pair : clients) { + freeClients.insert(pair.first); } - mutex freeCtxMtx, nextRxFrameMtx, tracksMtx, errorMtx; + mutex freeClientsMtx, nextRxFrameMtx, tracksMtx, errorMtx; condition_variable freeCtxCv, nextRxFrameCv; exception_ptr eptr; // first exception thrown try { - LOG4CXX_TRACE(_log, "Main thread_id:" << this_thread::get_id()); + LOG4CXX_TRACE(_log, "Main thread_id:" << hex << this_thread::get_id()); int frameIdx = 0; int nextRxFrameIdx = 0; + nic::InferOptions inferOptions(cfg.model_name); + if(cfg.model_version > 0){ + inferOptions.model_version_ = to_string(cfg.model_version); + } + inferOptions.client_timeout_ = cfg.clientTimeout; + do { // Wait for an available inference context. LOG4CXX_TRACE(_log, "requesting inference from TRTIS server for frame[" << frameIdx << "]"); - int ctxId; + int clientId; { - unique_lock lk(freeCtxMtx); - if (freeCtxPool.empty()) { - LOG4CXX_TRACE(_log, "wait for an infer context to become available"); - freeCtxCv.wait(lk, [&freeCtxPool] { return !freeCtxPool.empty(); }); + unique_lock lk(freeClientsMtx); + if (freeClients.empty()) { + LOG4CXX_TRACE(_log, "wait for an infer client to become available"); + freeCtxCv.wait(lk, [&freeClients] { return !freeClients.empty(); }); } { lock_guard lk(errorMtx); if (eptr) { + LOG4CXX_ERROR(_log,"Error: an exception occurred"); break; // stop processing frames } } - auto it = freeCtxPool.begin(); - ctxId = *it; - freeCtxPool.erase(it); - LOG4CXX_TRACE(_log, "removing context[" << ctxId << "] from pool"); + auto it = freeClients.begin(); + clientId = *it; + freeClients.erase(it); + LOG4CXX_TRACE(_log, "using request client[" << clientId << "]"); } - sPtrInferCtx &ctx = ctxMap[ctxId]; + unique_ptr &client = clients[clientId]; - LngVec shape; - BytVec imgDat; - _ip_irv2_coco_prepImageData(cfg, frame, ctx, shape, imgDat); + vector shape; + vector frameDat; + auto inferInputs = _ip_irv2_coco_prepInputData(cfg, frame, shape, frameDat); LOG4CXX_TRACE(_log, "Loaded data into inference context"); LOG4CXX_DEBUG(_log, "frame[" << frameIdx << "] sending"); // Send inference request to the inference server. NI_CHECK_OK( - ctx->AsyncRun([frameIdx, &job, &cfg, &freeCtxCv, &nextRxFrameCv, - &nextRxFrameMtx, &tracksMtx, &freeCtxMtx, &errorMtx, - &nextRxFrameIdx, &freeCtxPool, &eptr, + client->AsyncInfer([clientId, frameIdx, &job, &cfg, &freeCtxCv, &nextRxFrameCv, + &nextRxFrameMtx, &tracksMtx, &freeClientsMtx, &errorMtx, + &nextRxFrameIdx, &freeClients, &eptr, &class_extra_tracks, &frame_track, &user_track, - this](nic::InferContext *c, sPtrInferCtxReq req) { + this](nic::InferResult* tmpResult) { // NOTE: When this callback is invoked, the frame has already been processed by the TRTIS server. - LOG4CXX_DEBUG(_log, "Async run callback for frame[" << frameIdx << "] with context[" - << c->CorrelationId() - << "] and thread_id:" - << this_thread::get_id()); + LOG4CXX_DEBUG(_log, "Async run callback for frame[" << frameIdx << "] " + << "and thread_id:" + << hex << this_thread::get_id()); + // stuff raw results pointer into a smart pointer + unique_ptr inferResults(tmpResult); // Ensure tracking is performed on the frames in the proper order. { @@ -969,21 +924,15 @@ vector TrtisDetection::GetDetections(const MPFVideoJob &job) { // Retrieve the results from the TRTIS server and update tracks. try { - StrUPtrInferCtxResMap res; - bool is_ready = false; - NI_CHECK_OK( - c->GetAsyncRunResults(&res, &is_ready, req, true), - "Failed to retrieve inference results for context " + - to_string(c->CorrelationId())); - if (!is_ready) { - THROW_TRTISEXCEPTION(MPF_DETECTION_FAILED, - "Inference results not ready during callback for context " + - to_string(c->CorrelationId())); + if(!inferResults->RequestStatus().IsOk()){ + THROW_TRTISEXCEPTION(MPF_DETECTION_FAILED, + "Inference failed for frame[" + to_string(frameIdx) + "]: " + + inferResults->RequestStatus().Message()); } - LOG4CXX_TRACE(_log, "inference complete"); - MPFImageLocationVec locations; - _ip_irv2_coco_getDetections(cfg, res, locations); + LOG4CXX_TRACE(_log, "inference succeeded"); + vector locations; + _ip_irv2_coco_getDetections(cfg, inferResults, locations); LOG4CXX_TRACE(_log, "inferenced frame[" << frameIdx << "]"); { lock_guard lk(tracksMtx); @@ -1003,6 +952,7 @@ vector TrtisDetection::GetDetections(const MPFVideoJob &job) { LOG4CXX_TRACE(_log, "tracked objects in frame[" << frameIdx << "]"); } } catch (...) { + LOG4CXX_ERROR(_log,"Error: an exception occured"); try { Utils::LogAndReThrowException(job, _log); } catch (MPFDetectionException &e) { @@ -1023,15 +973,16 @@ vector TrtisDetection::GetDetections(const MPFVideoJob &job) { nextRxFrameCv.notify_all(); } - // We're done with the context. Add it back to the pool so it can be used for another frame. + // We're done with the client. Add it back to the pool so it can be used for another frame. { - lock_guard lk(freeCtxMtx); - freeCtxPool.insert(c->CorrelationId()); + lock_guard lk(freeClientsMtx); + freeClients.insert(clientId); freeCtxCv.notify_all(); - LOG4CXX_TRACE(_log, "returned context[" << c->CorrelationId() << "] to pool"); + LOG4CXX_TRACE(_log, "returning client[" << clientId << "] to pool"); } LOG4CXX_DEBUG(_log, "frame[" << frameIdx << "] complete"); - }), + }, + inferOptions,getRaw(inferInputs)), "unable to inference '" + cfg.model_name + "' ver." + to_string(cfg.model_version)); LOG4CXX_DEBUG(_log, "Inference request sent for frame[" << frameIdx << "] sent"); @@ -1039,8 +990,8 @@ vector TrtisDetection::GetDetections(const MPFVideoJob &job) { frameIdx++; LOG4CXX_TRACE(_log, "frameIdx++ to " << frameIdx); } while (video_cap.Read(frame)); - } catch (...) { + LOG4CXX_ERROR(_log,"Error: an exception occured"); try { Utils::LogAndReThrowException(job, _log); } catch (MPFDetectionException &e) { @@ -1053,14 +1004,17 @@ vector TrtisDetection::GetDetections(const MPFVideoJob &job) { } } + + // Always wait for async threads to complete. - if (freeCtxPool.size() < initialCtxPoolSize) { + if (freeClients.size() < initialClientPoolSize) { LOG4CXX_TRACE(_log, - "wait for inference context pool size to return to initial size of " << initialCtxPoolSize); - unique_lock lk(freeCtxMtx); + "wait for inference context pool size to return to initial size of " << initialClientPoolSize); + unique_lock lk(freeClientsMtx); freeCtxCv.wait(lk, - [&freeCtxPool, &initialCtxPoolSize] { return freeCtxPool.size() == initialCtxPoolSize; }); + [&freeClients, &initialClientPoolSize] { return freeClients.size() == initialClientPoolSize; }); } + LOG4CXX_DEBUG(_log, "all frames [" << job.start_frame << "..." << job.stop_frame <<"] frames complete"); // Abort now if an error occurred. if (eptr) { @@ -1068,8 +1022,6 @@ vector TrtisDetection::GetDetections(const MPFVideoJob &job) { rethrow_exception(eptr); } - LOG4CXX_DEBUG(_log, "all frames complete"); - vector tracks = move(class_extra_tracks); if (!frame_track.frame_locations.empty()) { tracks.push_back(move(frame_track)); @@ -1129,23 +1081,32 @@ vector TrtisDetection::GetDetections(const MPFImageJob &job) LOG4CXX_TRACE(_log, "parsed job configuration settings"); cfg.maxInferConcurrency = 1; - sPtrInferCtx ctx = _niGetInferContext(cfg); - LOG4CXX_TRACE(_log, "retrieved inferencing context for model '" << cfg.model_name << "' from server " - << cfg.trtis_server); + nic::InferOptions inferOptions(cfg.model_name); + if(cfg.model_version > 0){ + inferOptions.model_version_ = to_string(cfg.model_version); + } + inferOptions.client_timeout_ = cfg.clientTimeout; + + LOG4CXX_TRACE(_log, "created inferencing client for model '" + << cfg.model_name << "' for server " + << cfg.trtis_server); + + vector imgDat; + vector imgShape; + auto inferInputs = _ip_irv2_coco_prepInputData(cfg, img, imgShape, imgDat); - LngVec shape; - BytVec imgDat; - _ip_irv2_coco_prepImageData(cfg, img, ctx, shape, imgDat); - LOG4CXX_TRACE(_log, "loaded data into inference context"); + LOG4CXX_TRACE(_log, "prepared inference input data"); // Send inference request to the inference server. - StrUPtrInferCtxResMap res; - NI_CHECK_OK(ctx->Run(&res), "unable to inference '" + cfg.model_name - + "' ver." + to_string(cfg.model_version)); + nic::InferResult *tmp; + NI_CHECK_OK(_niGetInferenceClients(cfg)[0]->Infer(&tmp,inferOptions,getRaw(inferInputs)), + "unable to inference '" + cfg.model_name + + "' ver." + to_string(cfg.model_version)); + unique_ptr inferResults(tmp); LOG4CXX_TRACE(_log, "inference complete"); size_t next_idx = locations.size(); - _ip_irv2_coco_getDetections(cfg, res, locations); + _ip_irv2_coco_getDetections(cfg, inferResults, locations); LOG4CXX_TRACE(_log, "parsed detections into locations vector"); for (auto &location : locations) { diff --git a/cpp/TrtisDetection/TrtisDetection.h b/cpp/TrtisDetection/TrtisDetection.h index 7c1023a0..2b53a265 100644 --- a/cpp/TrtisDetection/TrtisDetection.h +++ b/cpp/TrtisDetection/TrtisDetection.h @@ -35,45 +35,25 @@ // Nvidia TensorRT Inference Server (trtis) client lib includes // (see https://github.com/NVIDIA/tensorrt-inference-server) -#include "request_grpc.h" -#include "request_http.h" -#include "model_config.pb.h" +//#include "request_grpc.h" +//#include "request_http.h" +//#include "model_config.pb.h" + +#include "grpc_client.h" #include "IFeatureStorage.h" namespace MPF{ namespace COMPONENT{ - namespace ni = nvidia::inferenceserver; ///< namespace alias for inference server namespace nic = nvidia::inferenceserver::client; using namespace std; - typedef vector BytVec; ///< vector of bytes - typedef vector IntVec; ///< vector of integers - typedef vector LngVec; ///< vector of 64bit integers - typedef vector FltVec; ///< vector of floats - - typedef vector MPFVideoTrackVec; - typedef vector MPFImageLocationVec; - - typedef nic::InferContext InferCtx; ///< trtis inferencing message context - typedef nic::InferContext::Input InferCtxInp; ///< inferencing message input context - typedef nic::InferContext::Result InferCtxRes; ///< inferencing message output context - typedef nic::InferContext::Options InferCtxOpt; ///< inferencing message context options - typedef nic::InferContext::Request InferCtxReq; ///< inferencing context request - typedef unique_ptr uPtrInferCtx; ///< inferencing message context pointer - typedef shared_ptr sPtrInferCtx; ///< inferencing message context pointer - typedef unique_ptr uPtrInferCtxOpt; ///< inference options pointer - typedef unique_ptr uPtrInferCtxRes; ///< inference results pointer - typedef shared_ptr sPtrInferCtxInp; ///< inference input context pointer - typedef shared_ptr sPtrInferCtxReq; ///< inference request context pointer - typedef map StrUPtrInferCtxResMap; ///< map of inference outputs keyed by output name - class TrtisJobConfig { private: - static IFeatureStorage::uPtrFeatureStorage _getFeatureStorage(const MPFJob &job, - const log4cxx::LoggerPtr &log); + static IFeatureStorage::uPtrFeatureStorage + _getFeatureStorage(const MPFJob &job, const log4cxx::LoggerPtr &log); public: string data_uri; ///< media to process @@ -81,6 +61,7 @@ namespace MPF{ string model_name; ///< name of model as served by trtis int model_version; ///< version of model (e.g. -1 for latest) int maxInferConcurrency; ///< maximum number of concurrent video frame inferencing request + uint32_t clientTimeout; ///< client request timeout in micro-seconds IFeatureStorage::uPtrFeatureStorage featureStorage; ///< helper for storing FEATUREs TrtisJobConfig(const MPFJob &job, const log4cxx::LoggerPtr &log); @@ -99,8 +80,9 @@ namespace MPF{ int userBBox_y; ///< user bounding box upper left y1 int userBBox_width; ///< user bounding box width int userBBox_height; ///< user bounding box height - LngVec userBBox; ///< user bounding box as [y1,x1,x2,x2] - FltVec userBBoxNorm; ///< user bounding box normalized with image dimensions + vector userBBox; ///< user bounding box as [y1,x1,x2,x2] + vector userBBoxNorm; ///< user bounding box normalized with image dimensions + vector userBBoxNormShape; ///< normalized user bounding box shape bool recognitionEnroll; ///< enroll features in recognition framework float classConfThreshold; ///< class detection confidence threshold @@ -133,38 +115,36 @@ namespace MPF{ const string &class_label_file, int class_label_count); ///< read in class labels for a model from a file - sPtrInferCtx _niGetInferContext(const TrtisJobConfig &cfg, - int ctxId = 0); ///< get cached inference contexts + unordered_map> + _niGetInferenceClients(const TrtisJobConfig &cfg); - unordered_map _niGetInferContexts(const TrtisJobConfig &cfg); ///< get cached inference contexts - static string _niType2Str(ni::DataType dt); ///< nvidia data type to string static cv::Mat _niResult2CVMat(const int batch_idx, const string &name, - StrUPtrInferCtxResMap &results); ///< make an openCV mat header for nvidia tensor + const unique_ptr &res); ///< make an openCV mat header for nvidia tensor cv::Mat _cvResize(const cv::Mat &img, double &scaleFactor, const int target_width, const int target_height); ///< aspect preserving resize image to ~[target_width, target_height] - BytVec _cvRGBBytes(const cv::Mat &img, - LngVec &shape); ///< convert image to 8-bit RGB + vector _cvRGBBytes(const cv::Mat &img, + vector &shape); ///< convert image to 8-bit RGB - void _ip_irv2_coco_prepImageData(const TrtisIpIrv2CocoJobConfig &cfg, - const cv::Mat &img, - const sPtrInferCtx &ctx, - LngVec &shape, - BytVec &imgDat); ///< prep image for inferencing + vector> + _ip_irv2_coco_prepInputData(const TrtisIpIrv2CocoJobConfig &cfg, + const cv::Mat &img, + vector &imgShape, + vector &imgDat); ///< prep input data for inferencing void _ip_irv2_coco_getDetections(const TrtisIpIrv2CocoJobConfig &cfg, - StrUPtrInferCtxResMap &res, - MPFImageLocationVec &locations); ///< parse inference results and get detections + const unique_ptr &res, + vector &locations); ///< parse inference results and get detections void _ip_irv2_coco_tracker(const TrtisIpIrv2CocoJobConfig &cfg, MPFImageLocation &loc, const int frameIdx, - MPFVideoTrackVec &tracks); ///< tracking using time, space and feature proximity + vector &tracks); ///< tracking using time, space and feature proximity void _addToTrack(MPFImageLocation &location, int frame_index, diff --git a/cpp/TrtisDetection/sample_trtis_detector.cpp b/cpp/TrtisDetection/sample_trtis_detector.cpp index 261e1fe6..44069ed6 100644 --- a/cpp/TrtisDetection/sample_trtis_detector.cpp +++ b/cpp/TrtisDetection/sample_trtis_detector.cpp @@ -83,7 +83,7 @@ int main(int argc, char *argv[]) { Properties media_properties; string job_name("Testing TRTIS"); - MPFImageLocationVec detections; + vector detections; MPFImageJob job(job_name, uri, algorithm_properties, media_properties); cout << "Running job..." << endl; diff --git a/cpp/TrtisDetection/test/data/log4cxx.properties b/cpp/TrtisDetection/test/data/log4cxx.properties new file mode 100644 index 00000000..ee5b9b35 --- /dev/null +++ b/cpp/TrtisDetection/test/data/log4cxx.properties @@ -0,0 +1,10 @@ +# Set root logger level to DEBUG and its only appender to A1. +log4j.rootLogger=DEBUG, A1 + +# A1 is set to be a ConsoleAppender. +log4j.appender.A1=org.apache.log4j.ConsoleAppender + +# A1 uses PatternLayout. +log4j.appender.A1.layout=org.apache.log4j.PatternLayout +log4j.appender.A1.layout.ConversionPattern=%d %p [%t] %c{36}:%L - %m%n +log4j.appender.A1.layout.DatePattern='.'yyyy-MM-dd \ No newline at end of file diff --git a/cpp/TrtisDetection/test/test_trtis_detection.cpp b/cpp/TrtisDetection/test/test_trtis_detection.cpp index edf84777..2c521b5e 100644 --- a/cpp/TrtisDetection/test/test_trtis_detection.cpp +++ b/cpp/TrtisDetection/test/test_trtis_detection.cpp @@ -26,15 +26,37 @@ #include #include +#include +#include +#include #include -#include - +#include #include "TrtisDetection.h" using namespace MPF::COMPONENT; using namespace std; +/** *************************************************************************** +* macros for "pretty" gtest messages +**************************************************************************** */ +#define ANSI_TXT_GRN "\033[0;32m" +#define ANSI_TXT_MGT "\033[0;35m" //Magenta +#define ANSI_TXT_DFT "\033[0;0m" //Console default +#define GTEST_BOX "[ ] " +#define GOUT(MSG) \ + { \ + std::cout << GTEST_BOX << MSG << std::endl; \ + } +#define GOUT_MGT(MSG) \ + { \ + std::cout << ANSI_TXT_MGT << GTEST_BOX << MSG << ANSI_TXT_DFT << std::endl; \ + } +#define GOUT_GRN(MSG) \ + { \ + std::cout << ANSI_TXT_GRN << GTEST_BOX << MSG << ANSI_TXT_DFT << std::endl; \ + } + //------------------------------------------------------------------------------ Properties getProperties_ip_irv2_coco() { @@ -53,7 +75,7 @@ bool containsObject(const string &object_name, //------------------------------------------------------------------------------ bool containsObject(const string &object_name, - const MPFImageLocationVec &locations) { + const vector &locations) { return any_of(locations.begin(), locations.end(), [&](const MPFImageLocation &location) { return containsObject(object_name, location.detection_properties); @@ -66,7 +88,12 @@ void assertObjectDetectedInImage(const string &expected_object, TrtisDetection &trtisDet) { MPFImageJob job("Test", image_path, getProperties_ip_irv2_coco(), {}); - MPFImageLocationVec image_locations = trtisDet.GetDetections(job); + vector image_locations = trtisDet.GetDetections(job); + + ImageGeneration image_generation; + image_generation.WriteDetectionOutputImage(image_path, + image_locations, + "test/detections.png"); ASSERT_FALSE(image_locations.empty()); @@ -75,7 +102,7 @@ void assertObjectDetectedInImage(const string &expected_object, } bool init_logging() { - log4cxx::BasicConfigurator::configure(); + log4cxx::PropertyConfigurator::configure("test/log4cxx.properties"); return true; } bool logging_initialized = init_logging(); @@ -104,7 +131,7 @@ TEST(TRTIS, ImageTest) { //------------------------------------------------------------------------------ bool containsObject(const string &object_name, - const MPFVideoTrackVec &tracks) { + const vector &tracks) { return any_of(tracks.begin(), tracks.end(), [&](const MPFVideoTrack &track) { return containsObject(object_name, track.detection_properties); @@ -115,9 +142,9 @@ bool containsObject(const string &object_name, void assertObjectDetectedInVideo(const string &object_name, const Properties &job_props, TrtisDetection &trtisDet) { - MPFVideoJob job("TEST", "test/ff-region-object-motion.avi", 11, 12, job_props, {}); + MPFVideoJob job("TEST", "test/ff-region-object-motion.avi", 0, 12, job_props, {}); - MPFVideoTrackVec tracks = trtisDet.GetDetections(job); + vector tracks = trtisDet.GetDetections(job); ASSERT_FALSE(tracks.empty()); ASSERT_TRUE(containsObject(object_name, tracks)); @@ -136,3 +163,46 @@ TEST(TRTIS, VideoTest) { ASSERT_TRUE(trtisDet.Close()); } +//------------------------------------------------------------------------------ +TEST(TRTIS, DISABLED_VideoTest2) { + TrtisDetection trtisDet; + trtisDet.SetRunDirectory("../plugin"); + + ASSERT_TRUE(trtisDet.Init()); + + Properties job_props = getProperties_ip_irv2_coco(); + job_props["USER_FEATURE_ENABLE"] = "false"; + job_props["FRAME_FEATURE_ENABLE"] = "false"; + job_props["EXTRA_FEATURE_ENABLE"] = "false"; + job_props["MAX_INFER_CONCURRENCY"] = "10"; + job_props["CONTEXT_WAIT_TIMEOUT_SEC"] = "60"; + MPFVideoJob job("TEST", + "test/ped_short.mp4", + 0, + 50, + job_props, + {{"FPS","24.0"}}); + + vector tracks; + auto start_time = chrono::high_resolution_clock::now(); + tracks = trtisDet.GetDetections(job); + auto end_time = chrono::high_resolution_clock::now(); + double time_taken = chrono::duration_cast(end_time - start_time).count(); + time_taken = time_taken * 1e-9; + int frame_count = job.stop_frame - job.start_frame + 1; + double fps = frame_count / time_taken; + GOUT("\tVideoJob processing time for "<< frame_count << " frames : " << fixed << setprecision(3) << time_taken << "[sec]"); + GOUT("\tVideoJob processing speed:" << fixed << setprecision(2) << fps << " [FPS] or " << setprecision(3) << 1000.0f/fps << "[ms] per inference"); + ASSERT_FALSE(tracks.empty()); + + GOUT("\tWriting detected video to files."); + VideoGeneration video_generation; + video_generation.WriteTrackOutputVideo(job.data_uri, tracks, "tracks.avi"); + + GOUT("\tWriting test tracks to files."); + WriteDetectionsToFile::WriteVideoTracks("tracks.txt", tracks); + + GOUT("\tClosing down detection.") + + ASSERT_TRUE(trtisDet.Close()); +} diff --git a/java/TikaTextDetection/README.md b/java/TikaTextDetection/README.md index cbacae3b..9ad0a900 100644 --- a/java/TikaTextDetection/README.md +++ b/java/TikaTextDetection/README.md @@ -2,13 +2,18 @@ This directory contains source code for the OpenMPF Tika text detection component. -Supports most document formats (.txt, .pptx, .docx, .doc, .pdf, etc.) as input. -Extracts text contained in document and processes text for detected languages -(71 languages currently supported). For PDF and PowerPoint documents, text will -be extracted and processed per page/slide. The first page track (with detection -property PAGE_NUM = 1) corresponds to first page of each document by default. - -Users can also enable metadata reporting. -If enabled by setting the job property STORE_METADATA = "true", document -metadata will be labeled and stored as the first track. -Metadata track will not contain the PAGE_NUM or TEXT detection properties. +This component supports most document formats (`.txt`, `.pptx`, `.docx`, `.doc`, `.pdf`, etc.) as input. It extracts +text contained in the document and processes the text for detected languages (71 languages currently supported). For PDF +and PowerPoint documents text will be extracted and processed per page (each PowerPoint slide is treated as a page). +Unlike PowerPoint (`*.pptx`) files, OpenOffice Presentation (`*.odp`) files cannot be parsed by page. The component can +still extract all of the text from Word documents and other files that cannot be parsed by page, but all of their tracks +will have a `PAGE_NUM = 1` property. Note that page numbers start at 1, not 0. + +Every page can generate zero or more tracks, depending on the number of text sections in that page. A text section can +be a line or paragrah of text surrounded by newlines and/or page breaks, a single bullet point, a single table cell, +etc. In addition to `PAGE_NUM`, each track will also have a `SECTION_NUM` property. `SECTION_NUM` starts over at 1 on +each page. + +Users can also enable metadata reporting. If enabled by setting the job property `STORE_METADATA = "true"`, document +metadata will be labeled and stored as the first track. Metadata track will not contain the `PAGE_NUM`, `SECTION_NUM`, +or `TEXT` detection properties. diff --git a/java/TikaTextDetection/pom.xml b/java/TikaTextDetection/pom.xml index 4f85f71c..c7469d61 100755 --- a/java/TikaTextDetection/pom.xml +++ b/java/TikaTextDetection/pom.xml @@ -53,17 +53,17 @@ org.apache.tika tika-core - 1.18 + 1.22 org.apache.tika tika-parsers - 1.18 + 1.22 org.apache.tika tika-langdetect - 1.18 + 1.22 com.fasterxml.jackson.core diff --git a/java/TikaTextDetection/src/main/java/org/mitre/mpf/detection/tika/TextExtractionContentHandler.java b/java/TikaTextDetection/src/main/java/org/mitre/mpf/detection/tika/TextExtractionContentHandler.java index 48afe961..8d6483e0 100644 --- a/java/TikaTextDetection/src/main/java/org/mitre/mpf/detection/tika/TextExtractionContentHandler.java +++ b/java/TikaTextDetection/src/main/java/org/mitre/mpf/detection/tika/TextExtractionContentHandler.java @@ -26,33 +26,40 @@ package org.mitre.mpf.detection.tika; -import org.xml.sax.SAXException; -import org.xml.sax.Attributes; import org.apache.tika.sax.ToTextContentHandler; -import java.lang.StringBuilder; +import org.xml.sax.Attributes; + import java.util.ArrayList; public class TextExtractionContentHandler extends ToTextContentHandler { - private String pageTag = "div"; - protected int pageNumber = 0; + private static final String pageTag = "div"; + private static final String sectionTag = "p"; + protected int pageNumber; + protected int sectionNumber; public StringBuilder textResults; - public ArrayList pageMap; + public ArrayList> pageMap; + private ArrayList sectionMap; private boolean skipTitle; + private boolean skipBlankSections; public TextExtractionContentHandler(){ super(); - pageTag = "div"; pageNumber = 0; + sectionNumber = 0; // Enable to avoid storing metadata/title text from ppt document. skipTitle = true; - textResults = new StringBuilder(); - pageMap = new ArrayList(); - pageMap.add(new StringBuilder()); + // Disable to skip recording empty sections (warning: could produce an excessive number of empty tracks). + skipBlankSections = true; + textResults = new StringBuilder(); + pageMap = new ArrayList<>(); + sectionMap = new ArrayList<>(); + pageMap.add(sectionMap); + sectionMap.add(new StringBuilder()); } - public void startElement (String uri, String localName, String qName, Attributes atts) throws SAXException { + public void startElement (String uri, String localName, String qName, Attributes atts) { if (atts.getValue("class") != null) { if (pageTag.equals(qName) && (atts.getValue("class").equals("page"))) { startPage(); @@ -67,44 +74,59 @@ public void startElement (String uri, String localName, String qName, Attributes startPage(); } } + } else if (sectionTag.equals(qName)) { + newSection(); } } - public void endElement (String uri, String localName, String qName) throws SAXException { + public void endElement (String uri, String localName, String qName) { if (pageTag.equals(qName)) { endPage(); } } - public void characters(char[] ch, int start, int length) throws SAXException { + public void characters(char[] ch, int start, int length) { if (length > 0) { textResults.append(ch, start, length); - pageMap.get(pageNumber).append(ch, start, length); + pageMap.get(pageNumber).get(sectionNumber).append(ch, start, length); } } - protected void startPage() throws SAXException { + protected void startPage() { pageNumber ++; - pageMap.add(new StringBuilder()); + sectionNumber = 0; + sectionMap = new ArrayList<>(); + sectionMap.add(new StringBuilder()); + pageMap.add(sectionMap); } - protected void endPage() throws SAXException { - return; - } + protected void endPage() {} - protected void resetPage() throws SAXException { + protected void resetPage() { pageNumber = 0; + sectionNumber = 0; + sectionMap.clear(); + sectionMap.add(new StringBuilder()); pageMap.clear(); - pageMap.add(new StringBuilder()); + pageMap.add(sectionMap); + + } + + protected void newSection() { + if (skipBlankSections && sectionMap.get(sectionNumber).toString().trim().isEmpty()){ + return; + } + sectionNumber++; + sectionMap.add(new StringBuilder()); } public String toString(){ return textResults.toString(); } - // Returns the text detections, subdivided by page number. - public ArrayList getPages(){ + // Returns the text detections, subdivided by page number and section. + public ArrayList> getPages(){ return pageMap; } } diff --git a/java/TikaTextDetection/src/main/java/org/mitre/mpf/detection/tika/TikaTextDetectionComponent.java b/java/TikaTextDetection/src/main/java/org/mitre/mpf/detection/tika/TikaTextDetectionComponent.java index e8bb3864..49a8bed6 100755 --- a/java/TikaTextDetection/src/main/java/org/mitre/mpf/detection/tika/TikaTextDetectionComponent.java +++ b/java/TikaTextDetection/src/main/java/org/mitre/mpf/detection/tika/TikaTextDetectionComponent.java @@ -41,7 +41,6 @@ import java.io.File; import java.io.FileInputStream; -import java.io.IOException; import java.util.*; public class TikaTextDetectionComponent extends MPFDetectionComponentBase { @@ -61,15 +60,10 @@ public List getDetections(MPFGenericJob mpfGenericJob) throws mpfGenericJob.getJobName(), mpfGenericJob.getDataUri(), mpfGenericJob.getJobProperties().size(), mpfGenericJob.getMediaProperties().size()); - // ========================= - // Tika Detection - // ========================= - // Specify filename for tika parsers here. File file = new File(mpfGenericJob.getDataUri()); - - List pageOutput = new ArrayList(); + ArrayList> pageOutput; Metadata metadata = new Metadata(); try (FileInputStream inputstream = new FileInputStream(file)) { // Init parser with custom content handler for parsing text per page (PDF/PPTX). @@ -82,13 +76,13 @@ public List getDetections(MPFGenericJob mpfGenericJob) throws pageOutput = handler.getPages(); } catch (Exception e) { - String errorMsg = String.format("Error parsing file. Filepath = %s", file.toString()); + String errorMsg = String.format("Error parsing file. Filepath = %s", file); LOG.error(errorMsg, e); throw new MPFComponentDetectionError(MPFDetectionError.MPF_COULD_NOT_READ_DATAFILE, errorMsg); } float confidence = -1.0f; - List tracks = new LinkedList(); + List tracks = new LinkedList<>(); Map properties = mpfGenericJob.getJobProperties(); @@ -99,8 +93,8 @@ public List getDetections(MPFGenericJob mpfGenericJob) throws // Store metadata as a unique track. // Disabled by default for format consistency. if (MapUtils.getBooleanValue(properties, "STORE_METADATA")) { - Map genericDetectionProperties = new HashMap(); - Map metadataMap = new HashMap(); + Map genericDetectionProperties = new HashMap<>(); + Map metadataMap = new HashMap<>(); String[] metadataKeys = metadata.names(); for (String s: metadataKeys) { @@ -120,81 +114,91 @@ public List getDetections(MPFGenericJob mpfGenericJob) throws tracks.add(metadataTrack); } + boolean listAllPages = MapUtils.getBooleanValue(properties, "LIST_ALL_PAGES", false); // If output exists, separate all output into separate pages. // Tag each page by detected language. - if (pageOutput.size() >= 1) { + if (!pageOutput.isEmpty()) { // Load language identifier. OptimaizeLangDetector identifier = new OptimaizeLangDetector(); + identifier.loadModels(); - try { - identifier.loadModels(); - } catch (IOException e) { - String errorMsg = "Failed to load language models."; - LOG.error(errorMsg, e); - throw new MPFComponentDetectionError(MPFDetectionError.MPF_DETECTION_FAILED, errorMsg, e); - } + int maxIDLength = (int) (Math.log10(pageOutput.size())) + 1; - int pageIDLen = (int) (java.lang.Math.log10(pageOutput.size())) + 1; - for (int i = 0; i < pageOutput.size(); i++) { + int maxSectionsOnPage = pageOutput.stream().mapToInt(ArrayList::size).max().getAsInt(); + int sectionIDLength = (int) (Math.log10(maxSectionsOnPage)) + 1; - Map genericDetectionProperties = new HashMap(); + if (sectionIDLength > maxIDLength) { + maxIDLength = sectionIDLength; + } + for (int p = 0; p < pageOutput.size(); p++) { + if (pageOutput.get(p).size() == 1 && pageOutput.get(p).get(0).toString().trim().isEmpty()) { + // If LIST_ALL_PAGES is true, create empty tracks for empty pages. + if (listAllPages) { + Map genericDetectionProperties = new HashMap<>(); + genericDetectionProperties.put("TEXT", ""); + genericDetectionProperties.put("TEXT_LANGUAGE", "Unknown"); + genericDetectionProperties.put("PAGE_NUM", String.format("%0" + maxIDLength + "d", p + 1)); + genericDetectionProperties.put("SECTION_NUM", String.format("%0" + maxIDLength + "d", 1)); + MPFGenericTrack genericTrack = new MPFGenericTrack(confidence, genericDetectionProperties); + tracks.add(genericTrack); + } + continue; + } - try { - String textDetect = pageOutput.get(i).toString(); + for (int s = 0; s < pageOutput.get(p).size(); s++) { - // By default, trim out detected text. - textDetect = textDetect.trim(); + Map genericDetectionProperties = new HashMap<>(); + try { + String textDetect = pageOutput.get(p).get(s).toString(); - if (textDetect.length() > 0) { - genericDetectionProperties.put("TEXT", textDetect); - } - else{ - if (!MapUtils.getBooleanValue(properties, "LIST_ALL_PAGES", false)) { + // By default, trim out detected text. + textDetect = textDetect.trim(); + if (textDetect.isEmpty()) { continue; } - genericDetectionProperties.put("TEXT", ""); - genericDetectionProperties.put("TEXT_LANGUAGE", "Unknown"); - } - // Process text languages. - if (textDetect.length() >= charLimit) { - LanguageResult langResult = identifier.detect(textDetect); - String language = langResult.getLanguage(); + genericDetectionProperties.put("TEXT", textDetect); - if (langMap.containsKey(language)) { - language = langMap.get(language); - } - if (!langResult.isReasonablyCertain()) { - language = null; - } - if (language != null && language.length() > 0) { - genericDetectionProperties.put("TEXT_LANGUAGE", language); + // Process text languages. + if (textDetect.length() >= charLimit) { + LanguageResult langResult = identifier.detect(textDetect); + String language = langResult.getLanguage(); + + if (langMap.containsKey(language)) { + language = langMap.get(language); + } + if (!langResult.isReasonablyCertain()) { + language = null; + } + if (language != null && language.length() > 0) { + genericDetectionProperties.put("TEXT_LANGUAGE", language); + } else { + genericDetectionProperties.put("TEXT_LANGUAGE", "Unknown"); + } } else { genericDetectionProperties.put("TEXT_LANGUAGE", "Unknown"); } - } else { - genericDetectionProperties.put("TEXT_LANGUAGE", "Unknown"); - } - - } catch (Exception e) { - String errorMsg = String.format("Failed to process text detections."); - LOG.error(errorMsg, e); - throw new MPFComponentDetectionError(MPFDetectionError.MPF_DETECTION_FAILED, errorMsg); - } + } catch (Exception e) { + String errorMsg = "Failed to process text detections."; + LOG.error(errorMsg, e); + throw new MPFComponentDetectionError(MPFDetectionError.MPF_DETECTION_FAILED, errorMsg); + } - genericDetectionProperties.put("PAGE_NUM",String.format("%0" + String.valueOf(pageIDLen) + "d", i + 1)); - MPFGenericTrack genericTrack = new MPFGenericTrack(confidence, genericDetectionProperties); - tracks.add(genericTrack); + genericDetectionProperties.put("PAGE_NUM", String.format("%0" + maxIDLength + "d", p + 1)); + genericDetectionProperties.put("SECTION_NUM", String.format("%0" + maxIDLength + "d", s + 1)); + MPFGenericTrack genericTrack = new MPFGenericTrack(confidence, genericDetectionProperties); + tracks.add(genericTrack); + } } } // If entire document is empty, generate a single track reporting no detections. - if (tracks.size() == 0) { + if (tracks.isEmpty()) { LOG.warn("Empty or invalid document. No extracted text."); } @@ -205,7 +209,7 @@ public List getDetections(MPFGenericJob mpfGenericJob) throws // Map for translating from ISO 639-2 code to english description. private static Map initLangMap() { - Map map = new HashMap(); + Map map = new HashMap<>(); map.put("af", "Afrikaans"); map.put("an", "Aragonese"); map.put("ar", "Arabic"); diff --git a/java/TikaTextDetection/src/test/java/org/mitre/mpf/detection/tika/TestTikaTextDetectionComponent.java b/java/TikaTextDetection/src/test/java/org/mitre/mpf/detection/tika/TestTikaTextDetectionComponent.java index 3bfcbe70..6b851a43 100755 --- a/java/TikaTextDetection/src/test/java/org/mitre/mpf/detection/tika/TestTikaTextDetectionComponent.java +++ b/java/TikaTextDetection/src/test/java/org/mitre/mpf/detection/tika/TestTikaTextDetectionComponent.java @@ -63,7 +63,7 @@ public void tearDown() { } @Test - public void testGetDetectionsGeneric() throws MPFComponentDetectionError { + public void testGetDetectionsPowerPointFile() throws MPFComponentDetectionError { String mediaPath = this.getClass().getResource("/data/test-tika-detection.pptx").getPath(); Map jobProperties = new HashMap<>(); @@ -75,26 +75,26 @@ public void testGetDetectionsGeneric() throws MPFComponentDetectionError { boolean debug = false; List tracks = tikaComponent.getDetections(genericJob); - assertEquals("Number of expected tracks does not match.", 11 ,tracks.size()); + assertEquals("Number of expected tracks does not match.", 23 ,tracks.size()); // Test each output type. MPFGenericTrack testTrack = tracks.get(0); assertEquals("Expected language does not match.", "English", testTrack.getDetectionProperties().get("TEXT_LANGUAGE")); - assertEquals("Expected text does not match.", "Testing Text Detection\nSlide 1", testTrack.getDetectionProperties().get("TEXT")); + assertEquals("Expected text does not match.", "Testing Text Detection", testTrack.getDetectionProperties().get("TEXT")); // Test language extraction. - testTrack = tracks.get(1); + testTrack = tracks.get(3); assertEquals("Expected language does not match.", "Japanese", testTrack.getDetectionProperties().get("TEXT_LANGUAGE")); // Test no detections. - testTrack = tracks.get(4); - assertEquals("Text should be empty", "", testTrack.getDetectionProperties().get("TEXT")); + testTrack = tracks.get(9); + assertTrue("Text should be empty", testTrack.getDetectionProperties().get("TEXT").isEmpty()); assertEquals("Language should be empty", "Unknown", testTrack.getDetectionProperties().get("TEXT_LANGUAGE")); - testTrack = tracks.get(9); + testTrack = tracks.get(20); assertThat(testTrack.getDetectionProperties().get("TEXT"), containsString("All human beings are born free")); - testTrack = tracks.get(10); + testTrack = tracks.get(22); assertThat(testTrack.getDetectionProperties().get("TEXT"), containsString("End slide test text")); // For human testing. @@ -111,8 +111,8 @@ public void testGetDetectionsGeneric() throws MPFComponentDetectionError { } @Test - public void testGetDetectionsPowerPointFile() throws MPFComponentDetectionError { - String mediaPath = this.getClass().getResource("/data/test-tika-detection.pptx").getPath(); + public void testGetDetectionsDocumentFile() throws MPFComponentDetectionError { + String mediaPath = this.getClass().getResource("/data/test-tika-detection.docx").getPath(); Map jobProperties = new HashMap<>(); Map mediaProperties = new HashMap<>(); @@ -122,7 +122,13 @@ public void testGetDetectionsPowerPointFile() throws MPFComponentDetectionError MPFGenericJob genericJob = new MPFGenericJob("TestGenericJob", mediaPath, jobProperties, mediaProperties); List tracks = tikaComponent.getDetections(genericJob); - assertEquals("Number of expected tracks does not match.", 11, tracks.size()); + assertEquals("Number of expected tracks does not match.", 6, tracks.size()); + assertThat(tracks.get(0).getDetectionProperties().get("TEXT"), + containsString("first section")); + assertThat(tracks.get(1).getDetectionProperties().get("TEXT"), + containsString("second section")); + assertThat(tracks.get(2).getDetectionProperties().get("TEXT"), + containsString("third section")); } @Test diff --git a/java/TikaTextDetection/src/test/resources/data/NOTICE b/java/TikaTextDetection/src/test/resources/data/NOTICE index b58b1edb..96fdd4e8 100644 --- a/java/TikaTextDetection/src/test/resources/data/NOTICE +++ b/java/TikaTextDetection/src/test/resources/data/NOTICE @@ -10,7 +10,8 @@ Contains public domain text from the following sources: Declaration of Human Rights. Public Domain - +# test-tika-detection.docx +Contains custom text for testing Tika text detection. # test-tika-detection.xlsx Contains custom text for testing Tika text detection. diff --git a/java/TikaTextDetection/src/test/resources/data/test-tika-detection.docx b/java/TikaTextDetection/src/test/resources/data/test-tika-detection.docx new file mode 100755 index 00000000..93f0f061 Binary files /dev/null and b/java/TikaTextDetection/src/test/resources/data/test-tika-detection.docx differ