Skip to content

Commit 0b02756

Browse files
authored
Merge pull request #4 from Zipstack/ocr-adapter-support
Ocr adapter support
2 parents 932d557 + e4a16f8 commit 0b02756

File tree

7 files changed

+354
-446
lines changed

7 files changed

+354
-446
lines changed

pdm.lock

Lines changed: 245 additions & 443 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ dependencies = [
1515
"python-magic~=0.4.27",
1616
"python-dotenv==1.0.0",
1717
# LLM Triad
18-
"unstract-adapters~=0.2.1",
18+
"unstract-adapters~=0.2.2",
1919
"llama-index==0.9.28",
2020
"tiktoken~=0.4.0",
2121
"transformers==4.37.0",

src/unstract/sdk/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.11.1"
1+
__version__ = "0.11.2"
22

33

44
def get_sdk_version():

src/unstract/sdk/ocr.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from abc import ABCMeta
2+
from typing import Optional
3+
4+
from unstract.adapters.constants import Common
5+
from unstract.adapters.ocr import adapters
6+
from unstract.adapters.ocr.ocr_adapter import OCRAdapter
7+
8+
from unstract.sdk.adapters import ToolAdapter
9+
from unstract.sdk.constants import LogLevel
10+
from unstract.sdk.tool.base import BaseTool
11+
12+
13+
class OCR(metaclass=ABCMeta):
14+
def __init__(self, tool: BaseTool):
15+
self.tool = tool
16+
self.ocr_adapters = adapters
17+
18+
def get_ocr(self, adapter_instance_id: str) -> Optional[OCRAdapter]:
19+
try:
20+
ocr_config = ToolAdapter.get_adapter_config(
21+
self.tool, adapter_instance_id
22+
)
23+
ocr_adapter_id = ocr_config.get(Common.ADAPTER_ID)
24+
if ocr_adapter_id in self.ocr_adapters:
25+
ocr_adapter = self.ocr_adapters[ocr_adapter_id][
26+
Common.METADATA
27+
][Common.ADAPTER]
28+
ocr_metadata = ocr_config.get(Common.ADAPTER_METADATA)
29+
ocr_adapter_class = ocr_adapter(ocr_metadata)
30+
31+
return ocr_adapter_class
32+
33+
except Exception as e:
34+
self.tool.stream_log(
35+
log=f"Unable to get OCR adapter {adapter_instance_id}: {e}",
36+
level=LogLevel.ERROR,
37+
)
38+
return None

src/unstract/sdk/x2txt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABCMeta
2+
from typing import Optional
23

34
from unstract.adapters.constants import Common
45
from unstract.adapters.x2text import adapters
@@ -15,7 +16,7 @@ def __init__(self, tool: BaseTool):
1516
self.tool = tool
1617
self.x2text_adapters = adapters
1718

18-
def get_x2text(self, adapter_instance_id: str) -> X2TextAdapter:
19+
def get_x2text(self, adapter_instance_id: str) -> Optional[X2TextAdapter]:
1920
try:
2021
x2text_config = ToolAdapter.get_adapter_config(
2122
self.tool, adapter_instance_id

tests/sample.env

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@ X2TEXT_PORT=3004
88
LLM_TEST_VALUES=["", "", ""]
99
EMBEDDING_TEST_VALUES=["", "", ""]
1010
VECTOR_DB_TEST_VALUES=["", "", ""]
11+
OCR_TEST_VALUES=["", ""]
1112
X2TEXT_TEST_VALUES=["", ""]
13+

tests/test_ocr.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import json
2+
import logging
3+
import os
4+
import unittest
5+
from typing import Any
6+
7+
from dotenv import load_dotenv
8+
from parameterized import parameterized
9+
10+
from unstract.sdk.ocr import OCR
11+
from unstract.sdk.tool.base import BaseTool
12+
13+
load_dotenv()
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
def get_test_values(env_key: str) -> list[str]:
19+
values = json.loads(os.environ.get(env_key))
20+
return values
21+
22+
23+
def get_env_value(env_key: str) -> str:
24+
value = os.environ.get(env_key)
25+
return value
26+
27+
28+
class ToolOCRTest(unittest.TestCase):
29+
class MockTool(BaseTool):
30+
def run(
31+
self,
32+
params: dict[str, Any] = {},
33+
settings: dict[str, Any] = {},
34+
workflow_id: str = "",
35+
) -> None:
36+
pass
37+
38+
@classmethod
39+
def setUpClass(cls):
40+
cls.tool = cls.MockTool()
41+
42+
@parameterized.expand(get_test_values("OCR_TEST_VALUES"))
43+
def test_get_ocr(self, adapter_instance_id):
44+
tool_ocr = OCR(tool=self.tool)
45+
ocr = tool_ocr.get_ocr(adapter_instance_id)
46+
result = ocr.test_connection()
47+
self.assertTrue(result)
48+
input_file = get_env_value("INPUT_FILE_PATH")
49+
output_file = get_env_value("OUTPUT_FILE_PATH")
50+
if os.path.isfile(output_file):
51+
os.remove(output_file)
52+
output = ocr.process(input_file, output_file)
53+
file_size = os.path.getsize(output_file)
54+
self.assertGreater(file_size, 0)
55+
if os.path.isfile(output_file):
56+
os.remove(output_file)
57+
with open(output_file, "w", encoding="utf-8") as f:
58+
f.write(output)
59+
f.close()
60+
file_size = os.path.getsize(output_file)
61+
self.assertGreater(file_size, 0)
62+
63+
64+
if __name__ == "__main__":
65+
unittest.main()

0 commit comments

Comments
 (0)