Skip to content

Commit 6706ccb

Browse files
authored
[BugFix] fix too many open files problem (#3275)
1 parent 1b6f482 commit 6706ccb

File tree

6 files changed

+178
-23
lines changed

6 files changed

+178
-23
lines changed

fastdeploy/entrypoints/engine_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
2525
from fastdeploy.metrics.work_metrics import work_process_metrics
2626
from fastdeploy.platforms import current_platform
27-
from fastdeploy.utils import EngineError, api_server_logger
27+
from fastdeploy.utils import EngineError, StatefulSemaphore, api_server_logger
2828

2929

3030
class EngineClient:
@@ -44,6 +44,7 @@ def __init__(
4444
reasoning_parser=None,
4545
data_parallel_size=1,
4646
enable_logprob=False,
47+
workers=1,
4748
):
4849
input_processor = InputPreprocessor(
4950
tokenizer,
@@ -76,6 +77,7 @@ def __init__(
7677
suffix=pid,
7778
create=False,
7879
)
80+
self.semaphore = StatefulSemaphore((envs.FD_SUPPORT_MAX_CONNECTIONS + workers - 1) // workers)
7981

8082
def create_zmq_client(self, model, mode):
8183
"""

fastdeploy/entrypoints/openai/api_server.py

Lines changed: 80 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,17 @@
1414
# limitations under the License.
1515
"""
1616

17+
import asyncio
1718
import os
1819
import threading
1920
import time
21+
from collections.abc import AsyncGenerator
2022
from contextlib import asynccontextmanager
2123
from multiprocessing import current_process
2224

2325
import uvicorn
2426
import zmq
25-
from fastapi import FastAPI, Request
27+
from fastapi import FastAPI, HTTPException, Request
2628
from fastapi.responses import JSONResponse, Response, StreamingResponse
2729
from prometheus_client import CONTENT_TYPE_LATEST
2830

@@ -48,6 +50,7 @@
4850
from fastdeploy.metrics.trace_util import fd_start_span, inject_to_metadata, instrument
4951
from fastdeploy.utils import (
5052
FlexibleArgumentParser,
53+
StatefulSemaphore,
5154
api_server_logger,
5255
console_logger,
5356
is_port_available,
@@ -60,6 +63,13 @@
6063
parser.add_argument("--workers", default=1, type=int, help="number of workers")
6164
parser.add_argument("--metrics-port", default=8001, type=int, help="port for metrics server")
6265
parser.add_argument("--controller-port", default=-1, type=int, help="port for controller server")
66+
parser.add_argument(
67+
"--max-waiting-time",
68+
default=-1,
69+
type=int,
70+
help="max waiting time for connection, if set value -1 means no waiting time limit",
71+
)
72+
parser.add_argument("--max-concurrency", default=512, type=int, help="max concurrency")
6373
parser = EngineArgs.add_cli_args(parser)
6474
args = parser.parse_args()
6575
args.model = retrive_model_from_server(args.model, args.revision)
@@ -115,10 +125,11 @@ async def lifespan(app: FastAPI):
115125
args.reasoning_parser,
116126
args.data_parallel_size,
117127
args.enable_logprob,
128+
args.workers,
118129
)
119130
app.state.dynamic_load_weight = args.dynamic_load_weight
120-
chat_handler = OpenAIServingChat(engine_client, pid, args.ips)
121-
completion_handler = OpenAIServingCompletion(engine_client, pid, args.ips)
131+
chat_handler = OpenAIServingChat(engine_client, pid, args.ips, args.max_waiting_time)
132+
completion_handler = OpenAIServingCompletion(engine_client, pid, args.ips, args.max_waiting_time)
122133
engine_client.create_zmq_client(model=pid, mode=zmq.PUSH)
123134
engine_client.pid = pid
124135
app.state.engine_client = engine_client
@@ -140,6 +151,41 @@ async def lifespan(app: FastAPI):
140151
instrument(app)
141152

142153

154+
MAX_CONCURRENT_CONNECTIONS = (args.max_concurrency + args.workers - 1) // args.workers
155+
connection_semaphore = StatefulSemaphore(MAX_CONCURRENT_CONNECTIONS)
156+
157+
158+
@asynccontextmanager
159+
async def connection_manager():
160+
"""
161+
async context manager for connection manager
162+
"""
163+
try:
164+
await asyncio.wait_for(connection_semaphore.acquire(), timeout=0.001)
165+
yield
166+
except asyncio.TimeoutError:
167+
api_server_logger.info(f"Reach max request release: {connection_semaphore.status()}")
168+
if connection_semaphore.locked():
169+
connection_semaphore.release()
170+
raise HTTPException(status_code=429, detail="Too many requests")
171+
172+
173+
def wrap_streaming_generator(original_generator: AsyncGenerator):
174+
"""
175+
Wrap an async generator to release the connection semaphore when the generator is finished.
176+
"""
177+
178+
async def wrapped_generator():
179+
try:
180+
async for chunk in original_generator:
181+
yield chunk
182+
finally:
183+
api_server_logger.debug(f"release: {connection_semaphore.status()}")
184+
connection_semaphore.release()
185+
186+
return wrapped_generator
187+
188+
143189
# TODO 传递真实引擎值 通过pid 获取状态
144190
@app.get("/health")
145191
def health(request: Request) -> Response:
@@ -202,16 +248,23 @@ async def create_chat_completion(request: ChatCompletionRequest):
202248
status, msg = app.state.engine_client.is_workers_alive()
203249
if not status:
204250
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
205-
inject_to_metadata(request)
206-
generator = await app.state.chat_handler.create_chat_completion(request)
207-
208-
if isinstance(generator, ErrorResponse):
209-
return JSONResponse(content=generator.model_dump(), status_code=generator.code)
210-
211-
elif isinstance(generator, ChatCompletionResponse):
212-
return JSONResponse(content=generator.model_dump())
213-
214-
return StreamingResponse(content=generator, media_type="text/event-stream")
251+
try:
252+
async with connection_manager():
253+
inject_to_metadata(request)
254+
generator = await app.state.chat_handler.create_chat_completion(request)
255+
if isinstance(generator, ErrorResponse):
256+
connection_semaphore.release()
257+
return JSONResponse(content={"detail": generator.model_dump()}, status_code=generator.code)
258+
elif isinstance(generator, ChatCompletionResponse):
259+
connection_semaphore.release()
260+
return JSONResponse(content=generator.model_dump())
261+
else:
262+
wrapped_generator = wrap_streaming_generator(generator)
263+
return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream")
264+
265+
except HTTPException as e:
266+
api_server_logger.error(f"Error in chat completion: {str(e)}")
267+
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
215268

216269

217270
@app.post("/v1/completions")
@@ -224,13 +277,20 @@ async def create_completion(request: CompletionRequest):
224277
if not status:
225278
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
226279

227-
generator = await app.state.completion_handler.create_completion(request)
228-
if isinstance(generator, ErrorResponse):
229-
return JSONResponse(content=generator.model_dump(), status_code=generator.code)
230-
elif isinstance(generator, CompletionResponse):
231-
return JSONResponse(content=generator.model_dump())
232-
233-
return StreamingResponse(content=generator, media_type="text/event-stream")
280+
try:
281+
async with connection_manager():
282+
generator = await app.state.completion_handler.create_completion(request)
283+
if isinstance(generator, ErrorResponse):
284+
connection_semaphore.release()
285+
return JSONResponse(content=generator.model_dump(), status_code=generator.code)
286+
elif isinstance(generator, CompletionResponse):
287+
connection_semaphore.release()
288+
return JSONResponse(content=generator.model_dump())
289+
else:
290+
wrapped_generator = wrap_streaming_generator(generator)
291+
return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream")
292+
except HTTPException as e:
293+
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
234294

235295

236296
@app.get("/update_model_weight")

fastdeploy/entrypoints/openai/serving_chat.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,11 @@ class OpenAIServingChat:
4949
OpenAI-style chat completions serving
5050
"""
5151

52-
def __init__(self, engine_client, pid, ips):
52+
def __init__(self, engine_client, pid, ips, max_waiting_time):
5353
self.engine_client = engine_client
5454
self.pid = pid
5555
self.master_ip = ips
56+
self.max_waiting_time = max_waiting_time
5657
self.host_ip = get_host_ip()
5758
if self.master_ip is not None:
5859
if isinstance(self.master_ip, list):
@@ -94,6 +95,15 @@ async def create_chat_completion(self, request: ChatCompletionRequest):
9495

9596
del current_req_dict
9697

98+
try:
99+
api_server_logger.debug(f"{self.engine_client.semaphore.status()}")
100+
if self.max_waiting_time < 0:
101+
await self.engine_client.semaphore.acquire()
102+
else:
103+
await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time)
104+
except Exception:
105+
return ErrorResponse(code=408, message=f"Request queued time exceed {self.max_waiting_time}")
106+
97107
if request.stream:
98108
return self.chat_completion_stream_generator(request, request_id, request.model, prompt_token_ids)
99109
else:
@@ -310,6 +320,8 @@ async def chat_completion_stream_generator(
310320
yield f"data: {error_data}\n\n"
311321
finally:
312322
dealer.close()
323+
self.engine_client.semaphore.release()
324+
api_server_logger.info(f"release {self.engine_client.semaphore.status()}")
313325
yield "data: [DONE]\n\n"
314326

315327
async def chat_completion_full_generator(
@@ -384,6 +396,8 @@ async def chat_completion_full_generator(
384396
break
385397
finally:
386398
dealer.close()
399+
self.engine_client.semaphore.release()
400+
api_server_logger.info(f"release {self.engine_client.semaphore.status()}")
387401

388402
choices = []
389403
output = final_res["outputs"]

fastdeploy/entrypoints/openai/serving_completion.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,12 @@
4040

4141

4242
class OpenAIServingCompletion:
43-
def __init__(self, engine_client, pid, ips):
43+
def __init__(self, engine_client, pid, ips, max_waiting_time):
4444
self.engine_client = engine_client
4545
self.pid = pid
4646
self.master_ip = ips
4747
self.host_ip = get_host_ip()
48+
self.max_waiting_time = max_waiting_time
4849
if self.master_ip is not None:
4950
if isinstance(self.master_ip, list):
5051
self.master_ip = self.master_ip[0]
@@ -114,6 +115,14 @@ async def create_completion(self, request: CompletionRequest):
114115

115116
del current_req_dict
116117

118+
try:
119+
if self.max_waiting_time < 0:
120+
await self.engine_client.semaphore.acquire()
121+
else:
122+
await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time)
123+
except Exception:
124+
return ErrorResponse(code=408, message=f"Request queued time exceed {self.max_waiting_time}")
125+
117126
if request.stream:
118127
return self.completion_stream_generator(
119128
request=request,
@@ -223,6 +232,7 @@ async def completion_full_generator(
223232
finally:
224233
if dealer is not None:
225234
dealer.close()
235+
self.engine_client.semaphore.release()
226236

227237
async def completion_stream_generator(
228238
self,
@@ -372,6 +382,7 @@ async def completion_stream_generator(
372382
del request
373383
if dealer is not None:
374384
dealer.close()
385+
self.engine_client.semaphore.release()
375386
yield "data: [DONE]\n\n"
376387

377388
def request_output_to_completion_response(

fastdeploy/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@
8282
"ENABLE_V1_KVCACHE_SCHEDULER": lambda: int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")),
8383
# set trace attribute job_id.
8484
"FD_JOB_ID": lambda: os.getenv("FD_JOB_ID"),
85+
# support max connections
86+
"FD_SUPPORT_MAX_CONNECTIONS": lambda: 768,
8587
}
8688

8789

fastdeploy/utils.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""
1616

1717
import argparse
18+
import asyncio
1819
import codecs
1920
import importlib
2021
import logging
@@ -291,6 +292,16 @@ def extract_tar(tar_path, output_dir):
291292
raise RuntimeError(f"Extraction failed: {e!s}")
292293

293294

295+
def get_limited_max_value(max_value):
296+
def validator(value):
297+
value = float(value)
298+
if value > max_value:
299+
raise argparse.ArgumentTypeError(f"The value cannot exceed {max_value}")
300+
return value
301+
302+
return validator
303+
304+
294305
def download_model(url, output_dir, temp_tar):
295306
"""
296307
下载模型,并将其解压到指定目录。
@@ -596,6 +607,61 @@ def version():
596607
return content
597608

598609

610+
class StatefulSemaphore:
611+
__slots__ = ("_semaphore", "_max_value", "_acquired_count", "_last_reset")
612+
613+
"""
614+
StatefulSemaphore is a class that wraps an asyncio.Semaphore and provides additional stateful information.
615+
"""
616+
617+
def __init__(self, value: int):
618+
"""
619+
StatefulSemaphore constructor
620+
"""
621+
if value < 0:
622+
raise ValueError("Value must be non-negative.")
623+
self._semaphore = asyncio.Semaphore(value)
624+
self._max_value = value
625+
self._acquired_count = 0
626+
self._last_reset = time.monotonic()
627+
628+
async def acquire(self):
629+
await self._semaphore.acquire()
630+
self._acquired_count += 1
631+
632+
def release(self):
633+
self._semaphore.release()
634+
635+
self._acquired_count = max(0, self._acquired_count - 1)
636+
637+
def locked(self) -> bool:
638+
return self._semaphore.locked()
639+
640+
@property
641+
def available(self) -> int:
642+
return self._max_value - self._acquired_count
643+
644+
@property
645+
def acquired(self) -> int:
646+
return self._acquired_count
647+
648+
@property
649+
def max_value(self) -> int:
650+
return self._max_value
651+
652+
@property
653+
def uptime(self) -> float:
654+
return time.monotonic() - self._last_reset
655+
656+
def status(self) -> dict:
657+
return {
658+
"available": self.available,
659+
"acquired": self.acquired,
660+
"max_value": self.max_value,
661+
"uptime": round(self.uptime, 2),
662+
}
663+
664+
599665
llm_logger = get_logger("fastdeploy", "fastdeploy.log")
600666
data_processor_logger = get_logger("data_processor", "data_processor.log")
601667
scheduler_logger = get_logger("scheduler", "scheduler.log")

0 commit comments

Comments
 (0)