14
14
# limitations under the License.
15
15
"""
16
16
17
+ import asyncio
17
18
import os
18
19
import threading
19
20
import time
21
+ from collections .abc import AsyncGenerator
20
22
from contextlib import asynccontextmanager
21
23
from multiprocessing import current_process
22
24
23
25
import uvicorn
24
26
import zmq
25
- from fastapi import FastAPI , Request
27
+ from fastapi import FastAPI , HTTPException , Request
26
28
from fastapi .responses import JSONResponse , Response , StreamingResponse
27
29
from prometheus_client import CONTENT_TYPE_LATEST
28
30
48
50
from fastdeploy .metrics .trace_util import fd_start_span , inject_to_metadata , instrument
49
51
from fastdeploy .utils import (
50
52
FlexibleArgumentParser ,
53
+ StatefulSemaphore ,
51
54
api_server_logger ,
52
55
console_logger ,
53
56
is_port_available ,
60
63
parser .add_argument ("--workers" , default = 1 , type = int , help = "number of workers" )
61
64
parser .add_argument ("--metrics-port" , default = 8001 , type = int , help = "port for metrics server" )
62
65
parser .add_argument ("--controller-port" , default = - 1 , type = int , help = "port for controller server" )
66
+ parser .add_argument (
67
+ "--max-waiting-time" ,
68
+ default = - 1 ,
69
+ type = int ,
70
+ help = "max waiting time for connection, if set value -1 means no waiting time limit" ,
71
+ )
72
+ parser .add_argument ("--max-concurrency" , default = 512 , type = int , help = "max concurrency" )
63
73
parser = EngineArgs .add_cli_args (parser )
64
74
args = parser .parse_args ()
65
75
args .model = retrive_model_from_server (args .model , args .revision )
@@ -115,10 +125,11 @@ async def lifespan(app: FastAPI):
115
125
args .reasoning_parser ,
116
126
args .data_parallel_size ,
117
127
args .enable_logprob ,
128
+ args .workers ,
118
129
)
119
130
app .state .dynamic_load_weight = args .dynamic_load_weight
120
- chat_handler = OpenAIServingChat (engine_client , pid , args .ips )
121
- completion_handler = OpenAIServingCompletion (engine_client , pid , args .ips )
131
+ chat_handler = OpenAIServingChat (engine_client , pid , args .ips , args . max_waiting_time )
132
+ completion_handler = OpenAIServingCompletion (engine_client , pid , args .ips , args . max_waiting_time )
122
133
engine_client .create_zmq_client (model = pid , mode = zmq .PUSH )
123
134
engine_client .pid = pid
124
135
app .state .engine_client = engine_client
@@ -140,6 +151,41 @@ async def lifespan(app: FastAPI):
140
151
instrument (app )
141
152
142
153
154
+ MAX_CONCURRENT_CONNECTIONS = (args .max_concurrency + args .workers - 1 ) // args .workers
155
+ connection_semaphore = StatefulSemaphore (MAX_CONCURRENT_CONNECTIONS )
156
+
157
+
158
+ @asynccontextmanager
159
+ async def connection_manager ():
160
+ """
161
+ async context manager for connection manager
162
+ """
163
+ try :
164
+ await asyncio .wait_for (connection_semaphore .acquire (), timeout = 0.001 )
165
+ yield
166
+ except asyncio .TimeoutError :
167
+ api_server_logger .info (f"Reach max request release: { connection_semaphore .status ()} " )
168
+ if connection_semaphore .locked ():
169
+ connection_semaphore .release ()
170
+ raise HTTPException (status_code = 429 , detail = "Too many requests" )
171
+
172
+
173
+ def wrap_streaming_generator (original_generator : AsyncGenerator ):
174
+ """
175
+ Wrap an async generator to release the connection semaphore when the generator is finished.
176
+ """
177
+
178
+ async def wrapped_generator ():
179
+ try :
180
+ async for chunk in original_generator :
181
+ yield chunk
182
+ finally :
183
+ api_server_logger .debug (f"release: { connection_semaphore .status ()} " )
184
+ connection_semaphore .release ()
185
+
186
+ return wrapped_generator
187
+
188
+
143
189
# TODO 传递真实引擎值 通过pid 获取状态
144
190
@app .get ("/health" )
145
191
def health (request : Request ) -> Response :
@@ -202,16 +248,23 @@ async def create_chat_completion(request: ChatCompletionRequest):
202
248
status , msg = app .state .engine_client .is_workers_alive ()
203
249
if not status :
204
250
return JSONResponse (content = {"error" : "Worker Service Not Healthy" }, status_code = 304 )
205
- inject_to_metadata (request )
206
- generator = await app .state .chat_handler .create_chat_completion (request )
207
-
208
- if isinstance (generator , ErrorResponse ):
209
- return JSONResponse (content = generator .model_dump (), status_code = generator .code )
210
-
211
- elif isinstance (generator , ChatCompletionResponse ):
212
- return JSONResponse (content = generator .model_dump ())
213
-
214
- return StreamingResponse (content = generator , media_type = "text/event-stream" )
251
+ try :
252
+ async with connection_manager ():
253
+ inject_to_metadata (request )
254
+ generator = await app .state .chat_handler .create_chat_completion (request )
255
+ if isinstance (generator , ErrorResponse ):
256
+ connection_semaphore .release ()
257
+ return JSONResponse (content = {"detail" : generator .model_dump ()}, status_code = generator .code )
258
+ elif isinstance (generator , ChatCompletionResponse ):
259
+ connection_semaphore .release ()
260
+ return JSONResponse (content = generator .model_dump ())
261
+ else :
262
+ wrapped_generator = wrap_streaming_generator (generator )
263
+ return StreamingResponse (content = wrapped_generator (), media_type = "text/event-stream" )
264
+
265
+ except HTTPException as e :
266
+ api_server_logger .error (f"Error in chat completion: { str (e )} " )
267
+ return JSONResponse (status_code = e .status_code , content = {"detail" : e .detail })
215
268
216
269
217
270
@app .post ("/v1/completions" )
@@ -224,13 +277,20 @@ async def create_completion(request: CompletionRequest):
224
277
if not status :
225
278
return JSONResponse (content = {"error" : "Worker Service Not Healthy" }, status_code = 304 )
226
279
227
- generator = await app .state .completion_handler .create_completion (request )
228
- if isinstance (generator , ErrorResponse ):
229
- return JSONResponse (content = generator .model_dump (), status_code = generator .code )
230
- elif isinstance (generator , CompletionResponse ):
231
- return JSONResponse (content = generator .model_dump ())
232
-
233
- return StreamingResponse (content = generator , media_type = "text/event-stream" )
280
+ try :
281
+ async with connection_manager ():
282
+ generator = await app .state .completion_handler .create_completion (request )
283
+ if isinstance (generator , ErrorResponse ):
284
+ connection_semaphore .release ()
285
+ return JSONResponse (content = generator .model_dump (), status_code = generator .code )
286
+ elif isinstance (generator , CompletionResponse ):
287
+ connection_semaphore .release ()
288
+ return JSONResponse (content = generator .model_dump ())
289
+ else :
290
+ wrapped_generator = wrap_streaming_generator (generator )
291
+ return StreamingResponse (content = wrapped_generator (), media_type = "text/event-stream" )
292
+ except HTTPException as e :
293
+ return JSONResponse (status_code = e .status_code , content = {"detail" : e .detail })
234
294
235
295
236
296
@app .get ("/update_model_weight" )
0 commit comments