11import logging
22import os
33import unittest
4+ from difflib import SequenceMatcher , unified_diff
45from pathlib import Path
56
67import pytest
@@ -23,9 +24,7 @@ def test_get_usage_info(client):
2324 "subscription_plan" ,
2425 "today_page_count" ,
2526 ]
26- assert set (usage_info .keys ()) == set (
27- expected_keys
28- ), f"usage_info { usage_info } does not contain the expected keys"
27+ assert set (usage_info .keys ()) == set (expected_keys ), f"usage_info { usage_info } does not contain the expected keys"
2928
3029
3130@pytest .mark .parametrize (
@@ -56,7 +55,21 @@ def test_whisper(client, data_dir, processing_mode, output_mode, input_file):
5655
5756 assert isinstance (response , dict )
5857 assert response ["status_code" ] == 200
59- assert response ["extracted_text" ] == exp
58+
59+ # For text based processing, perform a strict match
60+ if processing_mode == "text" and output_mode == "text" :
61+ assert response ["extracted_text" ] == exp
62+ # For OCR based processing, perform a fuzzy match
63+ else :
64+ extracted_text = response ["extracted_text" ]
65+ similarity = SequenceMatcher (None , extracted_text , exp ).ratio ()
66+ threshold = 0.97
67+
68+ if similarity < threshold :
69+ diff = "\n " .join (
70+ unified_diff (exp .splitlines (), extracted_text .splitlines (), fromfile = "Expected" , tofile = "Extracted" )
71+ )
72+ pytest .fail (f"Texts are not similar enough: { similarity * 100 :.2f} % similarity. Diff:\n { diff } " )
6073
6174
6275# TODO: Review and port to pytest based tests
@@ -78,9 +91,7 @@ def test_whisper(self):
7891 # @unittest.skip("Skipping test_whisper")
7992 def test_whisper_stream (self ):
8093 client = LLMWhispererClient ()
81- download_url = (
82- "https://storage.googleapis.com/pandora-static/samples/bill.jpg.pdf"
83- )
94+ download_url = "https://storage.googleapis.com/pandora-static/samples/bill.jpg.pdf"
8495 # Create a stream of download_url and pass it to whisper
8596 response_download = requests .get (download_url , stream = True )
8697 response_download .raise_for_status ()
@@ -95,18 +106,14 @@ def test_whisper_stream(self):
95106 @unittest .skip ("Skipping test_whisper_status" )
96107 def test_whisper_status (self ):
97108 client = LLMWhispererClient ()
98- response = client .whisper_status (
99- whisper_hash = "7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a"
100- )
109+ response = client .whisper_status (whisper_hash = "7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a" )
101110 logger .info (response )
102111 self .assertIsInstance (response , dict )
103112
104113 @unittest .skip ("Skipping test_whisper_retrieve" )
105114 def test_whisper_retrieve (self ):
106115 client = LLMWhispererClient ()
107- response = client .whisper_retrieve (
108- whisper_hash = "7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a"
109- )
116+ response = client .whisper_retrieve (whisper_hash = "7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a" )
110117 logger .info (response )
111118 self .assertIsInstance (response , dict )
112119
0 commit comments