Skip to content

Commit f031453

Browse files
URL in post support and test case (#13)
* added url in post logic and test using url * added common function for extracted text assertion * renamed assert function --------- Signed-off-by: Chandrasekharan M <[email protected]> Co-authored-by: Chandrasekharan M <[email protected]>
1 parent 224b3bf commit f031453

File tree

4 files changed

+378
-301
lines changed

4 files changed

+378
-301
lines changed

src/unstract/llmwhisperer/client_v2.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,13 @@ def whisper(
152152
file_path: str = "",
153153
stream: IO[bytes] = None,
154154
url: str = "",
155-
mode: str = "high_quality",
155+
mode: str = "form",
156156
output_mode: str = "layout_preserving",
157157
page_seperator: str = "<<<",
158158
pages_to_extract: str = "",
159159
median_filter_size: int = 0,
160160
gaussian_blur_radius: int = 0,
161-
line_splitter_tolerance: float = 0.75,
161+
line_splitter_tolerance: float = 0.4,
162162
horizontal_stretch_factor: float = 1.0,
163163
mark_vertical_lines: bool = False,
164164
mark_horizontal_lines: bool = False,
@@ -216,7 +216,6 @@ def whisper(
216216
self.logger.debug("whisper called")
217217
api_url = f"{self.base_url}/whisper"
218218
params = {
219-
"url": url,
220219
"mode": mode,
221220
"output_mode": output_mode,
222221
"page_seperator": page_seperator,
@@ -281,7 +280,8 @@ def generate():
281280
data=data,
282281
)
283282
else:
284-
req = requests.Request("POST", api_url, params=params, headers=self.headers)
283+
params["url_in_post"] = True
284+
req = requests.Request("POST", api_url, params=params, headers=self.headers, data=url)
285285
prepared = req.prepare()
286286
s = requests.Session()
287287
response = s.send(prepared, timeout=wait_timeout, stream=should_stream)
@@ -350,7 +350,7 @@ def generate():
350350
return message
351351

352352
# Will not reach here if status code is 202
353-
message = response.text
353+
message = json.loads(response.text)
354354
message["status_code"] = response.status_code
355355
return message
356356

tests/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import os
22

33
import pytest
4+
from dotenv import load_dotenv
45

56
from unstract.llmwhisperer.client import LLMWhispererClient
67
from unstract.llmwhisperer.client_v2 import LLMWhispererClientV2
78

9+
load_dotenv()
10+
811

912
@pytest.fixture(name="client")
1013
def llm_whisperer_client():

tests/integration/client_v2_test.py

Lines changed: 66 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,23 +50,75 @@ def test_whisper_v2(client_v2, data_dir, output_mode, mode, input_file):
5050

5151
exp_basename = f"{Path(input_file).stem}.{mode}.{output_mode}.txt"
5252
exp_file = os.path.join(data_dir, "expected", exp_basename)
53-
with open(exp_file, encoding="utf-8") as f:
53+
# verify extracted text
54+
assert_extracted_text(exp_file, whisper_result, mode, output_mode)
55+
56+
57+
@pytest.mark.parametrize(
58+
"output_mode, mode, url, input_file, page_count",
59+
[
60+
("layout_preserving", "native_text", "https://unstractpocstorage.blob.core.windows.net/public/Amex.pdf",
61+
"credit_card.pdf", 7),
62+
("layout_preserving", "low_cost", "https://unstractpocstorage.blob.core.windows.net/public/Amex.pdf",
63+
"credit_card.pdf", 7),
64+
("layout_preserving", "high_quality", "https://unstractpocstorage.blob.core.windows.net/public/scanned_bill.pdf",
65+
"restaurant_invoice_photo.pdf", 1),
66+
("layout_preserving", "form", "https://unstractpocstorage.blob.core.windows.net/public/scanned_form.pdf",
67+
"handwritten-form.pdf", 1),
68+
]
69+
)
70+
def test_whisper_v2_url_in_post(client_v2, data_dir, output_mode, mode, url, input_file, page_count):
71+
usage_before = client_v2.get_usage_info()
72+
whisper_result = client_v2.whisper(
73+
mode=mode, output_mode=output_mode, url=url, wait_for_completion=True
74+
)
75+
logger.debug(f"Result for '{output_mode}', '{mode}', " f"'{input_file}: {whisper_result}")
76+
77+
exp_basename = f"{Path(input_file).stem}.{mode}.{output_mode}.txt"
78+
exp_file = os.path.join(data_dir, "expected", exp_basename)
79+
# verify extracted text
80+
assert_extracted_text(exp_file, whisper_result, mode, output_mode)
81+
usage_after = client_v2.get_usage_info()
82+
# Verify usage after extraction
83+
verify_usage(usage_before, usage_after, page_count, mode)
84+
85+
86+
def assert_extracted_text(file_path, whisper_result, mode, output_mode):
87+
with open(file_path, encoding="utf-8") as f:
5488
exp = f.read()
5589

5690
assert isinstance(whisper_result, dict)
5791
assert whisper_result["status_code"] == 200
5892

59-
# For text based processing, perform a strict match
93+
# For OCR based processing
94+
threshold = 0.97
95+
96+
# For text based processing
6097
if mode == "native_text" and output_mode == "text":
61-
assert whisper_result["extraction"]["result_text"] == exp
62-
# For OCR based processing, perform a fuzzy match
63-
else:
64-
extracted_text = whisper_result["extraction"]["result_text"]
65-
similarity = SequenceMatcher(None, extracted_text, exp).ratio()
66-
threshold = 0.97
67-
68-
if similarity < threshold:
69-
diff = "\n".join(
70-
unified_diff(exp.splitlines(), extracted_text.splitlines(), fromfile="Expected", tofile="Extracted")
71-
)
72-
pytest.fail(f"Texts are not similar enough: {similarity * 100:.2f}% similarity. Diff:\n{diff}")
98+
threshold = 0.99
99+
extracted_text = whisper_result["extraction"]["result_text"]
100+
similarity = SequenceMatcher(None, extracted_text, exp).ratio()
101+
102+
if similarity < threshold:
103+
diff = "\n".join(
104+
unified_diff(exp.splitlines(), extracted_text.splitlines(), fromfile="Expected", tofile="Extracted")
105+
)
106+
pytest.fail(f"Texts are not similar enough: {similarity * 100:.2f}% similarity. Diff:\n{diff}")
107+
108+
109+
def verify_usage(before_extract, after_extract, page_count, mode='form'):
110+
all_modes = ['form', 'high_quality', 'low_cost', 'native_text']
111+
all_modes.remove(mode)
112+
assert (after_extract['today_page_count'] == before_extract['today_page_count'] + page_count), \
113+
"today_page_count calculation is wrong"
114+
assert (after_extract['current_page_count'] == before_extract['current_page_count'] + page_count), \
115+
"current_page_count calculation is wrong"
116+
if after_extract['overage_page_count'] > 0:
117+
assert (after_extract['overage_page_count'] == before_extract['overage_page_count'] + page_count), \
118+
"overage_page_count calculation is wrong"
119+
assert (after_extract[f'current_page_count_{mode}'] == before_extract[f'current_page_count_{mode}'] + page_count), \
120+
f"{mode} mode calculation is wrong"
121+
for i in range(len(all_modes)):
122+
assert (after_extract[f'current_page_count_{all_modes[i]}'] ==
123+
before_extract[f'current_page_count_{all_modes[i]}']), \
124+
f"{all_modes[i]} mode calculation is wrong"

0 commit comments

Comments
 (0)