Skip to content

Commit cb55245

Browse files
authored
rfctr: extract OCRAgent.get_agent() out of PDF subtree (#2965)
**Summary** File-types other than PDF need to use OCR on extracted images. Extract `OCRAgent.get_agent()` such that any file-type partitioner can use it without risking dependency on PDF-only extras.
1 parent 17c2d07 commit cb55245

File tree

7 files changed

+178
-82
lines changed

7 files changed

+178
-82
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
## 0.13.7-dev5
1+
## 0.13.7-dev6
22

33
### Enhancements
44

55
* **Remove `page_number` metadata fields** for HTML partition until we have a better strategy to decide page counting.
6+
* **Extract OCRAgent.get_agent().** Generalize access to the configured OCRAgent instance beyond its use for PDFs.
67

78
### Features
89

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# pyright: reportPrivateUsage=false
2+
3+
"""Unit-test suite for the `unstructured.partition.utils.ocr_models.ocr_interface` module."""
4+
5+
from __future__ import annotations
6+
7+
import pytest
8+
9+
from test_unstructured.unit_utils import (
10+
FixtureRequest,
11+
LogCaptureFixture,
12+
Mock,
13+
instance_mock,
14+
method_mock,
15+
property_mock,
16+
)
17+
from unstructured.partition.utils.config import ENVConfig
18+
from unstructured.partition.utils.constants import (
19+
OCR_AGENT_PADDLE,
20+
OCR_AGENT_PADDLE_OLD,
21+
OCR_AGENT_TESSERACT,
22+
OCR_AGENT_TESSERACT_OLD,
23+
)
24+
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
25+
26+
27+
class DescribeOCRAgent:
28+
"""Unit-test suite for `unstructured.partition.utils...ocr_interface.OCRAgent` class."""
29+
30+
def it_provides_access_to_the_configured_OCR_agent(
31+
self, _get_ocr_agent_cls_qname_: Mock, get_instance_: Mock, ocr_agent_: Mock
32+
):
33+
_get_ocr_agent_cls_qname_.return_value = OCR_AGENT_TESSERACT
34+
get_instance_.return_value = ocr_agent_
35+
36+
ocr_agent = OCRAgent.get_agent()
37+
38+
_get_ocr_agent_cls_qname_.assert_called_once_with()
39+
get_instance_.assert_called_once_with(OCR_AGENT_TESSERACT)
40+
assert ocr_agent is ocr_agent_
41+
42+
@pytest.mark.parametrize("ExceptionCls", [ImportError, AttributeError])
43+
def but_it_raises_whan_no_such_ocr_agent_class_is_found(
44+
self, ExceptionCls: type, _get_ocr_agent_cls_qname_: Mock, get_instance_: Mock
45+
):
46+
_get_ocr_agent_cls_qname_.return_value = "Invalid.Ocr.Agent.Qname"
47+
get_instance_.side_effect = ExceptionCls
48+
49+
with pytest.raises(ValueError, match="OCR_AGENT must be set to an existing OCR agent "):
50+
OCRAgent.get_agent()
51+
52+
_get_ocr_agent_cls_qname_.assert_called_once_with()
53+
get_instance_.assert_called_once_with("Invalid.Ocr.Agent.Qname")
54+
55+
@pytest.mark.parametrize(
56+
("OCR_AGENT", "expected_value"),
57+
[
58+
(OCR_AGENT_PADDLE, OCR_AGENT_PADDLE),
59+
(OCR_AGENT_PADDLE_OLD, OCR_AGENT_PADDLE),
60+
(OCR_AGENT_TESSERACT, OCR_AGENT_TESSERACT),
61+
(OCR_AGENT_TESSERACT_OLD, OCR_AGENT_TESSERACT),
62+
],
63+
)
64+
def it_computes_the_OCR_agent_qualified_module_name(
65+
self, OCR_AGENT: str, expected_value: str, OCR_AGENT_prop_: Mock
66+
):
67+
OCR_AGENT_prop_.return_value = OCR_AGENT
68+
assert OCRAgent._get_ocr_agent_cls_qname() == expected_value
69+
70+
@pytest.mark.parametrize("OCR_AGENT", [OCR_AGENT_PADDLE_OLD, OCR_AGENT_TESSERACT_OLD])
71+
def and_it_logs_a_warning_when_the_OCR_AGENT_module_name_is_obsolete(
72+
self, caplog: LogCaptureFixture, OCR_AGENT: str, OCR_AGENT_prop_: Mock
73+
):
74+
OCR_AGENT_prop_.return_value = OCR_AGENT
75+
OCRAgent._get_ocr_agent_cls_qname()
76+
assert f"OCR agent name {OCR_AGENT} is outdated " in caplog.text
77+
78+
# -- fixtures --------------------------------------------------------------------------------
79+
80+
@pytest.fixture()
81+
def get_instance_(self, request: FixtureRequest):
82+
return method_mock(request, OCRAgent, "get_instance")
83+
84+
@pytest.fixture()
85+
def _get_ocr_agent_cls_qname_(self, request: FixtureRequest):
86+
return method_mock(request, OCRAgent, "_get_ocr_agent_cls_qname")
87+
88+
@pytest.fixture()
89+
def ocr_agent_(self, request: FixtureRequest):
90+
return instance_mock(request, OCRAgent)
91+
92+
@pytest.fixture()
93+
def OCR_AGENT_prop_(self, request: FixtureRequest):
94+
return property_mock(request, ENVConfig, "OCR_AGENT")

unstructured/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.13.7-dev5" # pragma: no cover
1+
__version__ = "0.13.7-dev6" # pragma: no cover

unstructured/partition/pdf.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -937,21 +937,19 @@ def _partition_pdf_or_image_with_ocr(
937937

938938

939939
def _partition_pdf_or_image_with_ocr_from_image(
940-
image: PILImage,
940+
image: PILImage.Image,
941941
languages: Optional[list[str]] = None,
942942
page_number: int = 1,
943943
include_page_breaks: bool = False,
944944
metadata_last_modified: Optional[str] = None,
945945
sort_mode: str = SORT_MODE_XY_CUT,
946-
**kwargs,
946+
**kwargs: Any,
947947
) -> list[Element]:
948948
"""Extract `unstructured` elements from an image using OCR and perform partitioning."""
949949

950-
from unstructured.partition.pdf_image.ocr import (
951-
get_ocr_agent,
952-
)
950+
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
953951

954-
ocr_agent = get_ocr_agent()
952+
ocr_agent = OCRAgent.get_agent()
955953
ocr_languages = prepare_languages_for_tesseract(languages)
956954

957955
# NOTE(christine): `unstructured_pytesseract.image_to_string()` returns sorted text

unstructured/partition/pdf_image/ocr.py

Lines changed: 3 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,11 @@
1212
from PIL import ImageSequence
1313

1414
from unstructured.documents.elements import ElementType
15-
from unstructured.logger import logger
1615
from unstructured.metrics.table.table_formats import SimpleTableCell
1716
from unstructured.partition.pdf_image.pdf_image_utils import pad_element_bboxes, valid_text
1817
from unstructured.partition.utils.config import env_config
19-
from unstructured.partition.utils.constants import (
20-
OCR_AGENT_PADDLE,
21-
OCR_AGENT_PADDLE_OLD,
22-
OCR_AGENT_TESSERACT,
23-
OCR_AGENT_TESSERACT_OLD,
24-
OCRMode,
25-
)
26-
from unstructured.partition.utils.ocr_models.ocr_interface import (
27-
OCRAgent,
28-
)
18+
from unstructured.partition.utils.constants import OCR_AGENT_TESSERACT, OCRMode
19+
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
2920
from unstructured.utils import requires_dependencies
3021

3122
if TYPE_CHECKING:
@@ -35,12 +26,6 @@
3526
from unstructured_inference.models.tables import UnstructuredTableTransformerModel
3627

3728

38-
# Force tesseract to be single threaded,
39-
# otherwise we see major performance problems
40-
if "OMP_THREAD_LIMIT" not in os.environ:
41-
os.environ["OMP_THREAD_LIMIT"] = "1"
42-
43-
4429
def process_data_with_ocr(
4530
data: bytes | IO[bytes],
4631
out_layout: "DocumentLayout",
@@ -200,7 +185,7 @@ def supplement_page_layout_with_ocr(
200185
with no text and add text from OCR to each element.
201186
"""
202187

203-
ocr_agent = get_ocr_agent()
188+
ocr_agent = OCRAgent.get_agent()
204189
if ocr_mode == OCRMode.FULL_PAGE.value:
205190
ocr_layout = ocr_agent.get_layout_from_image(
206191
image,
@@ -453,34 +438,3 @@ def supplement_layout_with_ocr_elements(
453438
final_layout = layout
454439

455440
return final_layout
456-
457-
458-
def get_ocr_agent() -> OCRAgent:
459-
ocr_agent_module = env_config.OCR_AGENT
460-
message = (
461-
"OCR agent name %s is outdated and will be deprecated in a future release; please use %s "
462-
"instead"
463-
)
464-
# deal with compatibility with origin way to set OCR
465-
if ocr_agent_module.lower() == OCR_AGENT_TESSERACT_OLD:
466-
logger.warning(
467-
message,
468-
ocr_agent_module,
469-
OCR_AGENT_TESSERACT,
470-
)
471-
ocr_agent_module = OCR_AGENT_TESSERACT
472-
elif ocr_agent_module.lower() == OCR_AGENT_PADDLE_OLD:
473-
logger.warning(
474-
message,
475-
ocr_agent_module,
476-
OCR_AGENT_PADDLE,
477-
)
478-
ocr_agent_module = OCR_AGENT_PADDLE
479-
try:
480-
ocr_agent = OCRAgent.get_instance(ocr_agent_module)
481-
except (ImportError, AttributeError):
482-
raise ValueError(
483-
"Environment variable OCR_AGENT",
484-
f" must be set to an existing ocr agent module, not {ocr_agent_module}.",
485-
)
486-
return ocr_agent

unstructured/partition/utils/ocr_models/ocr_interface.py

Lines changed: 68 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,38 +5,39 @@
55
from abc import ABC, abstractmethod
66
from typing import TYPE_CHECKING
77

8-
from unstructured.partition.utils.constants import OCR_AGENT_MODULES_WHITELIST
8+
from unstructured.logger import logger
9+
from unstructured.partition.utils.config import env_config
10+
from unstructured.partition.utils.constants import (
11+
OCR_AGENT_MODULES_WHITELIST,
12+
OCR_AGENT_PADDLE,
13+
OCR_AGENT_PADDLE_OLD,
14+
OCR_AGENT_TESSERACT,
15+
OCR_AGENT_TESSERACT_OLD,
16+
)
917

1018
if TYPE_CHECKING:
1119
from PIL import Image as PILImage
1220
from unstructured_inference.inference.elements import TextRegion
13-
from unstructured_inference.inference.layoutelement import (
14-
LayoutElement,
15-
)
21+
from unstructured_inference.inference.layoutelement import LayoutElement
1622

1723

1824
class OCRAgent(ABC):
1925
"""Defines the interface for an Optical Character Recognition (OCR) service."""
2026

21-
@abstractmethod
22-
def is_text_sorted(self) -> bool:
23-
pass
24-
25-
@abstractmethod
26-
def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str:
27-
pass
28-
29-
@abstractmethod
30-
def get_layout_from_image(
31-
self, image: PILImage.Image, ocr_languages: str = "eng"
32-
) -> list[TextRegion]:
33-
pass
27+
@classmethod
28+
def get_agent(cls) -> OCRAgent:
29+
"""Get the configured OCRAgent instance.
3430
35-
@abstractmethod
36-
def get_layout_elements_from_image(
37-
self, image: PILImage.Image, ocr_languages: str = "eng"
38-
) -> list[LayoutElement]:
39-
pass
31+
The OCR package used by the agent is determined by the `OCR_AGENT` environment variable.
32+
"""
33+
ocr_agent_cls_qname = cls._get_ocr_agent_cls_qname()
34+
try:
35+
return cls.get_instance(ocr_agent_cls_qname)
36+
except (ImportError, AttributeError):
37+
raise ValueError(
38+
f"Environment variable OCR_AGENT must be set to an existing OCR agent module,"
39+
f" not {ocr_agent_cls_qname}."
40+
)
4041

4142
@staticmethod
4243
@functools.lru_cache(maxsize=None)
@@ -51,3 +52,48 @@ def get_instance(ocr_agent_module: str) -> "OCRAgent":
5152
f"Environment variable OCR_AGENT module name {module_name}, must be set to a"
5253
f" whitelisted module part of {OCR_AGENT_MODULES_WHITELIST}.",
5354
)
55+
56+
@abstractmethod
57+
def get_layout_elements_from_image(
58+
self, image: PILImage.Image, ocr_languages: str = "eng"
59+
) -> list[LayoutElement]:
60+
pass
61+
62+
@abstractmethod
63+
def get_layout_from_image(
64+
self, image: PILImage.Image, ocr_languages: str = "eng"
65+
) -> list[TextRegion]:
66+
pass
67+
68+
@abstractmethod
69+
def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str:
70+
pass
71+
72+
@abstractmethod
73+
def is_text_sorted(self) -> bool:
74+
pass
75+
76+
@staticmethod
77+
def _get_ocr_agent_cls_qname() -> str:
78+
"""Get the fully-qualified class name of the configured OCR agent.
79+
80+
The qualified name (qname) looks like:
81+
"unstructured.partition.utils.ocr_models.tesseract_ocr.OCRAgentTesseract"
82+
83+
The qname provides the full module address and class name of the OCR agent.
84+
"""
85+
ocr_agent_qname = env_config.OCR_AGENT
86+
87+
# -- map legacy method of setting OCR agent by key-name to full qname --
88+
qnames_by_keyname = {
89+
OCR_AGENT_TESSERACT_OLD: OCR_AGENT_TESSERACT,
90+
OCR_AGENT_PADDLE_OLD: OCR_AGENT_PADDLE,
91+
}
92+
if qname_mapped_from_keyname := qnames_by_keyname.get(ocr_agent_qname.lower()):
93+
logger.warning(
94+
f"OCR agent name {ocr_agent_qname} is outdated and will be removed in a future"
95+
f" release; please use {qname_mapped_from_keyname} instead"
96+
)
97+
return qname_mapped_from_keyname
98+
99+
return ocr_agent_qname

unstructured/partition/utils/ocr_models/tesseract_ocr.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import os
34
from typing import TYPE_CHECKING, List
45

56
import cv2
@@ -22,9 +23,11 @@
2223

2324
if TYPE_CHECKING:
2425
from unstructured_inference.inference.elements import TextRegion
25-
from unstructured_inference.inference.layoutelement import (
26-
LayoutElement,
27-
)
26+
from unstructured_inference.inference.layoutelement import LayoutElement
27+
28+
# -- force tesseract to be single threaded, otherwise we see major performance problems --
29+
if "OMP_THREAD_LIMIT" not in os.environ:
30+
os.environ["OMP_THREAD_LIMIT"] = "1"
2831

2932

3033
class OCRAgentTesseract(OCRAgent):

0 commit comments

Comments
 (0)