Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/unstract/llmwhisperer/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def whisper(
ocr_provider: str = "advanced",
line_splitter_tolerance: float = 0.4,
horizontal_stretch_factor: float = 1.0,
encoding: str = "utf-8"
) -> dict:
"""
Sends a request to the LLMWhisperer API to process a document.
Expand All @@ -190,6 +191,7 @@ def whisper(
ocr_provider (str, optional): The OCR provider. Can be "advanced" or "basic". Defaults to "advanced".
line_splitter_tolerance (float, optional): The line splitter tolerance. Defaults to 0.4.
horizontal_stretch_factor (float, optional): The horizontal stretch factor. Defaults to 1.0.
encoding (str): The character encoding to use for processing the text. Defaults to "utf-8".

Returns:
dict: The response from the API as a dictionary.
Expand Down Expand Up @@ -268,6 +270,7 @@ def generate():
prepared = req.prepare()
s = requests.Session()
response = s.send(prepared, timeout=self.api_timeout, stream=should_stream)
response.encoding = encoding
if response.status_code != 200 and response.status_code != 202:
message = json.loads(response.text)
message["status_code"] = response.status_code
Expand Down Expand Up @@ -318,7 +321,7 @@ def whisper_status(self, whisper_hash: str) -> dict:
message["status_code"] = response.status_code
return message

def whisper_retrieve(self, whisper_hash: str) -> dict:
def whisper_retrieve(self, whisper_hash: str, encoding: str = "utf-8") -> dict:
"""Retrieves the result of the whisper operation from the LLMWhisperer
API.

Expand All @@ -329,6 +332,7 @@ def whisper_retrieve(self, whisper_hash: str) -> dict:

Args:
whisper_hash (str): The hash of the whisper operation.
encoding (str): The character encoding to use for processing the text. Defaults to "utf-8".

Returns:
dict: A dictionary containing the status code and the extracted text from the whisper operation.
Expand All @@ -345,6 +349,7 @@ def whisper_retrieve(self, whisper_hash: str) -> dict:
prepared = req.prepare()
s = requests.Session()
response = s.send(prepared, timeout=self.api_timeout)
response.encoding = encoding
if response.status_code != 200:
err = json.loads(response.text)
err["status_code"] = response.status_code
Expand Down
Loading