|
1 | 1 | # https://github.com/microsoft/table-transformer/blob/main/src/inference.py |
2 | 2 | # https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Table%20Transformer/Using_Table_Transformer_for_table_detection_and_table_structure_recognition.ipynb |
3 | 3 | import logging |
4 | | -import platform |
| 4 | +import os |
5 | 5 | import xml.etree.ElementTree as ET |
6 | 6 | from collections import defaultdict |
7 | 7 | from pathlib import Path |
@@ -56,49 +56,48 @@ def initialize( |
56 | 56 |
|
57 | 57 | def get_tokens(self, x: Image): |
58 | 58 | """Get OCR tokens from either paddleocr or tesseract""" |
59 | | - if platform.machine() == "x86_64": |
60 | | - try: |
61 | | - from unstructured_inference.models import paddle_ocr |
62 | | - |
63 | | - paddle_result = paddle_ocr.load_agent().ocr(np.array(x), cls=True) |
64 | | - |
65 | | - tokens = [] |
66 | | - for idx in range(len(paddle_result)): |
67 | | - res = paddle_result[idx] |
68 | | - for line in res: |
69 | | - xmin = min([i[0] for i in line[0]]) |
70 | | - ymin = min([i[1] for i in line[0]]) |
71 | | - xmax = max([i[0] for i in line[0]]) |
72 | | - ymax = max([i[1] for i in line[0]]) |
73 | | - tokens.append({"bbox": [xmin, ymin, xmax, ymax], "text": line[1][0]}) |
74 | | - return tokens |
75 | | - except ModuleNotFoundError: |
76 | | - logging.warning( |
77 | | - "No module named 'unstructured_paddleocr', falling back to tesseract", |
78 | | - ) |
79 | | - pass |
80 | | - |
81 | | - ocr_df: pd.DataFrame = pytesseract.image_to_data( |
82 | | - x, |
83 | | - output_type="data.frame", |
84 | | - ) |
85 | | - |
86 | | - ocr_df = ocr_df.dropna() |
87 | | - |
88 | | - tokens = [] |
89 | | - for idtx in ocr_df.itertuples(): |
90 | | - tokens.append( |
91 | | - { |
92 | | - "bbox": [ |
93 | | - idtx.left, |
94 | | - idtx.top, |
95 | | - idtx.left + idtx.width, |
96 | | - idtx.top + idtx.height, |
97 | | - ], |
98 | | - "text": idtx.text, |
99 | | - }, |
| 59 | + table_ocr = os.getenv("TABLE_OCR", "tesseract").lower() |
| 60 | + if table_ocr not in ["paddle", "tesseract"]: |
| 61 | + raise ValueError( |
| 62 | + "Environment variable TABLE_OCR must be set to 'tesseract' or 'paddle'.", |
| 63 | + ) |
| 64 | + if table_ocr == "paddle": |
| 65 | + from unstructured_inference.models import paddle_ocr |
| 66 | + |
| 67 | + paddle_result = paddle_ocr.load_agent().ocr(np.array(x), cls=True) |
| 68 | + |
| 69 | + tokens = [] |
| 70 | + for idx in range(len(paddle_result)): |
| 71 | + res = paddle_result[idx] |
| 72 | + for line in res: |
| 73 | + xmin = min([i[0] for i in line[0]]) |
| 74 | + ymin = min([i[1] for i in line[0]]) |
| 75 | + xmax = max([i[0] for i in line[0]]) |
| 76 | + ymax = max([i[1] for i in line[0]]) |
| 77 | + tokens.append({"bbox": [xmin, ymin, xmax, ymax], "text": line[1][0]}) |
| 78 | + return tokens |
| 79 | + else: |
| 80 | + ocr_df: pd.DataFrame = pytesseract.image_to_data( |
| 81 | + x, |
| 82 | + output_type="data.frame", |
100 | 83 | ) |
101 | | - return tokens |
| 84 | + |
| 85 | + ocr_df = ocr_df.dropna() |
| 86 | + |
| 87 | + tokens = [] |
| 88 | + for idtx in ocr_df.itertuples(): |
| 89 | + tokens.append( |
| 90 | + { |
| 91 | + "bbox": [ |
| 92 | + idtx.left, |
| 93 | + idtx.top, |
| 94 | + idtx.left + idtx.width, |
| 95 | + idtx.top + idtx.height, |
| 96 | + ], |
| 97 | + "text": idtx.text, |
| 98 | + }, |
| 99 | + ) |
| 100 | + return tokens |
102 | 101 |
|
103 | 102 | def run_prediction(self, x: Image): |
104 | 103 | """Predict table structure""" |
|
0 commit comments