1616from queue import Empty , Queue
1717from typing import Dict , List , Optional
1818
19- import zmq
2019from fastapi import HTTPException
2120
2221from litserve import LitAPI
2827logger = logging .getLogger (__name__ )
2928
3029
31- def run_batched_loop (
32- lit_api : LitAPI ,
33- lit_spec : LitSpec ,
34- request_queue : Queue ,
35- response_queues : List [Queue ],
36- max_batch_size : int ,
37- batch_timeout : float ,
38- callback_runner : CallbackRunner ,
39- socket : Optional [zmq .Socket ],
40- ):
41- while True :
42- batches , timed_out_uids = collate_requests (
43- lit_api ,
44- request_queue ,
45- max_batch_size ,
46- batch_timeout ,
47- )
48-
49- for response_queue_id , uid in timed_out_uids :
50- logger .error (
51- f"Request { uid } was waiting in the queue for too long ({ lit_api .request_timeout } seconds) and "
52- "has been timed out. "
53- "You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)."
54- )
55- if socket :
56- socket .send_pyobj ((uid , (HTTPException (504 , "Request timed out" ), LitAPIStatus .ERROR )))
57- else :
58- response_queues [response_queue_id ].put ((
59- uid ,
60- (HTTPException (504 , "Request timed out" ), LitAPIStatus .ERROR ),
61- ))
62-
63- if not batches :
64- continue
65- logger .debug (f"{ len (batches )} batched requests received" )
66- response_queue_ids , uids , inputs = zip (* batches )
67- num_inputs = len (inputs )
68- try :
69- contexts = [{}] * num_inputs
70- if hasattr (lit_spec , "populate_context" ):
71- for input , context in zip (inputs , contexts ):
72- lit_spec .populate_context (context , input )
73-
74- callback_runner .trigger_event (EventTypes .BEFORE_DECODE_REQUEST , lit_api = lit_api )
75- x = [
76- _inject_context (
77- context ,
78- lit_api .decode_request ,
79- input ,
80- )
81- for input , context in zip (inputs , contexts )
82- ]
83- callback_runner .trigger_event (EventTypes .AFTER_DECODE_REQUEST , lit_api = lit_api )
84-
85- x = lit_api .batch (x )
86-
87- callback_runner .trigger_event (EventTypes .BEFORE_PREDICT , lit_api = lit_api )
88- y = _inject_context (contexts , lit_api .predict , x )
89- callback_runner .trigger_event (EventTypes .AFTER_PREDICT , lit_api = lit_api )
90-
91- outputs = lit_api .unbatch (y )
92-
93- if len (outputs ) != num_inputs :
94- logger .error (
95- "LitAPI.predict/unbatch returned {len(outputs)} outputs, but expected {num_inputs}. "
96- "Please check the predict/unbatch method of the LitAPI implementation."
97- )
98- raise HTTPException (500 , "Batch size mismatch" )
99-
100- callback_runner .trigger_event (EventTypes .BEFORE_ENCODE_RESPONSE , lit_api = lit_api )
101- y_enc_list = []
102- for response_queue_id , y , uid , context in zip (response_queue_ids , outputs , uids , contexts ):
103- y_enc = _inject_context (context , lit_api .encode_response , y )
104- y_enc_list .append ((response_queue_id , uid , y_enc ))
105- callback_runner .trigger_event (EventTypes .AFTER_ENCODE_RESPONSE , lit_api = lit_api )
106-
107- for response_queue_id , uid , y_enc in y_enc_list :
108- response_queues [response_queue_id ].put ((uid , (y_enc , LitAPIStatus .OK )))
109-
110- except HTTPException as e :
111- for response_queue_id , uid in zip (response_queue_ids , uids ):
112- if socket :
113- socket .send_pyobj ((uid , (PickleableHTTPException .from_exception (e ), LitAPIStatus .ERROR )))
114- else :
115- response_queues [response_queue_id ].put ((
116- uid ,
117- (PickleableHTTPException .from_exception (e ), LitAPIStatus .ERROR ),
118- ))
119-
120- except Exception as e :
121- logger .exception (
122- "LitAPI ran into an error while processing the batched request.\n "
123- "Please check the error trace for more details."
124- )
125- for response_queue_id , uid in zip (response_queue_ids , uids ):
126- if socket :
127- socket .send_pyobj ((uid , (e , LitAPIStatus .ERROR )))
128- else :
129- response_queues [response_queue_id ].put ((uid , (e , LitAPIStatus .ERROR )))
130-
131-
13230class SingleLoop (DefaultLoop ):
13331 def run_single_loop (
13432 self ,
13533 lit_api : LitAPI ,
136- lit_spec : LitSpec ,
34+ lit_spec : Optional [ LitSpec ] ,
13735 request_queue : Queue ,
13836 response_queues : List [Queue ],
13937 callback_runner : CallbackRunner ,
140- socket : Optional [zmq .Socket ],
14138 ):
14239 while True :
14340 try :
@@ -233,9 +130,8 @@ def __call__(
233130 stream : bool ,
234131 workers_setup_status : Dict [int , str ],
235132 callback_runner : CallbackRunner ,
236- socket : Optional [zmq .Socket ],
237133 ):
238- self .run_single_loop (lit_api , lit_spec , request_queue , response_queues , callback_runner , socket )
134+ self .run_single_loop (lit_api , lit_spec , request_queue , response_queues , callback_runner )
239135
240136
241137class BatchedLoop (DefaultLoop ):
@@ -248,7 +144,6 @@ def run_batched_loop(
248144 max_batch_size : int ,
249145 batch_timeout : float ,
250146 callback_runner : CallbackRunner ,
251- socket : Optional [zmq .Socket ],
252147 ):
253148 while True :
254149 batches , timed_out_uids = collate_requests (
@@ -264,13 +159,9 @@ def run_batched_loop(
264159 "has been timed out. "
265160 "You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)."
266161 )
267- if socket :
268- socket .send_pyobj ((uid , (HTTPException (504 , "Request timed out" ), LitAPIStatus .ERROR )))
269- else :
270- response_queues [response_queue_id ].put ((
271- uid ,
272- (HTTPException (504 , "Request timed out" ), LitAPIStatus .ERROR ),
273- ))
162+ self .put_response (
163+ response_queues , response_queue_id , uid , HTTPException (504 , "Request timed out" ), LitAPIStatus .ERROR
164+ )
274165
275166 if not batches :
276167 continue
@@ -317,28 +208,25 @@ def run_batched_loop(
317208 callback_runner .trigger_event (EventTypes .AFTER_ENCODE_RESPONSE , lit_api = lit_api )
318209
319210 for response_queue_id , uid , y_enc in y_enc_list :
320- response_queues [ response_queue_id ]. put (( uid , ( y_enc , LitAPIStatus .OK )) )
211+ self . put_response ( response_queues , response_queue_id , uid , y_enc , LitAPIStatus .OK )
321212
322213 except HTTPException as e :
323214 for response_queue_id , uid in zip (response_queue_ids , uids ):
324- if socket :
325- socket . send_pyobj (( uid , ( PickleableHTTPException . from_exception ( e ), LitAPIStatus . ERROR )))
326- else :
327- response_queues [ response_queue_id ]. put ((
328- uid ,
329- ( PickleableHTTPException . from_exception ( e ), LitAPIStatus .ERROR ) ,
330- ) )
215+ self . put_response (
216+ response_queues ,
217+ response_queue_id ,
218+ uid ,
219+ PickleableHTTPException . from_exception ( e ) ,
220+ LitAPIStatus .ERROR ,
221+ )
331222
332223 except Exception as e :
333224 logger .exception (
334225 "LitAPI ran into an error while processing the batched request.\n "
335226 "Please check the error trace for more details."
336227 )
337228 for response_queue_id , uid in zip (response_queue_ids , uids ):
338- if socket :
339- socket .send_pyobj ((uid , (e , LitAPIStatus .ERROR )))
340- else :
341- response_queues [response_queue_id ].put ((uid , (e , LitAPIStatus .ERROR )))
229+ self .put_response (response_queues , response_queue_id , uid , e , LitAPIStatus .ERROR )
342230
343231 def __call__ (
344232 self ,
@@ -353,7 +241,6 @@ def __call__(
353241 stream : bool ,
354242 workers_setup_status : Dict [int , str ],
355243 callback_runner : CallbackRunner ,
356- socket : Optional [zmq .Socket ],
357244 ):
358245 self .run_batched_loop (
359246 lit_api ,
@@ -363,5 +250,4 @@ def __call__(
363250 max_batch_size ,
364251 batch_timeout ,
365252 callback_runner ,
366- socket ,
367253 )
0 commit comments