Skip to content

Commit 80ee29b

Browse files
authored
[cherry-pick] refine table related method (#3601)
1 parent 5e647ef commit 80ee29b

File tree

6 files changed

+81
-16
lines changed

6 files changed

+81
-16
lines changed

docs/pipeline_usage/tutorials/ocr_pipelines/table_recognition_v2.en.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,13 @@ In the above Python script, the following steps are executed:
893893
</td>
894894
<td><code>None</code></td>
895895
</tr>
896+
<td><code>use_table_cells_ocr_results</code></td>
897+
<td>Whether to enable Table-Cells-OCR mode, when not enabled, use global OCR result to fill to html table, when enabled, do OCR cell by cell and fill to html table. Both of them perform differently in different scenarios, please choose according to the actual situation.</td>
898+
<td><code>bool|False</code></td>
899+
<td>
900+
<ul>
901+
<li><b>bool</b>:<code>True</code> or <code>False</code>
902+
<td><code>False</code></td>
896903
</table>
897904

898905
(3) Process the prediction results, where each sample's prediction result is represented as a corresponding Result object, and supports operations such as printing, saving as an image, saving as an `xlsx` file, saving as an `HTML` file, and saving as a `json` file:

docs/pipeline_usage/tutorials/ocr_pipelines/table_recognition_v2.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,14 @@ for res in output:
895895
<li><b>float</b>:大于 <code>0</code> 的任意浮点数
896896
<li><b>None</b>:如果设置为 <code>None</code>, 将默认使用产线初始化的该参数值 <code>0.0</code>。即不设阈值</li></li></ul></td>
897897
<td><code>None</code></td>
898+
</tr>
899+
<td><code>use_table_cells_ocr_results</code></td>
900+
<td>是否启用单元格OCR模式,不启用时采用全局OCR结果填充至html表格,启用时逐个单元格做OCR并填充至html表格。二者在不同场景下表现不同,请根据实际情况选择。</td>
901+
<td><code>bool|False</code></td>
902+
<td>
903+
<ul>
904+
<li><b>bool</b>:<code>True</code> 或者 <code>False</code>
905+
<td><code>False</code></td>
898906

899907
</tr></table>
900908

paddlex/inference/pipelines/table_recognition/pipeline_v2.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os, sys
1616
from typing import Any, Dict, Optional, Union, List, Tuple
1717
import numpy as np
18+
import math
1819
import cv2
1920
from sklearn.cluster import KMeans
2021
from ..base import BasePipeline
@@ -497,12 +498,40 @@ def combine_rectangles(rectangles, N):
497498
if len(final_results) <= 0.6*html_pred_boxes_nums:
498499
final_results = combine_rectangles(ocr_det_results, html_pred_boxes_nums)
499500
return final_results
501+
502+
def split_ocr_bboxes_by_table_cells(self, ori_img, cells_bboxes):
503+
"""
504+
Splits OCR bounding boxes by table cells and retrieves text.
505+
506+
Args:
507+
ori_img (ndarray): The original image from which text regions will be extracted.
508+
cells_bboxes (list or ndarray): Detected cell bounding boxes to extract text from.
509+
510+
Returns:
511+
list: A list containing the recognized texts from each cell.
512+
"""
513+
514+
# Check if cells_bboxes is a list and convert it if not.
515+
if not isinstance(cells_bboxes, list):
516+
cells_bboxes = cells_bboxes.tolist()
517+
texts_list = [] # Initialize a list to store the recognized texts.
518+
# Process each bounding box provided in cells_bboxes.
519+
for i in range(len(cells_bboxes)):
520+
# Extract and round up the coordinates of the bounding box.
521+
x1, y1, x2, y2 = [math.ceil(k) for k in cells_bboxes[i]]
522+
# Perform OCR on the defined region of the image and get the recognized text.
523+
rec_te = next(self.general_ocr_pipeline(ori_img[y1:y2, x1:x2, :]))
524+
# Concatenate the texts and append them to the texts_list.
525+
texts_list.append(''.join(rec_te["rec_texts"]))
526+
# Return the list of recognized texts from each cell.
527+
return texts_list
500528

501529
def predict_single_table_recognition_res(
502530
self,
503531
image_array: np.ndarray,
504532
overall_ocr_res: OCRResult,
505533
table_box: list,
534+
use_table_cells_ocr_results: bool = False,
506535
flag_find_nei_text: bool = True,
507536
) -> SingleTableRecognitionResult:
508537
"""
@@ -517,6 +546,7 @@ def predict_single_table_recognition_res(
517546
Returns:
518547
SingleTableRecognitionResult: single table recognition result.
519548
"""
549+
520550
table_cls_pred = next(self.table_cls_model(image_array))
521551
table_cls_result = self.extract_results(table_cls_pred, "cls")
522552
if table_cls_result == "wired_table":
@@ -538,8 +568,12 @@ def predict_single_table_recognition_res(
538568
table_cells_result = self.cells_det_results_reprocessing(
539569
table_cells_result, table_cells_score, ocr_det_boxes, len(table_structure_pred['bbox'])
540570
)
571+
if use_table_cells_ocr_results == True:
572+
cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result)
573+
else:
574+
cells_texts_list = []
541575
single_table_recognition_res = get_table_recognition_res(
542-
table_box, table_structure_result, table_cells_result, overall_ocr_res
576+
table_box, table_structure_result, table_cells_result, overall_ocr_res, cells_texts_list, use_table_cells_ocr_results
543577
)
544578
neighbor_text = ""
545579
if flag_find_nei_text:
@@ -567,6 +601,7 @@ def predict(
567601
text_det_box_thresh: Optional[float] = None,
568602
text_det_unclip_ratio: Optional[float] = None,
569603
text_rec_score_thresh: Optional[float] = None,
604+
use_table_cells_ocr_results: Optional[bool] = False,
570605
**kwargs,
571606
) -> TableRecognitionResult:
572607
"""
@@ -638,6 +673,7 @@ def predict(
638673
doc_preprocessor_image,
639674
overall_ocr_res,
640675
table_box,
676+
use_table_cells_ocr_results,
641677
flag_find_nei_text=False,
642678
)
643679
single_table_rec_res["table_region_id"] = table_region_id
@@ -654,7 +690,7 @@ def predict(
654690
table_box = crop_img_info["box"]
655691
single_table_rec_res = (
656692
self.predict_single_table_recognition_res(
657-
crop_img_info["img"], overall_ocr_res, table_box
693+
crop_img_info["img"], overall_ocr_res, table_box, use_table_cells_ocr_results
658694
)
659695
)
660696
single_table_rec_res["table_region_id"] = table_region_id

paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,8 @@ def get_table_recognition_res(
403403
table_structure_result: list,
404404
table_cells_result: list,
405405
overall_ocr_res: OCRResult,
406+
cells_texts_list: list,
407+
use_table_cells_ocr_results: bool,
406408
) -> SingleTableRecognitionResult:
407409
"""
408410
Retrieve table recognition result from cropped image info, table structure prediction, and overall OCR result.
@@ -412,6 +414,8 @@ def get_table_recognition_res(
412414
table_structure_result (list): Predicted table structure.
413415
table_cells_result (list): Predicted table cells.
414416
overall_ocr_res (OCRResult): Overall OCR result from the input image.
417+
cells_texts_list (list): OCR results with cells.
418+
use_table_cells_ocr_results (bool): whether to use OCR results with cells.
415419
416420
Returns:
417421
SingleTableRecognitionResult: An object containing the single table recognition result.
@@ -425,12 +429,29 @@ def get_table_recognition_res(
425429
crop_start_point = [table_box[0][0], table_box[0][1]]
426430
img_shape = overall_ocr_res["doc_preprocessor_res"]["output_img"].shape[0:2]
427431

432+
if len(table_cells_result) == 0 or len(table_ocr_pred["rec_boxes"]) == 0:
433+
pred_html = ' '.join(table_structure_result)
434+
if len(table_cells_result) != 0:
435+
table_cells_result = convert_table_structure_pred_bbox(
436+
table_cells_result, crop_start_point, img_shape
437+
)
438+
single_img_res = {
439+
"cell_box_list": table_cells_result,
440+
"table_ocr_pred": table_ocr_pred,
441+
"pred_html": pred_html,
442+
}
443+
return SingleTableRecognitionResult(single_img_res)
444+
428445
table_cells_result = convert_table_structure_pred_bbox(
429446
table_cells_result, crop_start_point, img_shape
430447
)
431448

432-
ocr_dt_boxes = table_ocr_pred["rec_boxes"]
433-
ocr_texts_res = table_ocr_pred["rec_texts"]
449+
if use_table_cells_ocr_results == False:
450+
ocr_dt_boxes = table_ocr_pred["rec_boxes"]
451+
ocr_texts_res = table_ocr_pred["rec_texts"]
452+
else:
453+
ocr_dt_boxes = table_cells_result
454+
ocr_texts_res = cells_texts_list
434455

435456
table_cells_result, table_cells_flag = sort_table_cells_boxes(table_cells_result)
436457
row_start_index = find_row_start_index(table_structure_result)

paddlex/modules/table_recognition/dataset_checker/dataset_src/check_dataset.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,6 @@ def check(dataset_dir, output, dataset_type="PubTabTableRecDataset", sample_num=
7575
)
7676
sample_paths[tag].append(sample_path)
7777

78-
boxes_num = len(cells)
79-
tokens_num = sum(
80-
[
81-
structure.count(x)
82-
for x in ["<td>", "<td", "<eb></eb>", "<td></td>"]
83-
]
84-
)
85-
if boxes_num != tokens_num:
86-
raise CheckFailedError(
87-
f"The number of cells needs to be consistent with the number of tokens "
88-
"but the number of cells is {boxes_num}, and the number of tokens is {tokens_num}."
89-
)
9078
meta = {}
9179

9280
meta["train_samples"] = sample_cnts["train"]

paddlex/utils/pipeline_arguments.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,11 @@ def validator(cli_input: str) -> cli_expected_type:
195195
},
196196
],
197197
"table_recognition_v2": [
198+
{
199+
"name": "--use_table_cells_ocr_results",
200+
"type": bool,
201+
"help": "Determines whether to use cells OCR results",
202+
},
198203
{
199204
"name": "--use_doc_orientation_classify",
200205
"type": bool,

0 commit comments

Comments
 (0)