From 09e9aeed5a4715fa675529ca2e1bf1d66a95288a Mon Sep 17 00:00:00 2001 From: horcham <690936541@qq.com> Date: Wed, 31 Jan 2024 17:50:16 +0800 Subject: [PATCH 01/12] Fix problems of psenet-ctw1500 training --- configs/det/psenet/pse_r152_ctw1500.yaml | 9 ++++++--- deploy/py_infer/infer.py | 2 +- .../src/data_process/postprocess/det_db_postprocess.py | 5 ++++- docs/cn/datasets/synthtext.md | 2 +- docs/en/datasets/synthtext.md | 2 +- tools/dataset_converters/synthtext.py | 2 +- tools/export_convert_tool.py | 4 ++-- 7 files changed, 16 insertions(+), 10 deletions(-) diff --git a/configs/det/psenet/pse_r152_ctw1500.yaml b/configs/det/psenet/pse_r152_ctw1500.yaml index d6a8113b1..8eb105d47 100644 --- a/configs/det/psenet/pse_r152_ctw1500.yaml +++ b/configs/det/psenet/pse_r152_ctw1500.yaml @@ -72,9 +72,12 @@ train: - RandomColorAdjust: brightness: 0.1255 # 32.0 / 255 saturation: 0.5 - - IaaAugment: - Fliplr: { p: 0.5 } - Affine: { rotate: [ -10, 10 ] } + - RandomHorizontalFlip: + p: 0.5 + - RandomRotate: + degrees: [ -10, 10 ] + expand_canvas: False + p: 1.0 - PSEGtDecode: kernel_num: 7 min_shrink_ratio: 0.4 diff --git a/deploy/py_infer/infer.py b/deploy/py_infer/infer.py index 8ba5f7a81..f77098790 100644 --- a/deploy/py_infer/infer.py +++ b/deploy/py_infer/infer.py @@ -12,7 +12,7 @@ def main(): args = infer_args.get_args() parallel_pipeline = ParallelPipeline(args) parallel_pipeline.start_pipeline() - parallel_pipeline.infer_for_images(args.input_images_dir) + parallel_pipeline.infer_for_images(args.input_images_dir, 0) parallel_pipeline.stop_pipeline() diff --git a/deploy/py_infer/src/data_process/postprocess/det_db_postprocess.py b/deploy/py_infer/src/data_process/postprocess/det_db_postprocess.py index abe481c9a..692caaeb9 100644 --- a/deploy/py_infer/src/data_process/postprocess/det_db_postprocess.py +++ b/deploy/py_infer/src/data_process/postprocess/det_db_postprocess.py @@ -84,7 +84,10 @@ def __call__( src_w, src_h = shape_list[0, 1], shape_list[0, 0] polys = self.filter_tag_det_res(result["polys"][0], [src_h, src_w]) if self._if_merge_longedge_bbox: - polys = longedge_bbox_merge(polys, self._merge_inter_area_thres, self._merge_ratio, self._merge_angle_theta) + try: + polys = longedge_bbox_merge(polys, self._merge_inter_area_thres, self._merge_ratio, self._merge_angle_theta) + except Exception as e: + _logger.warning(f"long edge bbox merge failed: {e}") if self._if_sort_bbox: polys = sorted_boxes(polys, self._sort_bbox_y_delta) result["polys"][0] = polys diff --git a/docs/cn/datasets/synthtext.md b/docs/cn/datasets/synthtext.md index ac77e5475..b20f01119 100644 --- a/docs/cn/datasets/synthtext.md +++ b/docs/cn/datasets/synthtext.md @@ -22,7 +22,7 @@ path-to-data-dir/ > :warning: 另外, 我们强烈建议在使用 `SynthText` 数据集之前先进行预处理,因为它包含一些错误的数据。可以使用下列的方式进行校正: > ```shell -> python tools/dataset_converters/convert.py --dataset_name=synthtext --task=det --label_dir=/path-to-data-dir/SynthText/gt.mat --output_path=/path-to-data-dir/SynthText/gt_processed.mat +> python tools/dataset_converters/convert.py --dataset_name=synthtext --task=det --label_dir=/path-to-data-dir/SynthText/gt.mat --output_path=/path-to-data-dir/SynthText/gt_processed.mat --image_dir=/path-to-data-dir/SynthText > ``` > 以上的操作会产生与`SynthText`原始标注格式相同但是是经过过滤后的标注数据. diff --git a/docs/en/datasets/synthtext.md b/docs/en/datasets/synthtext.md index a2a653f0e..2178c91c0 100644 --- a/docs/en/datasets/synthtext.md +++ b/docs/en/datasets/synthtext.md @@ -23,7 +23,7 @@ path-to-data-dir/ > :warning: Additionally, It is strongly recommended to pre-process the `SynthText` dataset before using it as it contains some faulty data: > ```shell -> python tools/dataset_converters/convert.py --dataset_name=synthtext --task=det --label_dir=/path-to-data-dir/SynthText/gt.mat --output_path=/path-to-data-dir/SynthText/gt_processed.mat +> python tools/dataset_converters/convert.py --dataset_name=synthtext --task=det --label_dir=/path-to-data-dir/SynthText/gt.mat --output_path=/path-to-data-dir/SynthText/gt_processed.mat --image_dir=/path-to-data-dir/SynthText > ``` > This operation will generate a filtered output in the same format as the original `SynthText`. diff --git a/tools/dataset_converters/synthtext.py b/tools/dataset_converters/synthtext.py index bc2328d6c..237b105b4 100644 --- a/tools/dataset_converters/synthtext.py +++ b/tools/dataset_converters/synthtext.py @@ -61,7 +61,7 @@ def _sort_and_validate(self, sample: Tuple[np.ndarray, ...]) -> Tuple[np.ndarray def convert(self, task="det", image_dir=None, label_path=None, output_path=None): if task == "det": - self.convert_det(image_dir, label_path, output_path, save_output=True) + self.convert_det(image_dir, label_path, output_path) elif task == "rec_lmdb": self.convert_rec_lmdb(image_dir, label_path, output_path) else: diff --git a/tools/export_convert_tool.py b/tools/export_convert_tool.py index ee67dcc11..16a6ac0c5 100644 --- a/tools/export_convert_tool.py +++ b/tools/export_convert_tool.py @@ -77,8 +77,8 @@ def convert_mindir(self, model, info, input_file, config_file, force=False): else: log = f"{converted_model_path}.mindir exists and it will be overwritten if exported successfully." subprocess.call(f"echo {log}".split(), stdout=self.log_handle, stderr=self.log_handle) + os.remove(converted_model_path) print(log) - os.remove(converted_model_path) command = ( f"{self.convert_tool} --fmk=MINDIR --modelFile={input_file} --outputFile={converted_model_path}" + f" --optimize=ascend_oriented --configFile={config_file}" @@ -344,8 +344,8 @@ def export_mindir(self, model, is_dynamic, data_shape_h_w, model_type, input_fil else: log = f"{export_mindir_path} exists and it will be overwritten if exported successfully." subprocess.call(f"echo {log}".split(), stdout=self.log_handle, stderr=self.log_handle) + os.remove(export_mindir_path) print(log) - os.remove(export_mindir_path) command = f"python export.py --model_name_or_config {model} --save_dir {self.save_path}" if len(input_file) > 0 and os.path.exists(input_file): From 828fbb7792ac68f62bc8a0db9f7abcb626c428fc Mon Sep 17 00:00:00 2001 From: horcham <690936541@qq.com> Date: Wed, 31 Jan 2024 17:50:16 +0800 Subject: [PATCH 02/12] Fix problems of psenet-ctw1500 training --- README.md | 2 +- README_CN.md | 2 +- configs/det/dbnet/README.md | 4 ++-- configs/det/dbnet/README_CN.md | 6 +++--- .../dbnet/{db++_r50_icdar15.yaml => dbpp_r50_icdar15.yaml} | 0 ...{db++_r50_icdar15_910.yaml => dbpp_r50_icdar15_910.yaml} | 0 .../src/data_process/postprocess/det_db_postprocess.py | 4 +++- deploy/py_infer/src/utils/adapted/mindocr_models.py | 2 +- docs/cn/inference/inference_quickstart.md | 4 ++-- docs/en/inference/inference_quickstart.md | 4 ++-- tests/ut/test_models.py | 2 +- tools/infer/text/predict_from_yaml.py | 2 +- 12 files changed, 17 insertions(+), 15 deletions(-) rename configs/det/dbnet/{db++_r50_icdar15.yaml => dbpp_r50_icdar15.yaml} (100%) rename configs/det/dbnet/{db++_r50_icdar15_910.yaml => dbpp_r50_icdar15_910.yaml} (100%) diff --git a/README.md b/README.md index ef3279de5..a480ad008 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ You may adapt it to your task/dataset, for example, by running ```shell # train text detection model DBNet++ on icdar15 dataset -python tools/train.py --config configs/det/dbnet/db++_r50_icdar15.yaml +python tools/train.py --config configs/det/dbnet/dbnetpp_r50_icdar15.yaml ``` ```shell diff --git a/README_CN.md b/README_CN.md index 62b56ddf8..a0b0b977c 100644 --- a/README_CN.md +++ b/README_CN.md @@ -108,7 +108,7 @@ MindOCR在`configs`文件夹中提供系列SoTA的OCR模型及其训练策略, ```shell # train text detection model DBNet++ on icdar15 dataset -python tools/train.py --config configs/det/dbnet/db++_r50_icdar15.yaml +python tools/train.py --config configs/det/dbnet/dbnetpp_r50_icdar15.yaml ``` ```shell # train text recognition model CRNN on icdar15 dataset diff --git a/configs/det/dbnet/README.md b/configs/det/dbnet/README.md index f79c1fa98..586f85c73 100644 --- a/configs/det/dbnet/README.md +++ b/configs/det/dbnet/README.md @@ -92,8 +92,8 @@ DBNet and DBNet++ were trained on the ICDAR2015, MSRA-TD500, SCUT-CTW1500, Total | DBNet | D910x1-MS2.0-G | ResNet-50 | ImageNet | 83.53% | 86.62% | 85.05% | 13.3 s/epoch | 75.2 img/s | [yaml](db_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50-c3a4aa24.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50-c3a4aa24-fbf95c82.mindir) | | DBNet | D910x8-MS2.2-G | ResNet-50 | ImageNet | 82.62% | 88.54% | 85.48% | 2.3 s/epoch | 435 img/s | [yaml](db_r50_icdar15_8p.yaml) | Coming soon | | | | | | | | | | | | | -| DBNet++ | D910x1-MS2.0-G | ResNet-50 | SynthText | 85.70% | 87.81% | 86.74% | 17.7 s/epoch | 56 img/s | [yaml](db++_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2-9934aff0.mindir) | -| DBNet++ | D910x1-MS2.2-G | ResNet-50 | SynthText | 86.81% | 86.85% | 86.86% | 12.7 s/epoch | 78.2 img/s | [yaml](db++_r50_icdar15_910.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_910-35dc71f2.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_910-35dc71f2-e61a9c37.mindir) | +| DBNet++ | D910x1-MS2.0-G | ResNet-50 | SynthText | 85.70% | 87.81% | 86.74% | 17.7 s/epoch | 56 img/s | [yaml](dbpp_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2-9934aff0.mindir) | +| DBNet++ | D910x1-MS2.2-G | ResNet-50 | SynthText | 86.81% | 86.85% | 86.86% | 12.7 s/epoch | 78.2 img/s | [yaml](dbpp_r50_icdar15_910.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_910-35dc71f2.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_910-35dc71f2-e61a9c37.mindir) | > The input_shape for exported DBNet MindIR and DBNet++ MindIR in the links are `(1,3,736,1280)` and `(1,3,1152,2048)`, respectively. diff --git a/configs/det/dbnet/README_CN.md b/configs/det/dbnet/README_CN.md index ace68da1f..80fd82e3d 100644 --- a/configs/det/dbnet/README_CN.md +++ b/configs/det/dbnet/README_CN.md @@ -74,9 +74,9 @@ DBNet和DBNet++在ICDAR2015,MSRA-TD500,SCUT-CTW1500,Total-Text和MLT2017 | DBNet | D910x1-MS2.0-G | ResNet-50 | ImageNet | 83.53% | 86.62% | 85.05% | 13.3 s/epoch | 75.2 img/s | [yaml](db_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50-c3a4aa24.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50-c3a4aa24-fbf95c82.mindir) | | DBNet | D910x8-MS2.2-G | ResNet-50 | ImageNet | 82.62% | 88.54% | 85.48% | 2.3 s/epoch | 435 img/s | [yaml](db_r50_icdar15_8p.yaml) | Coming soon | | | | | | | | | | | | | -| DBNet++ | D910x1-MS2.0-G | ResNet-50 | SynthText | 85.70% | 87.81% | 86.74% | 17.7 s/epoch | 56 img/s | [yaml](db++_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2-9934aff0.mindir) | -| DBNet++ | D910x8-MS2.2-G | ResNet-50 | SynthText | 85.41% | 89.55% | 87.43% | 1.78 s/epoch | 432 img/s | [yaml](db++_r50_icdar15_8p.yaml) | Coming soon | -| DBNet++ | D910*x1-MS2.2-G | ResNet-50 | SynthText | 86.81% | 86.85% | 86.86% | 12.7 s/epoch | 78.2 img/s | [yaml](db++_r50_icdar15_910.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_910-35dc71f2.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_910-35dc71f2-e61a9c37.mindir) | +| DBNet++ | D910x1-MS2.0-G | ResNet-50 | SynthText | 85.70% | 87.81% | 86.74% | 17.7 s/epoch | 56 img/s | [yaml](dbpp_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2-9934aff0.mindir) | +| DBNet++ | D910x8-MS2.2-G | ResNet-50 | SynthText | 85.41% | 89.55% | 87.43% | 1.78 s/epoch | 432 img/s | [yaml](dbpp_r50_icdar15_8p.yaml) | Coming soon | +| DBNet++ | D910*x1-MS2.2-G | ResNet-50 | SynthText | 86.81% | 86.85% | 86.86% | 12.7 s/epoch | 78.2 img/s | [yaml](dbpp_r50_icdar15_910.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_910-35dc71f2.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_910-35dc71f2-e61a9c37.mindir) | > 链接中模型DBNet的MindIR导出时的输入Shape为`(1,3,736,1280)`,模型DBNet++的MindIR导出时的输入Shape为`(1,3,1152,2048)`。 diff --git a/configs/det/dbnet/db++_r50_icdar15.yaml b/configs/det/dbnet/dbpp_r50_icdar15.yaml similarity index 100% rename from configs/det/dbnet/db++_r50_icdar15.yaml rename to configs/det/dbnet/dbpp_r50_icdar15.yaml diff --git a/configs/det/dbnet/db++_r50_icdar15_910.yaml b/configs/det/dbnet/dbpp_r50_icdar15_910.yaml similarity index 100% rename from configs/det/dbnet/db++_r50_icdar15_910.yaml rename to configs/det/dbnet/dbpp_r50_icdar15_910.yaml diff --git a/deploy/py_infer/src/data_process/postprocess/det_db_postprocess.py b/deploy/py_infer/src/data_process/postprocess/det_db_postprocess.py index 692caaeb9..0a46d7f4f 100644 --- a/deploy/py_infer/src/data_process/postprocess/det_db_postprocess.py +++ b/deploy/py_infer/src/data_process/postprocess/det_db_postprocess.py @@ -85,7 +85,9 @@ def __call__( polys = self.filter_tag_det_res(result["polys"][0], [src_h, src_w]) if self._if_merge_longedge_bbox: try: - polys = longedge_bbox_merge(polys, self._merge_inter_area_thres, self._merge_ratio, self._merge_angle_theta) + polys = longedge_bbox_merge( + polys, self._merge_inter_area_thres, self._merge_ratio, self._merge_angle_theta + ) except Exception as e: _logger.warning(f"long edge bbox merge failed: {e}") if self._if_sort_bbox: diff --git a/deploy/py_infer/src/utils/adapted/mindocr_models.py b/deploy/py_infer/src/utils/adapted/mindocr_models.py index 7d30c84cb..b9c080862 100644 --- a/deploy/py_infer/src/utils/adapted/mindocr_models.py +++ b/deploy/py_infer/src/utils/adapted/mindocr_models.py @@ -4,7 +4,7 @@ MINDOCR_MODELS = { "en_ms_det_dbnet_resnet50": "det/dbnet/db_r50_icdar15.yaml", - "en_ms_det_dbnetpp_resnet50": "det/dbnet/db++_r50_icdar15.yaml", + "en_ms_det_dbnetpp_resnet50": "det/dbnet/dbnetpp_r50_icdar15.yaml", "en_ms_det_psenet_resnet152": "det/psenet/pse_r152_icdar15.yaml", "en_ms_det_psenet_resnet50": "det/psenet/pse_r50_icdar15.yaml", "en_ms_det_psenet_mobilenetv3": "det/psenet/pse_mv3_icdar15.yaml", diff --git a/docs/cn/inference/inference_quickstart.md b/docs/cn/inference/inference_quickstart.md index 26b8d7220..18ee2c6a5 100644 --- a/docs/cn/inference/inference_quickstart.md +++ b/docs/cn/inference/inference_quickstart.md @@ -8,8 +8,8 @@ | | ResNet-18 | en | IC15 | 81.73 | 24.04 | (1,3,736,1280) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/db_r18_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18-0c0c4cfa.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18-0c0c4cfa-cf46eb8b.mindir) | | | ResNet-50 | en | IC15 | 85.00 | 21.69 | (1,3,736,1280) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/db_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50-c3a4aa24.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50-c3a4aa24-fbf95c82.mindir) | | | ResNet-50 | ch + en | 12个数据集 | 83.41 | 21.69 | (1,3,736,1280) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/db_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_ch_en_general-a5dbb141.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_ch_en_general-a5dbb141-912f0a90.mindir) | -| [DBNet++](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet) | ResNet-50 | en | IC15 | 86.79 | 8.46 | (1,3,1152,2048) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/db++_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2-9934aff0.mindir) | -| | ResNet-50 | ch + en | 12个数据集 | 84.30 | 8.46 | (1,3,1152,2048) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/db++_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_ch_en_general-884ba5b9.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_ch_en_general-884ba5b9-b3f52398.mindir) | +| [DBNet++](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet) | ResNet-50 | en | IC15 | 86.79 | 8.46 | (1,3,1152,2048) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/dbnetpp_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2-9934aff0.mindir) | +| | ResNet-50 | ch + en | 12个数据集 | 84.30 | 8.46 | (1,3,1152,2048) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/dbnetpp_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_ch_en_general-884ba5b9.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_ch_en_general-884ba5b9-b3f52398.mindir) | | [EAST](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/east) | ResNet-50 | en | IC15 | 86.86 | 6.72 | (1,3,720,1280) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/east/east_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/east/east_resnet50_ic15-7262e359.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/east/east_resnet50_ic15-7262e359-5f05cd42.mindir) | | | MobileNetV3 | en | IC15 | 75.32 | 26.77 | (1,3,720,1280) | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/det/east/east_mobilenetv3_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/east/east_mobilenetv3_ic15-4288dba1.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/east/east_mobilenetv3_ic15-4288dba1-5bf242c5.mindir) | | [PSENet](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/psenet) | ResNet-152 | en | IC15 | 82.50 | 2.52 | (1,3,1472,2624) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/psenet/pse_r152_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/psenet/psenet_resnet152_ic15-6058a798.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/psenet/psenet_resnet152_ic15-6058a798-0d755205.mindir) | diff --git a/docs/en/inference/inference_quickstart.md b/docs/en/inference/inference_quickstart.md index 75ed980a1..f0d026f8d 100644 --- a/docs/en/inference/inference_quickstart.md +++ b/docs/en/inference/inference_quickstart.md @@ -8,8 +8,8 @@ | | ResNet-18 | en | IC15 | 81.73 | 24.04 | (1,3,736,1280) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/db_r18_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18-0c0c4cfa.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18-0c0c4cfa-cf46eb8b.mindir) | | | ResNet-50 | en | IC15 | 85.00 | 21.69 | (1,3,736,1280) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/db_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50-c3a4aa24.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50-c3a4aa24-fbf95c82.mindir) | | | ResNet-50 | ch + en | 12 Datasets | 83.41 | 21.69 | (1,3,736,1280) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/db_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_ch_en_general-a5dbb141.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_ch_en_general-a5dbb141-912f0a90.mindir) | -| [DBNet++](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet) | ResNet-50 | en | IC15 | 86.79 | 8.46 | (1,3,1152,2048) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/db++_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2-9934aff0.mindir) | -| | ResNet-50 | ch + en | 12 Datasets | 84.30 | 8.46 | (1,3,1152,2048) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/db++_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_ch_en_general-884ba5b9.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_ch_en_general-884ba5b9-b3f52398.mindir) | +| [DBNet++](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet) | ResNet-50 | en | IC15 | 86.79 | 8.46 | (1,3,1152,2048) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/dbnetpp_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2-9934aff0.mindir) | +| | ResNet-50 | ch + en | 12 Datasets | 84.30 | 8.46 | (1,3,1152,2048) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/dbnetpp_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_ch_en_general-884ba5b9.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_ch_en_general-884ba5b9-b3f52398.mindir) | | [EAST](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/east) | ResNet-50 | en | IC15 | 86.86 | 6.72 | (1,3,720,1280) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/east/east_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/east/east_resnet50_ic15-7262e359.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/east/east_resnet50_ic15-7262e359-5f05cd42.mindir) | | | MobileNetV3 | en | IC15 | 75.32 | 26.77 | (1,3,720,1280) | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/det/east/east_mobilenetv3_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/east/east_mobilenetv3_ic15-4288dba1.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/east/east_mobilenetv3_ic15-4288dba1-5bf242c5.mindir) | | [PSENet](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/psenet) | ResNet-152 | en | IC15 | 82.50 | 2.52 | (1,3,1472,2624) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/psenet/pse_r152_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/psenet/psenet_resnet152_ic15-6058a798.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/psenet/psenet_resnet152_ic15-6058a798-0d755205.mindir) | diff --git a/tests/ut/test_models.py b/tests/ut/test_models.py index 371676633..240874a7a 100644 --- a/tests/ut/test_models.py +++ b/tests/ut/test_models.py @@ -15,7 +15,7 @@ all_yamls = [ "configs/det/dbnet/db_r50_icdar15.yaml", - "configs/det/dbnet/db++_r50_icdar15.yaml", + "configs/det/dbnet/dbnetpp_r50_icdar15.yaml", "configs/rec/crnn/crnn_resnet34.yaml", "configs/rec/master/master_resnet31.yaml", "configs/rec/rare/rare_resnet34.yaml", diff --git a/tools/infer/text/predict_from_yaml.py b/tools/infer/text/predict_from_yaml.py index b6d7ccb17..fff5cd0ca 100644 --- a/tools/infer/text/predict_from_yaml.py +++ b/tools/infer/text/predict_from_yaml.py @@ -2,7 +2,7 @@ Inference base on custom yaml Example: - $ python tools/infer/text/predict_from_yaml.py --config configs/det/dbnet/db++_r50_icdar15.yaml + $ python tools/infer/text/predict_from_yaml.py --config configs/det/dbnet/dbnetpp_r50_icdar15.yaml $ python tools/infer/text/predict_from_yaml.py --config configs/rec/crnn/crnn_resnet34.yaml """ import argparse From cbaea0644509982405c92511a909da11e921fd54 Mon Sep 17 00:00:00 2001 From: horcham <690936541@qq.com> Date: Wed, 31 Jan 2024 17:50:16 +0800 Subject: [PATCH 03/12] Fix problems of psenet-ctw1500 training --- README.md | 2 +- README_CN.md | 2 +- deploy/py_infer/src/utils/adapted/mindocr_models.py | 2 +- docs/cn/inference/inference_quickstart.md | 4 ++-- docs/en/inference/inference_quickstart.md | 4 ++-- tests/ut/test_models.py | 2 +- tools/infer/text/predict_from_yaml.py | 2 +- 7 files changed, 9 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index a480ad008..10aa182ec 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ You may adapt it to your task/dataset, for example, by running ```shell # train text detection model DBNet++ on icdar15 dataset -python tools/train.py --config configs/det/dbnet/dbnetpp_r50_icdar15.yaml +python tools/train.py --config configs/det/dbnet/dbpp_r50_icdar15.yaml ``` ```shell diff --git a/README_CN.md b/README_CN.md index a0b0b977c..23c30f741 100644 --- a/README_CN.md +++ b/README_CN.md @@ -108,7 +108,7 @@ MindOCR在`configs`文件夹中提供系列SoTA的OCR模型及其训练策略, ```shell # train text detection model DBNet++ on icdar15 dataset -python tools/train.py --config configs/det/dbnet/dbnetpp_r50_icdar15.yaml +python tools/train.py --config configs/det/dbnet/dbpp_r50_icdar15.yaml ``` ```shell # train text recognition model CRNN on icdar15 dataset diff --git a/deploy/py_infer/src/utils/adapted/mindocr_models.py b/deploy/py_infer/src/utils/adapted/mindocr_models.py index b9c080862..2858d7435 100644 --- a/deploy/py_infer/src/utils/adapted/mindocr_models.py +++ b/deploy/py_infer/src/utils/adapted/mindocr_models.py @@ -4,7 +4,7 @@ MINDOCR_MODELS = { "en_ms_det_dbnet_resnet50": "det/dbnet/db_r50_icdar15.yaml", - "en_ms_det_dbnetpp_resnet50": "det/dbnet/dbnetpp_r50_icdar15.yaml", + "en_ms_det_dbnetpp_resnet50": "det/dbnet/dbpp_r50_icdar15.yaml", "en_ms_det_psenet_resnet152": "det/psenet/pse_r152_icdar15.yaml", "en_ms_det_psenet_resnet50": "det/psenet/pse_r50_icdar15.yaml", "en_ms_det_psenet_mobilenetv3": "det/psenet/pse_mv3_icdar15.yaml", diff --git a/docs/cn/inference/inference_quickstart.md b/docs/cn/inference/inference_quickstart.md index 18ee2c6a5..007232333 100644 --- a/docs/cn/inference/inference_quickstart.md +++ b/docs/cn/inference/inference_quickstart.md @@ -8,8 +8,8 @@ | | ResNet-18 | en | IC15 | 81.73 | 24.04 | (1,3,736,1280) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/db_r18_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18-0c0c4cfa.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18-0c0c4cfa-cf46eb8b.mindir) | | | ResNet-50 | en | IC15 | 85.00 | 21.69 | (1,3,736,1280) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/db_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50-c3a4aa24.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50-c3a4aa24-fbf95c82.mindir) | | | ResNet-50 | ch + en | 12个数据集 | 83.41 | 21.69 | (1,3,736,1280) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/db_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_ch_en_general-a5dbb141.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_ch_en_general-a5dbb141-912f0a90.mindir) | -| [DBNet++](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet) | ResNet-50 | en | IC15 | 86.79 | 8.46 | (1,3,1152,2048) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/dbnetpp_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2-9934aff0.mindir) | -| | ResNet-50 | ch + en | 12个数据集 | 84.30 | 8.46 | (1,3,1152,2048) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/dbnetpp_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_ch_en_general-884ba5b9.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_ch_en_general-884ba5b9-b3f52398.mindir) | +| [DBNet++](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet) | ResNet-50 | en | IC15 | 86.79 | 8.46 | (1,3,1152,2048) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/dbpp_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2-9934aff0.mindir) | +| | ResNet-50 | ch + en | 12个数据集 | 84.30 | 8.46 | (1,3,1152,2048) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/dbpp_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_ch_en_general-884ba5b9.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_ch_en_general-884ba5b9-b3f52398.mindir) | | [EAST](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/east) | ResNet-50 | en | IC15 | 86.86 | 6.72 | (1,3,720,1280) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/east/east_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/east/east_resnet50_ic15-7262e359.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/east/east_resnet50_ic15-7262e359-5f05cd42.mindir) | | | MobileNetV3 | en | IC15 | 75.32 | 26.77 | (1,3,720,1280) | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/det/east/east_mobilenetv3_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/east/east_mobilenetv3_ic15-4288dba1.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/east/east_mobilenetv3_ic15-4288dba1-5bf242c5.mindir) | | [PSENet](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/psenet) | ResNet-152 | en | IC15 | 82.50 | 2.52 | (1,3,1472,2624) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/psenet/pse_r152_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/psenet/psenet_resnet152_ic15-6058a798.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/psenet/psenet_resnet152_ic15-6058a798-0d755205.mindir) | diff --git a/docs/en/inference/inference_quickstart.md b/docs/en/inference/inference_quickstart.md index f0d026f8d..6652a2df6 100644 --- a/docs/en/inference/inference_quickstart.md +++ b/docs/en/inference/inference_quickstart.md @@ -8,8 +8,8 @@ | | ResNet-18 | en | IC15 | 81.73 | 24.04 | (1,3,736,1280) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/db_r18_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18-0c0c4cfa.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet18-0c0c4cfa-cf46eb8b.mindir) | | | ResNet-50 | en | IC15 | 85.00 | 21.69 | (1,3,736,1280) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/db_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50-c3a4aa24.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50-c3a4aa24-fbf95c82.mindir) | | | ResNet-50 | ch + en | 12 Datasets | 83.41 | 21.69 | (1,3,736,1280) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/db_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_ch_en_general-a5dbb141.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_ch_en_general-a5dbb141-912f0a90.mindir) | -| [DBNet++](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet) | ResNet-50 | en | IC15 | 86.79 | 8.46 | (1,3,1152,2048) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/dbnetpp_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2-9934aff0.mindir) | -| | ResNet-50 | ch + en | 12 Datasets | 84.30 | 8.46 | (1,3,1152,2048) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/dbnetpp_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_ch_en_general-884ba5b9.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_ch_en_general-884ba5b9-b3f52398.mindir) | +| [DBNet++](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet) | ResNet-50 | en | IC15 | 86.79 | 8.46 | (1,3,1152,2048) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/dbpp_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50-068166c2-9934aff0.mindir) | +| | ResNet-50 | ch + en | 12 Datasets | 84.30 | 8.46 | (1,3,1152,2048) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/dbnet/dbpp_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_ch_en_general-884ba5b9.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnetpp_resnet50_ch_en_general-884ba5b9-b3f52398.mindir) | | [EAST](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/east) | ResNet-50 | en | IC15 | 86.86 | 6.72 | (1,3,720,1280) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/east/east_r50_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/east/east_resnet50_ic15-7262e359.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/east/east_resnet50_ic15-7262e359-5f05cd42.mindir) | | | MobileNetV3 | en | IC15 | 75.32 | 26.77 | (1,3,720,1280) | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/det/east/east_mobilenetv3_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/east/east_mobilenetv3_ic15-4288dba1.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/east/east_mobilenetv3_ic15-4288dba1-5bf242c5.mindir) | | [PSENet](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/psenet) | ResNet-152 | en | IC15 | 82.50 | 2.52 | (1,3,1472,2624) | [yaml](https://github.com/mindspore-lab/mindocr/tree/main/configs/det/psenet/pse_r152_icdar15.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/psenet/psenet_resnet152_ic15-6058a798.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/psenet/psenet_resnet152_ic15-6058a798-0d755205.mindir) | diff --git a/tests/ut/test_models.py b/tests/ut/test_models.py index 240874a7a..88a2fe576 100644 --- a/tests/ut/test_models.py +++ b/tests/ut/test_models.py @@ -15,7 +15,7 @@ all_yamls = [ "configs/det/dbnet/db_r50_icdar15.yaml", - "configs/det/dbnet/dbnetpp_r50_icdar15.yaml", + "configs/det/dbnet/dbpp_r50_icdar15.yaml", "configs/rec/crnn/crnn_resnet34.yaml", "configs/rec/master/master_resnet31.yaml", "configs/rec/rare/rare_resnet34.yaml", diff --git a/tools/infer/text/predict_from_yaml.py b/tools/infer/text/predict_from_yaml.py index fff5cd0ca..03706cb77 100644 --- a/tools/infer/text/predict_from_yaml.py +++ b/tools/infer/text/predict_from_yaml.py @@ -2,7 +2,7 @@ Inference base on custom yaml Example: - $ python tools/infer/text/predict_from_yaml.py --config configs/det/dbnet/dbnetpp_r50_icdar15.yaml + $ python tools/infer/text/predict_from_yaml.py --config configs/det/dbnet/dbpp_r50_icdar15.yaml $ python tools/infer/text/predict_from_yaml.py --config configs/rec/crnn/crnn_resnet34.yaml """ import argparse From e38bacb9acdcec709d1f8ac0fd87ab706e17ee93 Mon Sep 17 00:00:00 2001 From: horcham <690936541@qq.com> Date: Mon, 18 Mar 2024 15:28:26 +0800 Subject: [PATCH 04/12] for export --- deploy/py_infer/src/core/model/model.py | 4 +-- mindocr/losses/det_loss.py | 11 +++++- mindocr/losses/rec_loss.py | 15 ++++++-- mindocr/models/necks/fpn.py | 34 +++++++++++++------ .../transforms/tps_spatial_transformer.py | 14 ++++++-- mindocr/models/utils/attention_cells.py | 14 ++++++-- 6 files changed, 72 insertions(+), 20 deletions(-) diff --git a/deploy/py_infer/src/core/model/model.py b/deploy/py_infer/src/core/model/model.py index 9a809533f..fa23d8da1 100644 --- a/deploy/py_infer/src/core/model/model.py +++ b/deploy/py_infer/src/core/model/model.py @@ -106,8 +106,8 @@ def warmup(self): height, width = hw_list[0] warmup_shape = [(*other_shape, height, width)] # Only single input - dummy_tensor = [np.random.randn(*shape).astype(dtype) for shape, dtype in zip(warmup_shape, self.input_dtype)] - self.model.infer(dummy_tensor) + # dummy_tensor = [np.random.randn(*shape).astype(dtype) for shape, dtype in zip(warmup_shape, self.input_dtype)] + # self.model.infer(dummy_tensor) def __del__(self): if hasattr(self, "model") and self.model: diff --git a/mindocr/losses/det_loss.py b/mindocr/losses/det_loss.py index 23ca8f4e2..cc97a3210 100644 --- a/mindocr/losses/det_loss.py +++ b/mindocr/losses/det_loss.py @@ -1,4 +1,5 @@ import logging +import os from math import pi from typing import Tuple, Union @@ -10,6 +11,8 @@ __all__ = ["DBLoss", "PSEDiceLoss", "EASTLoss", "FCELoss"] _logger = logging.getLogger(__name__) +OFFLINE_MODE = os.getenv("OFFLINE_MODE", None) + class DBLoss(nn.LossBase): """ @@ -165,7 +168,13 @@ def construct(self, pred: Tensor, gt: Tensor, mask: Tensor) -> Tensor: neg_loss = (loss * negative).view(loss.shape[0], -1) neg_vals, _ = ops.sort(neg_loss) - neg_index = ops.stack((mnp.arange(loss.shape[0]), neg_vals.shape[1] - neg_count), axis=1) + + if OFFLINE_MODE is None: + neg_index = ops.stack((mnp.arange(loss.shape[0]), neg_vals.shape[1] - neg_count), axis=1) + else: + neg_index = ops.stack( + (ops.arange(loss.shape[0], dtype=neg_count.dtype), neg_vals.shape[1] - neg_count), axis=1 + ) min_neg_score = ops.expand_dims(ops.gather_nd(neg_vals, neg_index), axis=1) neg_loss_mask = (neg_loss >= min_neg_score).astype(ms.float32) # filter values less than top k diff --git a/mindocr/losses/rec_loss.py b/mindocr/losses/rec_loss.py index 09ee8caec..88cb4e7f5 100644 --- a/mindocr/losses/rec_loss.py +++ b/mindocr/losses/rec_loss.py @@ -1,3 +1,5 @@ +import os + import numpy as np import mindspore as ms @@ -6,6 +8,8 @@ __all__ = ["CTCLoss", "AttentionLoss", "VisionLANLoss"] +OFFLINE_MODE = os.getenv("OFFLINE_MODE", None) + class CTCLoss(LossBase): """ @@ -147,14 +151,21 @@ class AttentionLoss(LossBase): def __init__(self, reduction: str = "mean", ignore_index: int = 0) -> None: super().__init__() # ignore symbol, assume it is placed at 0th index - self.criterion = nn.CrossEntropyLoss(reduction=reduction, ignore_index=ignore_index) + if OFFLINE_MODE is None: + self.criterion = nn.CrossEntropyLoss(reduction=reduction, ignore_index=ignore_index) + else: + self.reduction = reduction + self.ignore_index = ignore_index def construct(self, logits: Tensor, labels: Tensor) -> Tensor: labels = labels[:, 1:] # without symbol num_classes = logits.shape[-1] logits = ops.reshape(logits, (-1, num_classes)) labels = ops.reshape(labels, (-1,)) - return self.criterion(logits, labels) + if OFFLINE_MODE is None: + return self.criterion(logits, labels) + else: + return ops.cross_entropy(logits, labels, reduction=self.reduction, ignore_index=self.ignore_index) class SARLoss(LossBase): diff --git a/mindocr/models/necks/fpn.py b/mindocr/models/necks/fpn.py index 650395554..32a628ce8 100644 --- a/mindocr/models/necks/fpn.py +++ b/mindocr/models/necks/fpn.py @@ -1,3 +1,4 @@ +import os from typing import List, Tuple from mindspore import Tensor, nn, ops @@ -7,14 +8,20 @@ from ..utils.attention_cells import SEModule from .asf import AdaptiveScaleFusion +OFFLINE_MODE = os.getenv("OFFLINE_MODE", None) -def _resize_nn(x: Tensor, scale: int = 0, shape: Tuple[int] = None): - if scale == 1 or shape == x.shape[2:]: - return x - if scale: - shape = (x.shape[2] * scale, x.shape[3] * scale) - return ops.ResizeNearestNeighbor(shape)(x) +if OFFLINE_MODE is None: + def _resize_nn(x: Tensor, scale: int = 0, shape: Tuple[int] = None): + if scale == 1 or shape == x.shape[2:]: + return x + + if scale: + shape = (x.shape[2] * scale, x.shape[3] * scale) + return ops.ResizeNearestNeighbor(shape)(x) +else: + def _resize_nn(x: Tensor, shape: Tensor): + return ops.ResizeNearestNeighborV2()(x, shape) class FPN(nn.Cell): @@ -64,11 +71,18 @@ def construct(self, features: List[Tensor]) -> Tensor: for i, uc_op in enumerate(self.unify_channels): features[i] = uc_op(features[i]) - for i in range(2, -1, -1): - features[i] += _resize_nn(features[i + 1], shape=features[i].shape[2:]) + if OFFLINE_MODE is None: + for i in range(2, -1, -1): + features[i] += _resize_nn(features[i + 1], shape=features[i].shape[2:]) + + for i, out in enumerate(self.out): + features[i] = _resize_nn(out(features[i]), shape=features[0].shape[2:]) + else: + for i in range(2, -1, -1): + features[i] += _resize_nn(features[i + 1], shape=ops.dyn_shape(features[i])[2:]) - for i, out in enumerate(self.out): - features[i] = _resize_nn(out(features[i]), shape=features[0].shape[2:]) + for i, out in enumerate(self.out): + features[i] = _resize_nn(out(features[i]), shape=ops.dyn_shape(features[0])[2:]) return self.fuse(features[::-1]) # matching the reverse order of the original work diff --git a/mindocr/models/transforms/tps_spatial_transformer.py b/mindocr/models/transforms/tps_spatial_transformer.py index a49736fe3..95d3f6d75 100644 --- a/mindocr/models/transforms/tps_spatial_transformer.py +++ b/mindocr/models/transforms/tps_spatial_transformer.py @@ -1,4 +1,5 @@ import itertools +import os from typing import Optional, Tuple import numpy as np @@ -8,6 +9,8 @@ import mindspore.ops as ops from mindspore import Tensor +OFFLINE_MODE = os.getenv("OFFLINE_MODE", None) + def grid_sample(input: Tensor, grid: Tensor, canvas: Optional[Tensor] = None) -> Tensor: output = ops.grid_sample(input, grid) @@ -111,6 +114,9 @@ def __init__( self.target_coordinate_repr = Tensor(target_coordinate_repr, dtype=ms.float32) self.target_control_points = Tensor(target_control_points, dtype=ms.float32) + if OFFLINE_MODE is not None: + self.matmul = ops.BatchMatMul() + def construct( self, input: Tensor, source_control_points: Tensor ) -> Tuple[Tensor, Tensor]: @@ -118,8 +124,12 @@ def construct( padding_matrix = ops.tile(self.padding_matrix, (batch_size, 1, 1)) Y = ops.concat([source_control_points, padding_matrix], axis=1) - mapping_matrix = ops.matmul(self.inverse_kernel, Y) - source_coordinate = ops.matmul(self.target_coordinate_repr, mapping_matrix) + if OFFLINE_MODE is None: + mapping_matrix = ops.matmul(self.inverse_kernel, Y) + source_coordinate = ops.matmul(self.target_coordinate_repr, mapping_matrix) + else: + mapping_matrix = self.matmul(self.inverse_kernel[None, ...], Y) + source_coordinate = self.matmul(self.target_coordinate_repr[None, ...], mapping_matrix) grid = ops.reshape( source_coordinate, (-1, self.target_height, self.target_width, 2), diff --git a/mindocr/models/utils/attention_cells.py b/mindocr/models/utils/attention_cells.py index 016001085..b3f14dc24 100644 --- a/mindocr/models/utils/attention_cells.py +++ b/mindocr/models/utils/attention_cells.py @@ -1,3 +1,4 @@ +import os from typing import Optional, Tuple import numpy as np @@ -9,6 +10,8 @@ __all__ = ["MultiHeadAttention", "PositionwiseFeedForward", "PositionalEncoding", "SEModule"] +OFFLINE_MODE = os.getenv("OFFLINE_MODE", None) + class MultiHeadAttention(nn.Cell): def __init__( @@ -108,9 +111,14 @@ def __init__( self.pe = Tensor(pe, dtype=ms.float32) def construct(self, input_tensor: Tensor) -> Tensor: - input_tensor = ( - input_tensor + self.pe[:, : input_tensor.shape[1]] - ) # pe 1 5000 512 + if OFFLINE_MODE is None: + input_tensor = ( + input_tensor + self.pe[:, : input_tensor.shape[1]] + ) # pe 1 5000 512 + else: + input_tensor = ( + input_tensor + self.pe[:, : ops.dyn_shape(input_tensor)[1]] + ) # pe 1 5000 512 return self.dropout(input_tensor) From 77d8138b52a83d7e6de00202a01ffa94d92c2ac8 Mon Sep 17 00:00:00 2001 From: Bourn3z Date: Mon, 18 Mar 2024 19:38:08 +0800 Subject: [PATCH 05/12] Add the function of concatenating to crops after detection. --- deploy/py_infer/src/infer_args.py | 3 +++ .../module/detection/det_post_node.py | 25 +++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/deploy/py_infer/src/infer_args.py b/deploy/py_infer/src/infer_args.py index fc7285939..fbc55db16 100644 --- a/deploy/py_infer/src/infer_args.py +++ b/deploy/py_infer/src/infer_args.py @@ -119,6 +119,9 @@ def get_args(): "--show_log", type=str2bool, default=False, required=False, help="Whether show log when inferring." ) parser.add_argument("--save_log_dir", type=str, required=False, help="Log saving dir.") + parser.add_argument( + "--is_concat", type=str2bool, default=False, help="Whether to concatenate crops after the detection." + ) args = parser.parse_args() setup_logger(args) diff --git a/deploy/py_infer/src/parallel/module/detection/det_post_node.py b/deploy/py_infer/src/parallel/module/detection/det_post_node.py index fba6a5abc..47966dc40 100644 --- a/deploy/py_infer/src/parallel/module/detection/det_post_node.py +++ b/deploy/py_infer/src/parallel/module/detection/det_post_node.py @@ -1,3 +1,4 @@ +import cv2 import numpy as np from ....data_process.utils import cv_utils @@ -16,6 +17,28 @@ def init_self_args(self): self.text_detector.init(preprocess=False, model=False, postprocess=True) super().init_self_args() + def concat_crops(self, crops: list): + """ + Concatenates the list of cropped images horizontally after resizing them to have the same height. + + Args: + crops (list): A list of cropped images represented as numpy arrays. + + Returns: + numpy.ndarray: A horizontally concatenated image array. + """ + max_height = max(crop.shape[0] for crop in crops) + resized_crops = [] + for crop in crops: + h, w, c = crop.shape + new_h = max_height + new_w = int((w / h) * new_h) + + resized_img = cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_LINEAR) + resized_crops.append(resized_img) + crops = np.concatenate(resized_crops, axis=1) + return crops + def process(self, input_data): if input_data.skip: self.send_to_next_module(input_data) @@ -39,6 +62,8 @@ def process(self, input_data): for box in infer_res_list: sub_image = cv_utils.crop_box_from_image(image, np.array(box)) sub_image_list.append(sub_image) + if self.is_concat: + sub_image_list = [self.concat_crops(sub_image_list)] input_data.sub_image_list = sub_image_list input_data.data = None From 8c1938b3450e65b0287a04740f5be69308707918 Mon Sep 17 00:00:00 2001 From: horcham <690936541@qq.com> Date: Mon, 18 Mar 2024 20:05:20 +0800 Subject: [PATCH 06/12] fix large npu memory cost --- deploy/py_infer/src/data_process/postprocess/builder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deploy/py_infer/src/data_process/postprocess/builder.py b/deploy/py_infer/src/data_process/postprocess/builder.py index 092f415af..ddb7892a0 100644 --- a/deploy/py_infer/src/data_process/postprocess/builder.py +++ b/deploy/py_infer/src/data_process/postprocess/builder.py @@ -44,6 +44,7 @@ def get_device_status(): def _get_status(): nonlocal status try: + ms.set_context(max_device_memory="0.01GB") status = ms.Tensor([0])[0:].asnumpy()[0] except RuntimeError: status = 1 From b20f0faa43650e2515e18a44d6f20897782a2c8a Mon Sep 17 00:00:00 2001 From: Bourn3z Date: Mon, 18 Mar 2024 19:38:08 +0800 Subject: [PATCH 07/12] Add the function of concatenating to crops after detection. --- deploy/py_infer/src/infer_args.py | 3 +++ .../module/detection/det_post_node.py | 26 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/deploy/py_infer/src/infer_args.py b/deploy/py_infer/src/infer_args.py index fc7285939..fbc55db16 100644 --- a/deploy/py_infer/src/infer_args.py +++ b/deploy/py_infer/src/infer_args.py @@ -119,6 +119,9 @@ def get_args(): "--show_log", type=str2bool, default=False, required=False, help="Whether show log when inferring." ) parser.add_argument("--save_log_dir", type=str, required=False, help="Log saving dir.") + parser.add_argument( + "--is_concat", type=str2bool, default=False, help="Whether to concatenate crops after the detection." + ) args = parser.parse_args() setup_logger(args) diff --git a/deploy/py_infer/src/parallel/module/detection/det_post_node.py b/deploy/py_infer/src/parallel/module/detection/det_post_node.py index fba6a5abc..c91ec876a 100644 --- a/deploy/py_infer/src/parallel/module/detection/det_post_node.py +++ b/deploy/py_infer/src/parallel/module/detection/det_post_node.py @@ -1,3 +1,4 @@ +import cv2 import numpy as np from ....data_process.utils import cv_utils @@ -10,12 +11,35 @@ def __init__(self, args, msg_queue): super(DetPostNode, self).__init__(args, msg_queue) self.text_detector = None self.task_type = self.args.task_type + self.is_concat = self.args.is_concat def init_self_args(self): self.text_detector = TextDetector(self.args) self.text_detector.init(preprocess=False, model=False, postprocess=True) super().init_self_args() + def concat_crops(self, crops: list): + """ + Concatenates the list of cropped images horizontally after resizing them to have the same height. + + Args: + crops (list): A list of cropped images represented as numpy arrays. + + Returns: + numpy.ndarray: A horizontally concatenated image array. + """ + max_height = max(crop.shape[0] for crop in crops) + resized_crops = [] + for crop in crops: + h, w, c = crop.shape + new_h = max_height + new_w = int((w / h) * new_h) + + resized_img = cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_LINEAR) + resized_crops.append(resized_img) + crops = np.concatenate(resized_crops, axis=1) + return crops + def process(self, input_data): if input_data.skip: self.send_to_next_module(input_data) @@ -39,6 +63,8 @@ def process(self, input_data): for box in infer_res_list: sub_image = cv_utils.crop_box_from_image(image, np.array(box)) sub_image_list.append(sub_image) + if self.is_concat: + sub_image_list = [self.concat_crops(sub_image_list)] input_data.sub_image_list = sub_image_list input_data.data = None From 56dab202f133c842828dbd57f4615de0f57f03db Mon Sep 17 00:00:00 2001 From: Bourn3z Date: Mon, 18 Mar 2024 19:38:08 +0800 Subject: [PATCH 08/12] Add the function of concatenating to crops after detection. --- deploy/py_infer/src/infer_args.py | 3 +++ .../module/detection/det_post_node.py | 27 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/deploy/py_infer/src/infer_args.py b/deploy/py_infer/src/infer_args.py index fc7285939..fbc55db16 100644 --- a/deploy/py_infer/src/infer_args.py +++ b/deploy/py_infer/src/infer_args.py @@ -119,6 +119,9 @@ def get_args(): "--show_log", type=str2bool, default=False, required=False, help="Whether show log when inferring." ) parser.add_argument("--save_log_dir", type=str, required=False, help="Log saving dir.") + parser.add_argument( + "--is_concat", type=str2bool, default=False, help="Whether to concatenate crops after the detection." + ) args = parser.parse_args() setup_logger(args) diff --git a/deploy/py_infer/src/parallel/module/detection/det_post_node.py b/deploy/py_infer/src/parallel/module/detection/det_post_node.py index fba6a5abc..aed886e44 100644 --- a/deploy/py_infer/src/parallel/module/detection/det_post_node.py +++ b/deploy/py_infer/src/parallel/module/detection/det_post_node.py @@ -1,3 +1,4 @@ +import cv2 import numpy as np from ....data_process.utils import cv_utils @@ -10,12 +11,36 @@ def __init__(self, args, msg_queue): super(DetPostNode, self).__init__(args, msg_queue) self.text_detector = None self.task_type = self.args.task_type + self.is_concat = self.args.is_concat def init_self_args(self): self.text_detector = TextDetector(self.args) self.text_detector.init(preprocess=False, model=False, postprocess=True) super().init_self_args() + def concat_crops(self, crops: list): + """ + Concatenates the list of cropped images horizontally after resizing them to have the same height. + + Args: + crops (list): A list of cropped images represented as numpy arrays. + + Returns: + numpy.ndarray: A horizontally concatenated image array. + """ + crops_sorted = sorted(crops, key=lambda points: (points[0][1], points[0][0])) + max_height = max(crop.shape[0] for crop in crops_sorted) + resized_crops = [] + for crop in crops_sorted: + h, w, c = crop.shape + new_h = max_height + new_w = int((w / h) * new_h) + + resized_img = cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_LINEAR) + resized_crops.append(resized_img) + crops_concated = np.concatenate(resized_crops, axis=1) + return crops_concated + def process(self, input_data): if input_data.skip: self.send_to_next_module(input_data) @@ -39,6 +64,8 @@ def process(self, input_data): for box in infer_res_list: sub_image = cv_utils.crop_box_from_image(image, np.array(box)) sub_image_list.append(sub_image) + if self.is_concat: + sub_image_list = [self.concat_crops(sub_image_list)] input_data.sub_image_list = sub_image_list input_data.data = None From 855badeadd70056e75ebf1c03552b4250efb158f Mon Sep 17 00:00:00 2001 From: Bourn3z Date: Mon, 18 Mar 2024 19:38:08 +0800 Subject: [PATCH 09/12] Add the function of concatenating to crops after detection. --- deploy/py_infer/src/infer_args.py | 3 ++ .../module/detection/det_post_node.py | 28 +++++++++++++++++++ .../module/recognition/rec_post_node.py | 11 ++++++-- 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/deploy/py_infer/src/infer_args.py b/deploy/py_infer/src/infer_args.py index fc7285939..fbc55db16 100644 --- a/deploy/py_infer/src/infer_args.py +++ b/deploy/py_infer/src/infer_args.py @@ -119,6 +119,9 @@ def get_args(): "--show_log", type=str2bool, default=False, required=False, help="Whether show log when inferring." ) parser.add_argument("--save_log_dir", type=str, required=False, help="Log saving dir.") + parser.add_argument( + "--is_concat", type=str2bool, default=False, help="Whether to concatenate crops after the detection." + ) args = parser.parse_args() setup_logger(args) diff --git a/deploy/py_infer/src/parallel/module/detection/det_post_node.py b/deploy/py_infer/src/parallel/module/detection/det_post_node.py index fba6a5abc..c7f87ef83 100644 --- a/deploy/py_infer/src/parallel/module/detection/det_post_node.py +++ b/deploy/py_infer/src/parallel/module/detection/det_post_node.py @@ -1,3 +1,4 @@ +import cv2 import numpy as np from ....data_process.utils import cv_utils @@ -10,12 +11,35 @@ def __init__(self, args, msg_queue): super(DetPostNode, self).__init__(args, msg_queue) self.text_detector = None self.task_type = self.args.task_type + self.is_concat = self.args.is_concat def init_self_args(self): self.text_detector = TextDetector(self.args) self.text_detector.init(preprocess=False, model=False, postprocess=True) super().init_self_args() + def concat_crops(self, crops: list): + """ + Concatenates the list of cropped images horizontally after resizing them to have the same height. + + Args: + crops (list): A list of cropped images represented as numpy arrays. + + Returns: + numpy.ndarray: A horizontally concatenated image array. + """ + max_height = max(crop.shape[0] for crop in crops) + resized_crops = [] + for crop in crops: + h, w, c = crop.shape + new_h = max_height + new_w = int((w / h) * new_h) + + resized_img = cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_LINEAR) + resized_crops.append(resized_img) + crops_concated = np.concatenate(resized_crops, axis=1) + return crops_concated + def process(self, input_data): if input_data.skip: self.send_to_next_module(input_data) @@ -23,6 +47,8 @@ def process(self, input_data): data = input_data.data boxes = self.text_detector.postprocess(data["pred"], data["shape_list"]) + if self.is_concat: + boxes = sorted(boxes, key=lambda points: (points[0][1], points[0][0])) infer_res_list = [] for box in boxes: @@ -39,6 +65,8 @@ def process(self, input_data): for box in infer_res_list: sub_image = cv_utils.crop_box_from_image(image, np.array(box)) sub_image_list.append(sub_image) + if self.is_concat: + sub_image_list = [self.concat_crops(sub_image_list)] input_data.sub_image_list = sub_image_list input_data.data = None diff --git a/deploy/py_infer/src/parallel/module/recognition/rec_post_node.py b/deploy/py_infer/src/parallel/module/recognition/rec_post_node.py index 918cc56ba..5332a579c 100644 --- a/deploy/py_infer/src/parallel/module/recognition/rec_post_node.py +++ b/deploy/py_infer/src/parallel/module/recognition/rec_post_node.py @@ -7,6 +7,7 @@ def __init__(self, args, msg_queue): super(RecPostNode, self).__init__(args, msg_queue) self.text_recognizer = None self.task_type = self.args.task_type + self.is_concat = self.args.is_concat def init_self_args(self): self.text_recognizer = TextRecognizer(self.args) @@ -28,9 +29,13 @@ def process(self, input_data): else: texts = output["texts"] confs = output["confs"] - for result, text, conf in zip(input_data.infer_result, texts, confs): - result.append(text) - result.append(conf) + for i, result in enumerate(input_data.infer_result): + if self.is_concat: + result.append(texts[0]) + result.append(confs[0]) + else: + result.append(texts[i]) + result.append(confs[i]) input_data.data = None From 27505814b758757915db86b4b732e2975e046f7c Mon Sep 17 00:00:00 2001 From: Bourn3z Date: Mon, 18 Mar 2024 19:38:08 +0800 Subject: [PATCH 10/12] Add the function of concatenating to crops after detection. --- deploy/py_infer/src/infer_args.py | 3 ++ .../module/detection/det_post_node.py | 28 +++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/deploy/py_infer/src/infer_args.py b/deploy/py_infer/src/infer_args.py index fc7285939..fbc55db16 100644 --- a/deploy/py_infer/src/infer_args.py +++ b/deploy/py_infer/src/infer_args.py @@ -119,6 +119,9 @@ def get_args(): "--show_log", type=str2bool, default=False, required=False, help="Whether show log when inferring." ) parser.add_argument("--save_log_dir", type=str, required=False, help="Log saving dir.") + parser.add_argument( + "--is_concat", type=str2bool, default=False, help="Whether to concatenate crops after the detection." + ) args = parser.parse_args() setup_logger(args) diff --git a/deploy/py_infer/src/parallel/module/detection/det_post_node.py b/deploy/py_infer/src/parallel/module/detection/det_post_node.py index fba6a5abc..18d4ce5e0 100644 --- a/deploy/py_infer/src/parallel/module/detection/det_post_node.py +++ b/deploy/py_infer/src/parallel/module/detection/det_post_node.py @@ -1,3 +1,4 @@ +import cv2 import numpy as np from ....data_process.utils import cv_utils @@ -10,12 +11,35 @@ def __init__(self, args, msg_queue): super(DetPostNode, self).__init__(args, msg_queue) self.text_detector = None self.task_type = self.args.task_type + self.is_concat = self.args.is_concat def init_self_args(self): self.text_detector = TextDetector(self.args) self.text_detector.init(preprocess=False, model=False, postprocess=True) super().init_self_args() + def concat_crops(self, crops: list): + """ + Concatenates the list of cropped images horizontally after resizing them to have the same height. + + Args: + crops (list): A list of cropped images represented as numpy arrays. + + Returns: + numpy.ndarray: A horizontally concatenated image array. + """ + max_height = max(crop.shape[0] for crop in crops) + resized_crops = [] + for crop in crops: + h, w, c = crop.shape + new_h = max_height + new_w = int((w / h) * new_h) + + resized_img = cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_LINEAR) + resized_crops.append(resized_img) + crops_concated = np.concatenate(resized_crops, axis=1) + return crops_concated + def process(self, input_data): if input_data.skip: self.send_to_next_module(input_data) @@ -23,6 +47,8 @@ def process(self, input_data): data = input_data.data boxes = self.text_detector.postprocess(data["pred"], data["shape_list"]) + if self.is_concat: + boxes = sorted(boxes, key=lambda points: (points[0][1], points[0][0])) infer_res_list = [] for box in boxes: @@ -39,6 +65,8 @@ def process(self, input_data): for box in infer_res_list: sub_image = cv_utils.crop_box_from_image(image, np.array(box)) sub_image_list.append(sub_image) + if self.is_concat: + sub_image_list = len(sub_image_list) * [self.concat_crops(sub_image_list)] input_data.sub_image_list = sub_image_list input_data.data = None From d654c1309239697207151886fa695ce6a1c15f5a Mon Sep 17 00:00:00 2001 From: Bourn3z Date: Mon, 18 Mar 2024 19:38:08 +0800 Subject: [PATCH 11/12] Add the function of concatenating to crops after detection. --- deploy/py_infer/src/infer_args.py | 3 ++ .../module/detection/det_post_node.py | 28 +++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/deploy/py_infer/src/infer_args.py b/deploy/py_infer/src/infer_args.py index fc7285939..fbc55db16 100644 --- a/deploy/py_infer/src/infer_args.py +++ b/deploy/py_infer/src/infer_args.py @@ -119,6 +119,9 @@ def get_args(): "--show_log", type=str2bool, default=False, required=False, help="Whether show log when inferring." ) parser.add_argument("--save_log_dir", type=str, required=False, help="Log saving dir.") + parser.add_argument( + "--is_concat", type=str2bool, default=False, help="Whether to concatenate crops after the detection." + ) args = parser.parse_args() setup_logger(args) diff --git a/deploy/py_infer/src/parallel/module/detection/det_post_node.py b/deploy/py_infer/src/parallel/module/detection/det_post_node.py index fba6a5abc..18d4ce5e0 100644 --- a/deploy/py_infer/src/parallel/module/detection/det_post_node.py +++ b/deploy/py_infer/src/parallel/module/detection/det_post_node.py @@ -1,3 +1,4 @@ +import cv2 import numpy as np from ....data_process.utils import cv_utils @@ -10,12 +11,35 @@ def __init__(self, args, msg_queue): super(DetPostNode, self).__init__(args, msg_queue) self.text_detector = None self.task_type = self.args.task_type + self.is_concat = self.args.is_concat def init_self_args(self): self.text_detector = TextDetector(self.args) self.text_detector.init(preprocess=False, model=False, postprocess=True) super().init_self_args() + def concat_crops(self, crops: list): + """ + Concatenates the list of cropped images horizontally after resizing them to have the same height. + + Args: + crops (list): A list of cropped images represented as numpy arrays. + + Returns: + numpy.ndarray: A horizontally concatenated image array. + """ + max_height = max(crop.shape[0] for crop in crops) + resized_crops = [] + for crop in crops: + h, w, c = crop.shape + new_h = max_height + new_w = int((w / h) * new_h) + + resized_img = cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_LINEAR) + resized_crops.append(resized_img) + crops_concated = np.concatenate(resized_crops, axis=1) + return crops_concated + def process(self, input_data): if input_data.skip: self.send_to_next_module(input_data) @@ -23,6 +47,8 @@ def process(self, input_data): data = input_data.data boxes = self.text_detector.postprocess(data["pred"], data["shape_list"]) + if self.is_concat: + boxes = sorted(boxes, key=lambda points: (points[0][1], points[0][0])) infer_res_list = [] for box in boxes: @@ -39,6 +65,8 @@ def process(self, input_data): for box in infer_res_list: sub_image = cv_utils.crop_box_from_image(image, np.array(box)) sub_image_list.append(sub_image) + if self.is_concat: + sub_image_list = len(sub_image_list) * [self.concat_crops(sub_image_list)] input_data.sub_image_list = sub_image_list input_data.data = None From 258da172e592e5c335f15134d8de724cbeeca6cc Mon Sep 17 00:00:00 2001 From: Bourn3z Date: Mon, 18 Mar 2024 19:38:08 +0800 Subject: [PATCH 12/12] Add the function of concatenating to crops after detection. --- deploy/py_infer/src/infer_args.py | 3 ++ .../module/detection/det_post_node.py | 28 +++++++++++++++++++ .../module/recognition/rec_post_node.py | 6 ++-- 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/deploy/py_infer/src/infer_args.py b/deploy/py_infer/src/infer_args.py index fc7285939..fbc55db16 100644 --- a/deploy/py_infer/src/infer_args.py +++ b/deploy/py_infer/src/infer_args.py @@ -119,6 +119,9 @@ def get_args(): "--show_log", type=str2bool, default=False, required=False, help="Whether show log when inferring." ) parser.add_argument("--save_log_dir", type=str, required=False, help="Log saving dir.") + parser.add_argument( + "--is_concat", type=str2bool, default=False, help="Whether to concatenate crops after the detection." + ) args = parser.parse_args() setup_logger(args) diff --git a/deploy/py_infer/src/parallel/module/detection/det_post_node.py b/deploy/py_infer/src/parallel/module/detection/det_post_node.py index fba6a5abc..18d4ce5e0 100644 --- a/deploy/py_infer/src/parallel/module/detection/det_post_node.py +++ b/deploy/py_infer/src/parallel/module/detection/det_post_node.py @@ -1,3 +1,4 @@ +import cv2 import numpy as np from ....data_process.utils import cv_utils @@ -10,12 +11,35 @@ def __init__(self, args, msg_queue): super(DetPostNode, self).__init__(args, msg_queue) self.text_detector = None self.task_type = self.args.task_type + self.is_concat = self.args.is_concat def init_self_args(self): self.text_detector = TextDetector(self.args) self.text_detector.init(preprocess=False, model=False, postprocess=True) super().init_self_args() + def concat_crops(self, crops: list): + """ + Concatenates the list of cropped images horizontally after resizing them to have the same height. + + Args: + crops (list): A list of cropped images represented as numpy arrays. + + Returns: + numpy.ndarray: A horizontally concatenated image array. + """ + max_height = max(crop.shape[0] for crop in crops) + resized_crops = [] + for crop in crops: + h, w, c = crop.shape + new_h = max_height + new_w = int((w / h) * new_h) + + resized_img = cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_LINEAR) + resized_crops.append(resized_img) + crops_concated = np.concatenate(resized_crops, axis=1) + return crops_concated + def process(self, input_data): if input_data.skip: self.send_to_next_module(input_data) @@ -23,6 +47,8 @@ def process(self, input_data): data = input_data.data boxes = self.text_detector.postprocess(data["pred"], data["shape_list"]) + if self.is_concat: + boxes = sorted(boxes, key=lambda points: (points[0][1], points[0][0])) infer_res_list = [] for box in boxes: @@ -39,6 +65,8 @@ def process(self, input_data): for box in infer_res_list: sub_image = cv_utils.crop_box_from_image(image, np.array(box)) sub_image_list.append(sub_image) + if self.is_concat: + sub_image_list = len(sub_image_list) * [self.concat_crops(sub_image_list)] input_data.sub_image_list = sub_image_list input_data.data = None diff --git a/deploy/py_infer/src/parallel/module/recognition/rec_post_node.py b/deploy/py_infer/src/parallel/module/recognition/rec_post_node.py index 918cc56ba..e07bfb3cd 100644 --- a/deploy/py_infer/src/parallel/module/recognition/rec_post_node.py +++ b/deploy/py_infer/src/parallel/module/recognition/rec_post_node.py @@ -28,9 +28,9 @@ def process(self, input_data): else: texts = output["texts"] confs = output["confs"] - for result, text, conf in zip(input_data.infer_result, texts, confs): - result.append(text) - result.append(conf) + for results, text, conf in zip(input_data.infer_result, texts, confs): + results.append(text) + results.append(conf) input_data.data = None