Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/unstract/llmwhisperer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.21.0"
__version__ = "0.22.0"

from .client import LLMWhispererClient # noqa: F401

Expand Down
21 changes: 10 additions & 11 deletions src/unstract/llmwhisperer/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ class LLMWhispererClient:
client's activities and errors.
"""

formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
log_stream_handler = logging.StreamHandler()
log_stream_handler.setFormatter(formatter)
Expand Down Expand Up @@ -117,9 +115,7 @@ def __init__(
self.api_key = os.getenv("LLMWHISPERER_API_KEY", "")
else:
self.api_key = api_key
self.logger.debug(
"api_key set to %s", LLMWhispererUtils.redact_key(self.api_key)
)
self.logger.debug("api_key set to %s", LLMWhispererUtils.redact_key(self.api_key))

self.api_timeout = api_timeout

Expand Down Expand Up @@ -169,6 +165,7 @@ def whisper(
ocr_provider: str = "advanced",
line_splitter_tolerance: float = 0.4,
horizontal_stretch_factor: float = 1.0,
encoding: str = "utf-8",
) -> dict:
"""
Sends a request to the LLMWhisperer API to process a document.
Expand All @@ -190,6 +187,7 @@ def whisper(
ocr_provider (str, optional): The OCR provider. Can be "advanced" or "basic". Defaults to "advanced".
line_splitter_tolerance (float, optional): The line splitter tolerance. Defaults to 0.4.
horizontal_stretch_factor (float, optional): The horizontal stretch factor. Defaults to 1.0.
encoding (str): The character encoding to use for processing the text. Defaults to "utf-8".

Returns:
dict: The response from the API as a dictionary.
Expand Down Expand Up @@ -238,12 +236,10 @@ def whisper(
should_stream = False
if url == "":
if stream is not None:

should_stream = True

def generate():
for chunk in stream:
yield chunk
yield from stream

req = requests.Request(
"POST",
Expand All @@ -267,7 +263,8 @@ def generate():
req = requests.Request("POST", api_url, params=params, headers=self.headers)
prepared = req.prepare()
s = requests.Session()
response = s.send(prepared, timeout=self.api_timeout, stream=should_stream)
response = s.send(prepared, timeout=timeout, stream=should_stream)
response.encoding = encoding
if response.status_code != 200 and response.status_code != 202:
message = json.loads(response.text)
message["status_code"] = response.status_code
Expand Down Expand Up @@ -318,7 +315,7 @@ def whisper_status(self, whisper_hash: str) -> dict:
message["status_code"] = response.status_code
return message

def whisper_retrieve(self, whisper_hash: str) -> dict:
def whisper_retrieve(self, whisper_hash: str, encoding: str = "utf-8") -> dict:
"""Retrieves the result of the whisper operation from the LLMWhisperer
API.

Expand All @@ -329,6 +326,7 @@ def whisper_retrieve(self, whisper_hash: str) -> dict:

Args:
whisper_hash (str): The hash of the whisper operation.
encoding (str): The character encoding to use for processing the text. Defaults to "utf-8".

Returns:
dict: A dictionary containing the status code and the extracted text from the whisper operation.
Expand All @@ -345,6 +343,7 @@ def whisper_retrieve(self, whisper_hash: str) -> dict:
prepared = req.prepare()
s = requests.Session()
response = s.send(prepared, timeout=self.api_timeout)
response.encoding = encoding
if response.status_code != 200:
err = json.loads(response.text)
err["status_code"] = response.status_code
Expand Down
43 changes: 27 additions & 16 deletions src/unstract/llmwhisperer/client_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class LLMWhispererClientV2:

api_key = ""
base_url = ""
api_timeout = 120

def __init__(
self,
Expand Down Expand Up @@ -139,7 +140,7 @@ def get_usage_info(self) -> dict:
req = requests.Request("GET", url, headers=self.headers)
prepared = req.prepare()
s = requests.Session()
response = s.send(prepared, timeout=120)
response = s.send(prepared, timeout=self.api_timeout)
if response.status_code != 200:
err = json.loads(response.text)
err["status_code"] = response.status_code
Expand Down Expand Up @@ -169,6 +170,7 @@ def whisper(
use_webhook="",
wait_for_completion=False,
wait_timeout=180,
encoding: str = "utf-8",
) -> dict:
"""
Sends a request to the LLMWhisperer API to process a document.
Expand All @@ -178,8 +180,10 @@ def whisper(
file_path (str, optional): The path to the file to be processed. Defaults to "".
stream (IO[bytes], optional): A stream of bytes to be processed. Defaults to None.
url (str, optional): The URL of the file to be processed. Defaults to "".
mode (str, optional): The processing mode. Can be "high_quality", "form", "low_cost" or "native_text". Defaults to "high_quality".
output_mode (str, optional): The output mode. Can be "layout_preserving" or "text". Defaults to "layout_preserving".
mode (str, optional): The processing mode. Can be "high_quality", "form", "low_cost" or "native_text".
Defaults to "high_quality".
output_mode (str, optional): The output mode. Can be "layout_preserving" or "text".
Defaults to "layout_preserving".
page_seperator (str, optional): The page separator. Defaults to "<<<".
pages_to_extract (str, optional): The pages to extract. Defaults to "".
median_filter_size (int, optional): The size of the median filter. Defaults to 0.
Expand All @@ -192,10 +196,15 @@ def whisper(
lang (str, optional): The language of the document. Defaults to "eng".
tag (str, optional): The tag for the document. Defaults to "default".
filename (str, optional): The name of the file to store in reports. Defaults to "".
webhook_metadata (str, optional): The webhook metadata. This data will be passed to the webhook if webhooks are used Defaults to "".
use_webhook (str, optional): Webhook name to call. Defaults to "". If not provided, the no webhook will be called.
wait_for_completion (bool, optional): Whether to wait for the whisper operation to complete. Defaults to False.
wait_timeout (int, optional): The number of seconds to wait for the whisper operation to complete. Defaults to 180.
webhook_metadata (str, optional): The webhook metadata. This data will be passed to the webhook if
webhooks are used Defaults to "".
use_webhook (str, optional): Webhook name to call. Defaults to "". If not provided, then
no webhook will be called.
wait_for_completion (bool, optional): Whether to wait for the whisper operation to complete.
Defaults to False.
wait_timeout (int, optional): The number of seconds to wait for the whisper operation to complete.
Defaults to 180.
encoding (str): The character encoding to use for processing the text. Defaults to "utf-8".

Returns:
dict: The response from the API as a dictionary.
Expand Down Expand Up @@ -275,7 +284,8 @@ def generate():
req = requests.Request("POST", api_url, params=params, headers=self.headers)
prepared = req.prepare()
s = requests.Session()
response = s.send(prepared, timeout=120, stream=should_stream)
response = s.send(prepared, timeout=wait_timeout, stream=should_stream)
response.encoding = encoding
if response.status_code != 200 and response.status_code != 202:
message = json.loads(response.text)
message["status_code"] = response.status_code
Expand Down Expand Up @@ -371,7 +381,7 @@ def whisper_status(self, whisper_hash: str) -> dict:
req = requests.Request("GET", url, headers=self.headers, params=params)
prepared = req.prepare()
s = requests.Session()
response = s.send(prepared, timeout=120)
response = s.send(prepared, timeout=self.api_timeout)
if response.status_code != 200:
err = json.loads(response.text)
err["status_code"] = response.status_code
Expand All @@ -380,7 +390,7 @@ def whisper_status(self, whisper_hash: str) -> dict:
message["status_code"] = response.status_code
return message

def whisper_retrieve(self, whisper_hash: str) -> dict:
def whisper_retrieve(self, whisper_hash: str, encoding: str = "utf-8") -> dict:
"""Retrieves the result of the whisper operation from the LLMWhisperer
API.

Expand All @@ -391,6 +401,7 @@ def whisper_retrieve(self, whisper_hash: str) -> dict:

Args:
whisper_hash (str): The hash of the whisper operation.
encoding (str): The character encoding to use for processing the text. Defaults to "utf-8".

Returns:
dict: A dictionary containing the status code and the extracted text from the whisper operation.
Expand All @@ -406,7 +417,8 @@ def whisper_retrieve(self, whisper_hash: str) -> dict:
req = requests.Request("GET", url, headers=self.headers, params=params)
prepared = req.prepare()
s = requests.Session()
response = s.send(prepared, timeout=120)
response = s.send(prepared, timeout=self.api_timeout)
response.encoding = encoding
if response.status_code != 200:
err = json.loads(response.text)
err["status_code"] = response.status_code
Expand Down Expand Up @@ -449,7 +461,7 @@ def register_webhook(self, url: str, auth_token: str, webhook_name: str) -> dict
req = requests.Request("POST", url, headers=headersx, json=data)
prepared = req.prepare()
s = requests.Session()
response = s.send(prepared, timeout=120)
response = s.send(prepared, timeout=self.api_timeout)
if response.status_code != 200:
err = json.loads(response.text)
err["status_code"] = response.status_code
Expand Down Expand Up @@ -480,7 +492,7 @@ def get_webhook_details(self, webhook_name: str) -> dict:
req = requests.Request("GET", url, headers=self.headers, params=params)
prepared = req.prepare()
s = requests.Session()
response = s.send(prepared, timeout=120)
response = s.send(prepared, timeout=self.api_timeout)
if response.status_code != 200:
err = json.loads(response.text)
err["status_code"] = response.status_code
Expand All @@ -493,9 +505,8 @@ def get_highlight_rect(
target_width: int,
target_height: int,
) -> tuple[int, int, int, int, int]:
"""
Given the line metadata and the line number, this function returns the bounding box of the line
in the format (page,x1,y1,x2,y2)
"""Given the line metadata and the line number, this function returns
the bounding box of the line in the format (page,x1,y1,x2,y2)

Args:
line_metadata (list[int]): The line metadata returned by the LLMWhisperer API.
Expand Down
1 change: 1 addition & 0 deletions tests/integration/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_get_usage_info(client):
("ocr", "text", "restaurant_invoice_photo.pdf"),
("text", "line-printer", "restaurant_invoice_photo.pdf"),
("text", "text", "handwritten-form.pdf"),
("ocr", "line-printer", "utf_8_chars.pdf"),
],
)
def test_whisper(client, data_dir, processing_mode, output_mode, input_file):
Expand Down
1 change: 1 addition & 0 deletions tests/integration/client_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_get_usage_info(client_v2):
("text", "low_cost", "credit_card.pdf"),
("text", "high_quality", "restaurant_invoice_photo.pdf"),
("text", "form", "handwritten-form.pdf"),
("layout_preserving", "high_quality", "utf_8_chars.pdf"),
],
)
def test_whisper_v2(client_v2, data_dir, output_mode, mode, input_file):
Expand Down
Loading