diff --git a/src/unstract/llmwhisperer/client.py b/src/unstract/llmwhisperer/client.py index 39e34cb..e7076e9 100644 --- a/src/unstract/llmwhisperer/client.py +++ b/src/unstract/llmwhisperer/client.py @@ -169,6 +169,7 @@ def whisper( ocr_provider: str = "advanced", line_splitter_tolerance: float = 0.4, horizontal_stretch_factor: float = 1.0, + encoding: str = "utf-8" ) -> dict: """ Sends a request to the LLMWhisperer API to process a document. @@ -190,6 +191,7 @@ def whisper( ocr_provider (str, optional): The OCR provider. Can be "advanced" or "basic". Defaults to "advanced". line_splitter_tolerance (float, optional): The line splitter tolerance. Defaults to 0.4. horizontal_stretch_factor (float, optional): The horizontal stretch factor. Defaults to 1.0. + encoding (str): The character encoding to use for processing the text. Defaults to "utf-8". Returns: dict: The response from the API as a dictionary. @@ -268,6 +270,7 @@ def generate(): prepared = req.prepare() s = requests.Session() response = s.send(prepared, timeout=self.api_timeout, stream=should_stream) + response.encoding = encoding if response.status_code != 200 and response.status_code != 202: message = json.loads(response.text) message["status_code"] = response.status_code @@ -318,7 +321,7 @@ def whisper_status(self, whisper_hash: str) -> dict: message["status_code"] = response.status_code return message - def whisper_retrieve(self, whisper_hash: str) -> dict: + def whisper_retrieve(self, whisper_hash: str, encoding: str = "utf-8") -> dict: """Retrieves the result of the whisper operation from the LLMWhisperer API. @@ -329,6 +332,7 @@ def whisper_retrieve(self, whisper_hash: str) -> dict: Args: whisper_hash (str): The hash of the whisper operation. + encoding (str): The character encoding to use for processing the text. Defaults to "utf-8". Returns: dict: A dictionary containing the status code and the extracted text from the whisper operation. @@ -345,6 +349,7 @@ def whisper_retrieve(self, whisper_hash: str) -> dict: prepared = req.prepare() s = requests.Session() response = s.send(prepared, timeout=self.api_timeout) + response.encoding = encoding if response.status_code != 200: err = json.loads(response.text) err["status_code"] = response.status_code