Skip to content

Commit 124c562

Browse files
feat: v2 client tests (#12)
* V2 client test for whisper API * Fixed test outputs, refactored tests into integration folder * Added minor unit tests for callback related functions * Fixed typo in filename * Minor pre-commit config fix
1 parent a6af309 commit 124c562

16 files changed

+2840
-167
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ repos:
1717
exclude_types:
1818
- "markdown"
1919
- id: end-of-file-fixer
20+
exclude: "tests/test_data/.*"
2021
- id: check-yaml
2122
args: [--unsafe]
2223
- id: check-added-large-files
@@ -65,9 +66,7 @@ repos:
6566
args: [--max-line-length=120]
6667
exclude: |
6768
(?x)^(
68-
.*migrations/.*\.py|
69-
unstract-core/tests/.*|
70-
pkgs/unstract-flags/src/unstract/flags/evaluation_.*\.py|
69+
tests/test_data/.*|
7170
)$
7271
- repo: https://github.com/pycqa/isort
7372
rev: 5.13.2

src/unstract/llmwhisperer/client_v2.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,15 @@
1818
LLMWhispererClientException: Exception raised for errors in the LLMWhispererClient.
1919
"""
2020

21+
import copy
2122
import json
2223
import logging
2324
import os
24-
from typing import IO
25-
import copy
2625
import time
26+
from typing import IO
2727

2828
import requests
2929

30-
3130
BASE_URL = "https://llmwhisperer-api.unstract.com/api/v2"
3231

3332

@@ -63,9 +62,7 @@ class LLMWhispererClientV2:
6362
client's activities and errors.
6463
"""
6564

66-
formatter = logging.Formatter(
67-
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
68-
)
65+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
6966
logger = logging.getLogger(__name__)
7067
log_stream_handler = logging.StreamHandler()
7168
log_stream_handler.setFormatter(formatter)
@@ -251,12 +248,10 @@ def whisper(
251248
should_stream = False
252249
if url == "":
253250
if stream is not None:
254-
255251
should_stream = True
256252

257253
def generate():
258-
for chunk in stream:
259-
yield chunk
254+
yield from stream
260255

261256
req = requests.Request(
262257
"POST",
@@ -302,41 +297,31 @@ def generate():
302297
message["extraction"] = {}
303298
return message
304299
if status["status"] == "processing":
305-
self.logger.debug(
306-
f"Whisper-hash:{whisper_hash} | STATUS: processing..."
307-
)
300+
self.logger.debug(f"Whisper-hash:{whisper_hash} | STATUS: processing...")
308301
elif status["status"] == "delivered":
309-
self.logger.debug(
310-
f"Whisper-hash:{whisper_hash} | STATUS: Already delivered!"
311-
)
302+
self.logger.debug(f"Whisper-hash:{whisper_hash} | STATUS: Already delivered!")
312303
raise LLMWhispererClientException(
313304
{
314305
"status_code": -1,
315306
"message": "Whisper operation already delivered",
316307
}
317308
)
318309
elif status["status"] == "unknown":
319-
self.logger.debug(
320-
f"Whisper-hash:{whisper_hash} | STATUS: unknown..."
321-
)
310+
self.logger.debug(f"Whisper-hash:{whisper_hash} | STATUS: unknown...")
322311
raise LLMWhispererClientException(
323312
{
324313
"status_code": -1,
325314
"message": "Whisper operation status unknown",
326315
}
327316
)
328317
elif status["status"] == "failed":
329-
self.logger.debug(
330-
f"Whisper-hash:{whisper_hash} | STATUS: failed..."
331-
)
318+
self.logger.debug(f"Whisper-hash:{whisper_hash} | STATUS: failed...")
332319
message["status_code"] = -1
333320
message["message"] = "Whisper operation failed"
334321
message["extraction"] = {}
335322
return message
336323
elif status["status"] == "processed":
337-
self.logger.debug(
338-
f"Whisper-hash:{whisper_hash} | STATUS: processed!"
339-
)
324+
self.logger.debug(f"Whisper-hash:{whisper_hash} | STATUS: processed!")
340325
resultx = self.whisper_retrieve(whisper_hash=whisper_hash)
341326
if resultx["status_code"] == 200:
342327
message["status_code"] = 200
@@ -451,7 +436,6 @@ def register_webhook(self, url: str, auth_token: str, webhook_name: str) -> dict
451436
Raises:
452437
LLMWhispererClientException: If the API request fails, it raises an exception with
453438
the error message and status code returned by the API.
454-
455439
"""
456440

457441
data = {
@@ -489,7 +473,6 @@ def get_webhook_details(self, webhook_name: str) -> dict:
489473
Raises:
490474
LLMWhispererClientException: If the API request fails, it raises an exception with
491475
the error message and status code returned by the API.
492-
493476
"""
494477

495478
url = f"{self.base_url}/whisper-manage-callback"

tests/client_test_v2.py

Lines changed: 0 additions & 125 deletions
This file was deleted.

tests/conftest.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,21 @@
33
import pytest
44

55
from unstract.llmwhisperer.client import LLMWhispererClient
6+
from unstract.llmwhisperer.client_v2 import LLMWhispererClientV2
67

78

89
@pytest.fixture(name="client")
910
def llm_whisperer_client():
10-
# Create an instance of the client
1111
client = LLMWhispererClient()
1212
return client
1313

1414

15+
@pytest.fixture(name="client_v2")
16+
def llm_whisperer_client_v2():
17+
client = LLMWhispererClientV2()
18+
return client
19+
20+
1521
@pytest.fixture(name="data_dir", scope="session")
1622
def test_data_dir():
1723
return os.path.join(os.path.dirname(__file__), "test_data")

tests/integration/__init__.py

Whitespace-only changes.

tests/client_test.py renamed to tests/integration/client_test.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@ def test_get_usage_info(client):
2323
"subscription_plan",
2424
"today_page_count",
2525
]
26-
assert set(usage_info.keys()) == set(
27-
expected_keys
28-
), f"usage_info {usage_info} does not contain the expected keys"
26+
assert set(usage_info.keys()) == set(expected_keys), f"usage_info {usage_info} does not contain the expected keys"
2927

3028

3129
@pytest.mark.parametrize(
@@ -78,9 +76,7 @@ def test_whisper(self):
7876
# @unittest.skip("Skipping test_whisper")
7977
def test_whisper_stream(self):
8078
client = LLMWhispererClient()
81-
download_url = (
82-
"https://storage.googleapis.com/pandora-static/samples/bill.jpg.pdf"
83-
)
79+
download_url = "https://storage.googleapis.com/pandora-static/samples/bill.jpg.pdf"
8480
# Create a stream of download_url and pass it to whisper
8581
response_download = requests.get(download_url, stream=True)
8682
response_download.raise_for_status()
@@ -95,18 +91,14 @@ def test_whisper_stream(self):
9591
@unittest.skip("Skipping test_whisper_status")
9692
def test_whisper_status(self):
9793
client = LLMWhispererClient()
98-
response = client.whisper_status(
99-
whisper_hash="7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a"
100-
)
94+
response = client.whisper_status(whisper_hash="7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a")
10195
logger.info(response)
10296
self.assertIsInstance(response, dict)
10397

10498
@unittest.skip("Skipping test_whisper_retrieve")
10599
def test_whisper_retrieve(self):
106100
client = LLMWhispererClient()
107-
response = client.whisper_retrieve(
108-
whisper_hash="7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a"
109-
)
101+
response = client.whisper_retrieve(whisper_hash="7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a")
110102
logger.info(response)
111103
self.assertIsInstance(response, dict)
112104

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import logging
2+
import os
3+
from difflib import SequenceMatcher, unified_diff
4+
from pathlib import Path
5+
6+
import pytest
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
def test_get_usage_info(client_v2):
12+
usage_info = client_v2.get_usage_info()
13+
logger.info(usage_info)
14+
assert isinstance(usage_info, dict), "usage_info should be a dictionary"
15+
expected_keys = [
16+
"current_page_count",
17+
"current_page_count_low_cost",
18+
"current_page_count_form",
19+
"current_page_count_high_quality",
20+
"current_page_count_native_text",
21+
"daily_quota",
22+
"monthly_quota",
23+
"overage_page_count",
24+
"subscription_plan",
25+
"today_page_count",
26+
]
27+
assert set(usage_info.keys()) == set(expected_keys), f"usage_info {usage_info} does not contain the expected keys"
28+
29+
30+
@pytest.mark.parametrize(
31+
"output_mode, mode, input_file",
32+
[
33+
("layout_preserving", "native_text", "credit_card.pdf"),
34+
("layout_preserving", "low_cost", "credit_card.pdf"),
35+
("layout_preserving", "high_quality", "restaurant_invoice_photo.pdf"),
36+
("layout_preserving", "form", "handwritten-form.pdf"),
37+
("text", "native_text", "credit_card.pdf"),
38+
("text", "low_cost", "credit_card.pdf"),
39+
("text", "high_quality", "restaurant_invoice_photo.pdf"),
40+
("text", "form", "handwritten-form.pdf"),
41+
],
42+
)
43+
def test_whisper_v2(client_v2, data_dir, output_mode, mode, input_file):
44+
file_path = os.path.join(data_dir, input_file)
45+
whisper_result = client_v2.whisper(
46+
mode=mode, output_mode=output_mode, file_path=file_path, wait_for_completion=True
47+
)
48+
logger.debug(f"Result for '{output_mode}', '{mode}', " f"'{input_file}: {whisper_result}")
49+
50+
exp_basename = f"{Path(input_file).stem}.{mode}.{output_mode}.txt"
51+
exp_file = os.path.join(data_dir, "expected", exp_basename)
52+
with open(exp_file, encoding="utf-8") as f:
53+
exp = f.read()
54+
55+
assert isinstance(whisper_result, dict)
56+
assert whisper_result["status_code"] == 200
57+
58+
# For text based processing, perform a strict match
59+
if mode == "native_text" and output_mode == "text":
60+
assert whisper_result["extraction"]["result_text"] == exp
61+
# For OCR based processing, perform a fuzzy match
62+
else:
63+
extracted_text = whisper_result["extraction"]["result_text"]
64+
similarity = SequenceMatcher(None, extracted_text, exp).ratio()
65+
threshold = 0.97
66+
67+
if similarity < threshold:
68+
diff = "\n".join(
69+
unified_diff(exp.splitlines(), extracted_text.splitlines(), fromfile="Expected", tofile="Extracted")
70+
)
71+
pytest.fail(f"Texts are not similar enough: {similarity * 100:.2f}% similarity. Diff:\n{diff}")

0 commit comments

Comments
 (0)