1919import json
2020import logging
2121import os
22+ from typing import IO
2223
2324import requests
2425
@@ -57,7 +58,9 @@ class LLMWhispererClient:
5758 client's activities and errors.
5859 """
5960
60- formatter = logging .Formatter ("%(asctime)s - %(name)s - %(levelname)s - %(message)s" )
61+ formatter = logging .Formatter (
62+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
63+ )
6164 logger = logging .getLogger (__name__ )
6265 log_stream_handler = logging .StreamHandler ()
6366 log_stream_handler .setFormatter (formatter )
@@ -114,7 +117,9 @@ def __init__(
114117 self .api_key = os .getenv ("LLMWHISPERER_API_KEY" , "" )
115118 else :
116119 self .api_key = api_key
117- self .logger .debug ("api_key set to %s" , LLMWhispererUtils .redact_key (self .api_key ))
120+ self .logger .debug (
121+ "api_key set to %s" , LLMWhispererUtils .redact_key (self .api_key )
122+ )
118123
119124 self .api_timeout = api_timeout
120125
@@ -150,6 +155,7 @@ def get_usage_info(self) -> dict:
150155 def whisper (
151156 self ,
152157 file_path : str = "" ,
158+ stream : IO [bytes ] = None ,
153159 url : str = "" ,
154160 processing_mode : str = "ocr" ,
155161 output_mode : str = "line-printer" ,
@@ -170,6 +176,7 @@ def whisper(
170176
171177 Args:
172178 file_path (str, optional): The path to the file to be processed. Defaults to "".
179+ stream (IO[bytes], optional): A stream of bytes to be processed. Defaults to None.
173180 url (str, optional): The URL of the file to be processed. Defaults to "".
174181 processing_mode (str, optional): The processing mode. Can be "ocr" or "text". Defaults to "ocr".
175182 output_mode (str, optional): The output mode. Can be "line-printer" or "text". Defaults to "line-printer".
@@ -212,11 +219,11 @@ def whisper(
212219 self .logger .debug ("api_url: %s" , api_url )
213220 self .logger .debug ("params: %s" , params )
214221
215- if url == "" and file_path == "" :
222+ if url == "" and file_path == "" and stream is None :
216223 raise LLMWhispererClientException (
217224 {
218225 "status_code" : - 1 ,
219- "message" : "Either url or file_path must be provided" ,
226+ "message" : "Either url, stream or file_path must be provided" ,
220227 }
221228 )
222229
@@ -228,21 +235,39 @@ def whisper(
228235 }
229236 )
230237
238+ should_stream = False
231239 if url == "" :
232- with open (file_path , "rb" ) as f :
233- data = f .read ()
234- req = requests .Request (
235- "POST" ,
236- api_url ,
237- params = params ,
238- headers = self .headers ,
239- data = data ,
240- )
240+ if stream is not None :
241+
242+ should_stream = True
243+
244+ def generate ():
245+ for chunk in stream :
246+ yield chunk
247+
248+ req = requests .Request (
249+ "POST" ,
250+ api_url ,
251+ params = params ,
252+ headers = self .headers ,
253+ data = generate (),
254+ )
255+
256+ else :
257+ with open (file_path , "rb" ) as f :
258+ data = f .read ()
259+ req = requests .Request (
260+ "POST" ,
261+ api_url ,
262+ params = params ,
263+ headers = self .headers ,
264+ data = data ,
265+ )
241266 else :
242267 req = requests .Request ("POST" , api_url , params = params , headers = self .headers )
243268 prepared = req .prepare ()
244269 s = requests .Session ()
245- response = s .send (prepared , timeout = self .api_timeout )
270+ response = s .send (prepared , timeout = self .api_timeout , stream = should_stream )
246271 if response .status_code != 200 and response .status_code != 202 :
247272 message = json .loads (response .text )
248273 message ["status_code" ] = response .status_code
0 commit comments