11import logging
22import os
3- import unittest
3+ from difflib import SequenceMatcher , unified_diff
44from pathlib import Path
55
66import pytest
7- import requests
8-
9- from unstract .llmwhisperer import LLMWhispererClient
107
118logger = logging .getLogger (__name__ )
129
@@ -23,9 +20,7 @@ def test_get_usage_info(client):
2320 "subscription_plan" ,
2421 "today_page_count" ,
2522 ]
26- assert set (usage_info .keys ()) == set (
27- expected_keys
28- ), f"usage_info { usage_info } does not contain the expected keys"
23+ assert set (usage_info .keys ()) == set (expected_keys ), f"usage_info { usage_info } does not contain the expected keys"
2924
3025
3126@pytest .mark .parametrize (
@@ -41,85 +36,38 @@ def test_get_usage_info(client):
4136)
4237def test_whisper (client , data_dir , processing_mode , output_mode , input_file ):
4338 file_path = os .path .join (data_dir , input_file )
44- response = client .whisper (
39+ whisper_result = client .whisper (
4540 processing_mode = processing_mode ,
4641 output_mode = output_mode ,
4742 file_path = file_path ,
4843 timeout = 200 ,
4944 )
50- logger .debug (response )
45+ logger .debug (whisper_result )
5146
5247 exp_basename = f"{ Path (input_file ).stem } .{ processing_mode } .{ output_mode } .txt"
5348 exp_file = os .path .join (data_dir , "expected" , exp_basename )
54- with open (exp_file , encoding = "utf-8" ) as f :
55- exp = f .read ()
5649
57- assert isinstance (response , dict )
58- assert response ["status_code" ] == 200
59- assert response ["extracted_text" ] == exp
50+ assert_extracted_text (exp_file , whisper_result , processing_mode , output_mode )
6051
6152
62- # TODO: Review and port to pytest based tests
63- class TestLLMWhispererClient (unittest .TestCase ):
64- @unittest .skip ("Skipping test_whisper" )
65- def test_whisper (self ):
66- client = LLMWhispererClient ()
67- # response = client.whisper(
68- # url="https://storage.googleapis.com/pandora-static/samples/bill.jpg.pdf"
69- # )
70- response = client .whisper (
71- file_path = "test_data/restaurant_invoice_photo.pdf" ,
72- timeout = 200 ,
73- store_metadata_for_highlighting = True ,
74- )
75- print (response )
76- # self.assertIsInstance(response, dict)
53+ def assert_extracted_text (file_path , whisper_result , mode , output_mode ):
54+ with open (file_path , encoding = "utf-8" ) as f :
55+ exp = f .read ()
7756
78- # @unittest.skip("Skipping test_whisper")
79- def test_whisper_stream (self ):
80- client = LLMWhispererClient ()
81- download_url = (
82- "https://storage.googleapis.com/pandora-static/samples/bill.jpg.pdf"
83- )
84- # Create a stream of download_url and pass it to whisper
85- response_download = requests .get (download_url , stream = True )
86- response_download .raise_for_status ()
87- response = client .whisper (
88- stream = response_download .iter_content (chunk_size = 1024 ),
89- timeout = 200 ,
90- store_metadata_for_highlighting = True ,
91- )
92- print (response )
93- # self.assertIsInstance(response, dict)
57+ assert isinstance (whisper_result , dict )
58+ assert whisper_result ["status_code" ] == 200
9459
95- @unittest .skip ("Skipping test_whisper_status" )
96- def test_whisper_status (self ):
97- client = LLMWhispererClient ()
98- response = client .whisper_status (
99- whisper_hash = "7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a"
100- )
101- logger .info (response )
102- self .assertIsInstance (response , dict )
60+ # For OCR based processing
61+ threshold = 0.97
10362
104- @unittest .skip ("Skipping test_whisper_retrieve" )
105- def test_whisper_retrieve (self ):
106- client = LLMWhispererClient ()
107- response = client .whisper_retrieve (
108- whisper_hash = "7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a"
109- )
110- logger .info (response )
111- self .assertIsInstance (response , dict )
63+ # For text based processing
64+ if mode == "native_text" and output_mode == "text" :
65+ threshold = 0.99
66+ extracted_text = whisper_result ["extracted_text" ]
67+ similarity = SequenceMatcher (None , extracted_text , exp ).ratio ()
11268
113- @unittest .skip ("Skipping test_whisper_highlight_data" )
114- def test_whisper_highlight_data (self ):
115- client = LLMWhispererClient ()
116- response = client .highlight_data (
117- whisper_hash = "9924d865|5f1d285a7cf18d203de7af1a1abb0a3a" ,
118- search_text = "Indiranagar" ,
69+ if similarity < threshold :
70+ diff = "\n " .join (
71+ unified_diff (exp .splitlines (), extracted_text .splitlines (), fromfile = "Expected" , tofile = "Extracted" )
11972 )
120- logger .info (response )
121- self .assertIsInstance (response , dict )
122-
123-
124- if __name__ == "__main__" :
125- unittest .main ()
73+ pytest .fail (f"Texts are not similar enough: { similarity * 100 :.2f} % similarity. Diff:\n { diff } " )
0 commit comments