Skip to content

Commit f986f13

Browse files
authored
feat: add ability to pass languages to OCR agent (#86)
* add language parameter to tesseract * pass language into elements and layout * enable loading multiple language agents * update tests to include ocr languages * test ocr load * changelog and version * version bump for release
1 parent f704f26 commit f986f13

File tree

8 files changed

+81
-24
lines changed

8 files changed

+81
-24
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
## 0.4.1
2+
3+
* Added the ability to pass `ocr_languages` to the OCR agent for users who need
4+
non-English language packs.
5+
16
## 0.4.0
27

38
* Added logic to partition granular elements (words, characters) by proximity

test_unstructured_inference/inference/test_layout.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class MockOCRAgent:
4444
def detect(self, *args):
4545
return mock_text
4646

47-
monkeypatch.setattr(tesseract, "ocr_agent", MockOCRAgent)
47+
monkeypatch.setattr(tesseract, "ocr_agents", {"eng": MockOCRAgent})
4848
monkeypatch.setattr(tesseract, "is_pytesseract_available", lambda *args: True)
4949

5050
image = Image.fromarray(np.random.randint(12, 24, (40, 40)), mode="RGB")
@@ -96,7 +96,7 @@ def test_get_page_elements_with_ocr(monkeypatch):
9696
doc_layout = [text_block, image_block]
9797

9898
monkeypatch.setattr(detectron2, "is_detectron2_available", lambda *args: True)
99-
monkeypatch.setattr(elements, "ocr", lambda *args: "An Even Catchier Title")
99+
monkeypatch.setattr(elements, "ocr", lambda *args, **kwargs: "An Even Catchier Title")
100100

101101
image = Image.fromarray(np.random.randint(12, 14, size=(40, 10, 3)), mode="RGB")
102102
page = layout.PageLayout(
@@ -187,11 +187,19 @@ def points(self):
187187

188188

189189
class MockPageLayout(layout.PageLayout):
190-
def __init__(self, layout=None, model=None, ocr_strategy="auto", extract_tables=False):
190+
def __init__(
191+
self,
192+
layout=None,
193+
model=None,
194+
ocr_strategy="auto",
195+
ocr_languages="eng",
196+
extract_tables=False,
197+
):
191198
self.image = None
192199
self.layout = layout
193200
self.model = model
194201
self.ocr_strategy = ocr_strategy
202+
self.ocr_languages = ocr_languages
195203
self.extract_tables = extract_tables
196204

197205
def ocr(self, text_block: MockEmbeddedTextRegion):

test_unstructured_inference/models/test_tesseract.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@ def __init__(self, languages):
1111

1212
def test_load_agent(monkeypatch):
1313
monkeypatch.setattr(tesseract, "TesseractAgent", MockTesseractAgent)
14-
monkeypatch.setattr(tesseract, "ocr_agent", None)
14+
monkeypatch.setattr(tesseract, "ocr_agents", {})
1515

1616
with patch.object(tesseract, "is_pytesseract_available", return_value=True):
17-
tesseract.load_agent()
17+
tesseract.load_agent(languages="eng+swe")
1818

19-
assert isinstance(tesseract.ocr_agent, MockTesseractAgent)
19+
assert isinstance(tesseract.ocr_agents["eng+swe"], MockTesseractAgent)
2020

2121

2222
def test_load_agent_raises_when_not_available():
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.4.0" # pragma: no cover
1+
__version__ = "0.4.1" # pragma: no cover

unstructured_inference/inference/elements.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def extract_text(
162162
image: Optional[Image.Image] = None,
163163
extract_tables: bool = False,
164164
ocr_strategy: str = "auto",
165+
ocr_languages: str = "eng",
165166
) -> str:
166167
"""Extracts text contained in region."""
167168
if self.text is not None:
@@ -172,7 +173,7 @@ def extract_text(
172173
elif image is not None:
173174
if ocr_strategy != "never":
174175
# We don't have anything to go on but the image itself, so we use OCR
175-
text = ocr(self, image)
176+
text = ocr(self, image, languages=ocr_languages)
176177
else:
177178
text = ""
178179
else:
@@ -190,6 +191,7 @@ def extract_text(
190191
image: Optional[Image.Image] = None,
191192
extract_tables: bool = False,
192193
ocr_strategy: str = "auto",
194+
ocr_languages: str = "eng",
193195
) -> str:
194196
"""Extracts text contained in region."""
195197
if self.text is None:
@@ -205,24 +207,28 @@ def extract_text(
205207
image: Optional[Image.Image] = None,
206208
extract_tables: bool = False,
207209
ocr_strategy: str = "auto",
210+
ocr_languages: str = "eng",
208211
) -> str:
209212
"""Extracts text contained in region."""
210213
if self.text is None:
211214
if ocr_strategy == "never" or image is None:
212215
return ""
213216
else:
214-
return ocr(self, image)
217+
return ocr(self, image, languages=ocr_languages)
215218
else:
216219
return super().extract_text(objects, image, extract_tables, ocr_strategy)
217220

218221

219-
def ocr(text_block: TextRegion, image: Image.Image) -> str:
222+
def ocr(text_block: TextRegion, image: Image.Image, languages: str = "eng") -> str:
220223
"""Runs a cropped text block image through and OCR agent."""
221224
logger.debug("Running OCR on text block ...")
222-
tesseract.load_agent()
225+
tesseract.load_agent(languages=languages)
223226
padded_block = text_block.pad(12)
224227
cropped_image = image.crop((padded_block.x1, padded_block.y1, padded_block.x2, padded_block.y2))
225-
return tesseract.ocr_agent.detect(cropped_image)
228+
agent = tesseract.ocr_agents.get(languages)
229+
if agent is None:
230+
raise RuntimeError("OCR agent is not loaded for {languages}.")
231+
return agent.detect(cropped_image)
226232

227233

228234
def needs_ocr(
@@ -263,16 +269,17 @@ def aggregate_by_block(
263269
image: Optional[Image.Image],
264270
pdf_objects: List[TextRegion],
265271
ocr_strategy: str = "auto",
272+
ocr_languages: str = "eng",
266273
) -> str:
267274
"""Extracts the text aggregated from the elements of the given layout that lie within the given
268275
block."""
269276
if image is not None and needs_ocr(text_region, pdf_objects, ocr_strategy):
270-
text = ocr(text_region, image)
277+
text = ocr(text_region, image, languages=ocr_languages)
271278
else:
272279
filtered_blocks = [obj for obj in pdf_objects if obj.is_in(text_region, error_margin=5)]
273280
for little_block in filtered_blocks:
274281
if image is not None and needs_ocr(little_block, pdf_objects, ocr_strategy):
275-
little_block.text = ocr(little_block, image)
282+
little_block.text = ocr(little_block, image, languages=ocr_languages)
276283
text = " ".join([x.text for x in filtered_blocks if x.text])
277284
text = remove_control_characters(text)
278285
return text

unstructured_inference/inference/layout.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def from_file(
5555
model: Optional[UnstructuredModel] = None,
5656
fixed_layouts: Optional[List[Optional[List[TextRegion]]]] = None,
5757
ocr_strategy: str = "auto",
58+
ocr_languages: str = "eng",
5859
extract_tables: bool = False,
5960
) -> DocumentLayout:
6061
"""Creates a DocumentLayout from a pdf file."""
@@ -75,6 +76,7 @@ def from_file(
7576
model=model,
7677
layout=layout,
7778
ocr_strategy=ocr_strategy,
79+
ocr_languages=ocr_languages,
7880
fixed_layout=fixed_layout,
7981
extract_tables=extract_tables,
8082
)
@@ -87,6 +89,7 @@ def from_image_file(
8789
filename: str,
8890
model: Optional[UnstructuredModel] = None,
8991
ocr_strategy: str = "auto",
92+
ocr_languages: str = "eng",
9093
fixed_layout: Optional[List[TextRegion]] = None,
9194
extract_tables: bool = False,
9295
) -> DocumentLayout:
@@ -104,6 +107,7 @@ def from_image_file(
104107
model=model,
105108
layout=None,
106109
ocr_strategy=ocr_strategy,
110+
ocr_languages=ocr_languages,
107111
fixed_layout=fixed_layout,
108112
extract_tables=extract_tables,
109113
)
@@ -120,6 +124,7 @@ def __init__(
120124
layout: Optional[List[TextRegion]],
121125
model: Optional[UnstructuredModel] = None,
122126
ocr_strategy: str = "auto",
127+
ocr_languages: str = "eng",
123128
extract_tables: bool = False,
124129
):
125130
self.image = image
@@ -131,6 +136,7 @@ def __init__(
131136
if ocr_strategy not in VALID_OCR_STRATEGIES:
132137
raise ValueError(f"ocr_strategy must be one of {VALID_OCR_STRATEGIES}.")
133138
self.ocr_strategy = ocr_strategy
139+
self.ocr_languages = ocr_languages
134140
self.extract_tables = extract_tables
135141

136142
def __str__(self) -> str:
@@ -159,7 +165,12 @@ def get_elements_from_layout(self, layout: List[TextRegion]) -> List[LayoutEleme
159165
layout.sort(key=lambda element: element.y1)
160166
elements = [
161167
get_element_from_block(
162-
e, self.image, self.layout, self.ocr_strategy, self.extract_tables
168+
block=e,
169+
image=self.image,
170+
pdf_objects=self.layout,
171+
ocr_strategy=self.ocr_strategy,
172+
ocr_languages=self.ocr_languages,
173+
extract_tables=self.extract_tables,
163174
)
164175
for e in layout
165176
]
@@ -178,6 +189,7 @@ def from_image(
178189
model: Optional[UnstructuredModel] = None,
179190
layout: Optional[List[TextRegion]] = None,
180191
ocr_strategy: str = "auto",
192+
ocr_languages: str = "eng",
181193
extract_tables: bool = False,
182194
fixed_layout: Optional[List[TextRegion]] = None,
183195
):
@@ -188,6 +200,7 @@ def from_image(
188200
layout=layout,
189201
model=model,
190202
ocr_strategy=ocr_strategy,
203+
ocr_languages=ocr_languages,
191204
extract_tables=extract_tables,
192205
)
193206
if fixed_layout is None:
@@ -202,6 +215,7 @@ def process_data_with_model(
202215
model_name: Optional[str],
203216
is_image: bool = False,
204217
ocr_strategy: str = "auto",
218+
ocr_languages: str = "eng",
205219
fixed_layouts: Optional[List[Optional[List[TextRegion]]]] = None,
206220
extract_tables: bool = False,
207221
) -> DocumentLayout:
@@ -214,6 +228,7 @@ def process_data_with_model(
214228
model_name,
215229
is_image=is_image,
216230
ocr_strategy=ocr_strategy,
231+
ocr_languages=ocr_languages,
217232
fixed_layouts=fixed_layouts,
218233
extract_tables=extract_tables,
219234
)
@@ -226,6 +241,7 @@ def process_file_with_model(
226241
model_name: Optional[str],
227242
is_image: bool = False,
228243
ocr_strategy: str = "auto",
244+
ocr_languages: str = "eng",
229245
fixed_layouts: Optional[List[Optional[List[TextRegion]]]] = None,
230246
extract_tables: bool = False,
231247
) -> DocumentLayout:
@@ -234,13 +250,18 @@ def process_file_with_model(
234250
model = get_model(model_name)
235251
layout = (
236252
DocumentLayout.from_image_file(
237-
filename, model=model, ocr_strategy=ocr_strategy, extract_tables=extract_tables
253+
filename,
254+
model=model,
255+
ocr_strategy=ocr_strategy,
256+
ocr_languages=ocr_languages,
257+
extract_tables=extract_tables,
238258
)
239259
if is_image
240260
else DocumentLayout.from_file(
241261
filename,
242262
model=model,
243263
ocr_strategy=ocr_strategy,
264+
ocr_languages=ocr_languages,
244265
fixed_layouts=fixed_layouts,
245266
extract_tables=extract_tables,
246267
)
@@ -253,13 +274,18 @@ def get_element_from_block(
253274
image: Optional[Image.Image] = None,
254275
pdf_objects: Optional[List[TextRegion]] = None,
255276
ocr_strategy: str = "auto",
277+
ocr_languages: str = "eng",
256278
extract_tables: bool = False,
257279
) -> LayoutElement:
258280
"""Creates a LayoutElement from a given layout or image by finding all the text that lies within
259281
a given block."""
260282
element = LayoutElement.from_region(block)
261283
element.text = block.extract_text(
262-
objects=pdf_objects, image=image, extract_tables=extract_tables, ocr_strategy=ocr_strategy
284+
objects=pdf_objects,
285+
image=image,
286+
extract_tables=extract_tables,
287+
ocr_strategy=ocr_strategy,
288+
ocr_languages=ocr_languages,
263289
)
264290
return element
265291

unstructured_inference/inference/layoutelement.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def extract_text(
1919
image: Optional[Image.Image] = None,
2020
extract_tables: bool = False,
2121
ocr_strategy: str = "auto",
22+
ocr_languages: str = "eng",
2223
):
2324
"""Extracts text contained in region"""
2425
if self.text is not None:
@@ -32,6 +33,7 @@ def extract_text(
3233
image=image,
3334
extract_tables=extract_tables,
3435
ocr_strategy=ocr_strategy,
36+
ocr_languages=ocr_languages,
3537
)
3638
return text
3739

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,29 @@
1+
from typing import Dict
2+
13
from layoutparser.ocr.tesseract_agent import is_pytesseract_available, TesseractAgent
24

35
from unstructured_inference.logger import logger
46

5-
ocr_agent: TesseractAgent = None
7+
ocr_agents: Dict[str, TesseractAgent] = {}
8+
69

10+
def load_agent(languages: str = "eng"):
11+
"""Loads the Tesseract OCR agent as a global variable to ensure that we only load it once.
712
8-
def load_agent():
9-
"""Loads the Tesseract OCR agent as a global variable to ensure that we only load it once."""
10-
global ocr_agent
13+
Parameters
14+
----------
15+
languages
16+
The languages to use for the Tesseract agent. To use a langauge, you'll first need
17+
to isntall the appropriate Tesseract language pack.
18+
"""
19+
global ocr_agents
1120

1221
if not is_pytesseract_available():
1322
raise ImportError(
1423
"Failed to load Tesseract. Ensure that Tesseract is installed. Example command: \n"
1524
" >>> sudo apt install -y tesseract-ocr"
1625
)
1726

18-
if ocr_agent is None:
19-
logger.info("Loading the Tesseract OCR agent ...")
20-
ocr_agent = TesseractAgent(languages="eng")
27+
if languages not in ocr_agents:
28+
logger.info(f"Loading the Tesseract OCR agent for {languages} ...")
29+
ocr_agents[languages] = TesseractAgent(languages=languages)

0 commit comments

Comments
 (0)