1818
1919# TODO: Have a generic "Responder" interface vs just the concrete impl.
2020from shortfin .interop .fastapi import FastAPIResponder
21+ from fastapi .responses import JSONResponse
22+ from fastapi import status
2123
2224from .config_struct import DecodeConfig
2325from .io_struct import (
@@ -130,6 +132,7 @@ class ClientGenerateBatchProcess(sf.Process):
130132 "responder" ,
131133 "tokenizer" ,
132134 "decode_config" ,
135+ "service" ,
133136 ]
134137
135138 def __init__ (
@@ -140,6 +143,7 @@ def __init__(
140143 fiber : sf .Fiber | None = None ,
141144 ):
142145 super ().__init__ (fiber = service .main_fiber if fiber is None else fiber )
146+ self .service = service
143147 self .gen_req = gen_req
144148 self .responder = responder
145149 self .tokenizer = service .tokenizer
@@ -151,12 +155,29 @@ def __init__(
151155
152156 async def run (self ):
153157 logger .debug ("Started ClientBatchGenerateProcess: %r" , self )
154- streaming = self .gen_req .stream
155- self .responder .start_response ()
156- if streaming :
157- self .responder .stream_start ()
158+
159+ # Try to add request to queue
160+ # TODO(@zphoenixrises): Add load testing and integration tests for this.
161+ if not self .service .add_to_queue ():
162+ error_response = JSONResponse (
163+ status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
164+ content = {
165+ "error" : "Server queue is full. Please try again later." ,
166+ "code" : "QUEUE_FULL" ,
167+ "current_size" : self .service .current_queue_size ,
168+ "max_size" : self .service .max_queue_size ,
169+ },
170+ )
171+ self .responder .send_response (error_response )
172+ self .responder .ensure_response ()
173+ return
158174
159175 try :
176+ streaming = self .gen_req .stream
177+ self .responder .start_response ()
178+ if streaming :
179+ self .responder .stream_start ()
180+
160181 # Launch all individual generate processes and wait for them to finish.
161182 gen_processes = []
162183 input_ids = self .gen_req .input_ids
@@ -166,6 +187,7 @@ async def run(self):
166187 input_batch = [input_ids ] if self .gen_req .is_single else input_ids
167188 else :
168189 input_batch = self .tokenize ()
190+
169191 for index , input_tokens in enumerate (input_batch ):
170192 decode_config = copy (self .decode_config )
171193 decode_config .update_from_sampling_params (
@@ -189,7 +211,10 @@ async def run(self):
189211
190212 await asyncio .gather (* gen_processes )
191213 self .generate_response (gen_processes , streaming )
214+
192215 finally :
216+ # Remove request from queue when done
217+ self .service .remove_from_queue ()
193218 self .responder .ensure_response ()
194219
195220 def generate_response (
0 commit comments