diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 9ff35d47b9..bc0ec7a022 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -14,17 +14,16 @@ # limitations under the License. """ -import zmq import time -from random import randint import uuid + import numpy as np +from fastdeploy import envs from fastdeploy.input.preprocess import InputPreprocessor -from fastdeploy.engine.request import Request -from fastdeploy.inter_communicator import ZmqClient, IPCSignal +from fastdeploy.inter_communicator import IPCSignal, ZmqClient from fastdeploy.metrics.work_metrics import work_process_metrics -from fastdeploy.utils import api_server_logger, EngineError +from fastdeploy.utils import EngineError, StatefulSemaphore, api_server_logger class EngineClient: @@ -33,7 +32,7 @@ class EngineClient: """ def __init__(self, tokenizer, max_model_len, tensor_parallel_size, pid, limit_mm_per_prompt, mm_processor_kwargs, - enable_mm=False, reasoning_parser=None): + enable_mm=False, reasoning_parser=None, workers=1): input_processor = InputPreprocessor(tokenizer, reasoning_parser, limit_mm_per_prompt, @@ -57,6 +56,7 @@ def __init__(self, tokenizer, max_model_len, tensor_parallel_size, pid, limit_mm dtype=np.int32, suffix=pid, create=False) + self.semaphore = StatefulSemaphore((envs.FD_SUPPORT_MAX_CONNECTIONS + workers - 1) // workers) def create_zmq_client(self, model, mode): """ @@ -75,7 +75,6 @@ def format_and_add_data(self, prompts: dict): if "request_id" not in prompts: request_id = str(uuid.uuid4()) prompts["request_id"] = request_id - query_list = [] if "max_tokens" not in prompts: prompts["max_tokens"] = self.max_model_len - 1 @@ -178,7 +177,7 @@ def vaild_parameters(self, data): if data.get("temperature"): if data["temperature"] < 0: - raise ValueError(f"temperature must be non-negative") + raise ValueError("temperature must be non-negative") if data.get("presence_penalty"): diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index dd82470fc5..384b288f02 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -13,18 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. """ +import asyncio import os import threading import time +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from multiprocessing import current_process import uvicorn import zmq -from fastapi import FastAPI, Request +from fastapi import FastAPI, HTTPException, Request from fastapi.responses import JSONResponse, Response, StreamingResponse from prometheus_client import CONTENT_TYPE_LATEST -from fastdeploy.metrics.trace_util import inject_to_metadata,instrument from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.engine.engine import LLMEngine @@ -33,8 +34,8 @@ ChatCompletionResponse, CompletionRequest, CompletionResponse, - ErrorResponse, - ControlSchedulerRequest) + ControlSchedulerRequest, + ErrorResponse) from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat from fastdeploy.entrypoints.openai.serving_completion import \ OpenAIServingCompletion @@ -42,10 +43,10 @@ cleanup_prometheus_files, get_filtered_metrics, main_process_metrics) -from fastdeploy.utils import (FlexibleArgumentParser, api_server_logger, - console_logger, is_port_available, - retrive_model_from_server) - +from fastdeploy.metrics.trace_util import inject_to_metadata, instrument +from fastdeploy.utils import (FlexibleArgumentParser, StatefulSemaphore, + api_server_logger, console_logger, + is_port_available, retrive_model_from_server) parser = FlexibleArgumentParser() parser.add_argument("--port", @@ -65,6 +66,13 @@ default=-1, type=int, help="port for controller server") +parser.add_argument( + "--max-waiting-time", + default=-1, + type=int, + help="max waiting time for connection, if set value -1 means no waiting time limit", +) +parser.add_argument("--max-concurrency", default=512, type=int, help="max concurrency") parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() args.model = retrive_model_from_server(args.model) @@ -121,10 +129,11 @@ async def lifespan(app: FastAPI): args.tensor_parallel_size, pid, args.limit_mm_per_prompt, args.mm_processor_kwargs, args.enable_mm, - args.reasoning_parser) + args.reasoning_parser, + workers=args.workers) app.state.dynamic_load_weight = args.dynamic_load_weight - chat_handler = OpenAIServingChat(engine_client, pid) - completion_handler = OpenAIServingCompletion(engine_client, pid) + chat_handler = OpenAIServingChat(engine_client, pid, args.max_waiting_time) + completion_handler = OpenAIServingCompletion(engine_client, pid, args.max_waiting_time) engine_client.create_zmq_client(model=pid, mode=zmq.PUSH) engine_client.pid = pid app.state.engine_client = engine_client @@ -145,6 +154,40 @@ async def lifespan(app: FastAPI): instrument(app) +MAX_CONCURRENT_CONNECTIONS = (args.max_concurrency + args.workers - 1) // args.workers +connection_semaphore = StatefulSemaphore(MAX_CONCURRENT_CONNECTIONS) + + +@asynccontextmanager +async def connection_manager(): + """ + async context manager for connection manager + """ + try: + await asyncio.wait_for(connection_semaphore.acquire(), timeout=0.001) + yield + except asyncio.TimeoutError: + api_server_logger.info(f"Reach max request release: {connection_semaphore.status()}") + if connection_semaphore.locked(): + connection_semaphore.release() + raise HTTPException(status_code=429, detail="Too many requests") + + +def wrap_streaming_generator(original_generator: AsyncGenerator): + """ + Wrap an async generator to release the connection semaphore when the generator is finished. + """ + + async def wrapped_generator(): + try: + async for chunk in original_generator: + yield chunk + finally: + api_server_logger.debug(f"release: {connection_semaphore.status()}") + connection_semaphore.release() + + return wrapped_generator + # TODO 传递真实引擎值 通过pid 获取状态 @app.get("/health") def health(request: Request) -> Response: @@ -213,17 +256,24 @@ async def create_chat_completion(request: ChatCompletionRequest): return JSONResponse( content={"error": "Worker Service Not Healthy"}, status_code=304) - inject_to_metadata(request) - generator = await app.state.chat_handler.create_chat_completion(request) - - if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.code) + try: + async with connection_manager(): + inject_to_metadata(request) + generator = await app.state.chat_handler.create_chat_completion(request) + if isinstance(generator, ErrorResponse): + connection_semaphore.release() + return JSONResponse(content={"detail": generator.model_dump()}, status_code=generator.code) + elif isinstance(generator, ChatCompletionResponse): + connection_semaphore.release() + return JSONResponse(content=generator.model_dump()) + else: + wrapped_generator = wrap_streaming_generator(generator) + return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream") - elif isinstance(generator, ChatCompletionResponse): - return JSONResponse(content=generator.model_dump()) + except HTTPException as e: + api_server_logger.error(f"Error in chat completion: {str(e)}") + return JSONResponse(status_code=e.status_code, content={"detail": e.detail}) - return StreamingResponse(content=generator, media_type="text/event-stream") @app.post("/v1/completions") @@ -238,15 +288,20 @@ async def create_completion(request: CompletionRequest): content={"error": "Worker Service Not Healthy"}, status_code=304) - generator = await app.state.completion_handler.create_completion(request) - if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.code) - elif isinstance(generator, CompletionResponse): - return JSONResponse(content=generator.model_dump()) - - return StreamingResponse(content=generator, media_type="text/event-stream") - + try: + async with connection_manager(): + generator = await app.state.completion_handler.create_completion(request) + if isinstance(generator, ErrorResponse): + connection_semaphore.release() + return JSONResponse(content=generator.model_dump(), status_code=generator.code) + elif isinstance(generator, CompletionResponse): + connection_semaphore.release() + return JSONResponse(content=generator.model_dump()) + else: + wrapped_generator = wrap_streaming_generator(generator) + return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream") + except HTTPException as e: + return JSONResponse(status_code=e.status_code, content={"detail": e.detail}) @app.get("/update_model_weight") def update_model_weight(request: Request) -> Response: @@ -362,7 +417,7 @@ def control_scheduler(request: ControlSchedulerRequest): Control the scheduler behavior with the given parameters. """ content = ErrorResponse(object="", message="Scheduler updated successfully", code=0) - + global llm_engine if llm_engine is None: content.message = "Engine is not loaded" diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 31359e7285..87e90e3845 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -39,8 +39,9 @@ class OpenAIServingChat: OpenAI-style chat completions serving """ - def __init__(self, engine_client, pid): + def __init__(self, engine_client, pid, max_waiting_time): self.engine_client = engine_client + self.max_waiting_time = max_waiting_time self.pid = pid async def create_chat_completion( @@ -65,6 +66,15 @@ async def create_chat_completion( del current_req_dict + try: + api_server_logger.debug(f"{self.engine_client.semaphore.status()}") + if self.max_waiting_time < 0: + await self.engine_client.semaphore.acquire() + else: + await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time) + except Exception: + return ErrorResponse(code=408, message=f"Request queued time exceed {self.max_waiting_time}") + if request.stream: return self.chat_completion_stream_generator( request, request_id, @@ -269,6 +279,8 @@ async def chat_completion_stream_generator( yield f"data: {error_data}\n\n" finally: dealer.close() + self.engine_client.semaphore.release() + api_server_logger.info(f"release {self.engine_client.semaphore.status()}") yield "data: [DONE]\n\n" async def chat_completion_full_generator( @@ -341,6 +353,8 @@ async def chat_completion_full_generator( break finally: dealer.close() + self.engine_client.semaphore.release() + api_server_logger.info(f"release {self.engine_client.semaphore.status()}") choices = [] output = final_res["outputs"] diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index c69824400d..975c78f426 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -15,37 +15,26 @@ """ import asyncio -import aiozmq import json -from aiozmq import zmq -from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task import time -from collections.abc import AsyncGenerator, AsyncIterator -from collections.abc import Sequence as GenericSequence -from typing import Optional, Union, cast, TypeVar, List import uuid -from fastapi import Request +from typing import List + +import aiozmq +from aiozmq import zmq +from fastdeploy.engine.request import RequestOutput from fastdeploy.entrypoints.openai.protocol import ( - ErrorResponse, - CompletionRequest, - CompletionResponse, - CompletionStreamResponse, - CompletionResponseStreamChoice, - CompletionResponseChoice, - UsageInfo, - DeltaToolCall, - DeltaFunctionCall, - ToolCall, - FunctionCall -) + CompletionRequest, CompletionResponse, CompletionResponseChoice, + CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse, + UsageInfo) from fastdeploy.utils import api_server_logger -from fastdeploy.engine.request import RequestOutput class OpenAIServingCompletion: - def __init__(self, engine_client, pid): + def __init__(self, engine_client, pid, max_waiting_time): self.engine_client = engine_client + self.max_waiting_time = max_waiting_time self.pid = pid async def create_completion(self, request: CompletionRequest): @@ -98,6 +87,13 @@ async def create_completion(self, request: CompletionRequest): return ErrorResponse(message=str(e), code=400) del current_req_dict + try: + if self.max_waiting_time < 0: + await self.engine_client.semaphore.acquire() + else: + await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time) + except Exception: + return ErrorResponse(code=408, message=f"Request queued time exceed {self.max_waiting_time}") if request.stream: return self.completion_stream_generator( @@ -195,6 +191,7 @@ async def completion_full_generator( finally: if dealer is not None: dealer.close() + self.engine_client.semaphore.release() async def completion_stream_generator( @@ -327,6 +324,7 @@ async def completion_stream_generator( del request if dealer is not None: dealer.close() + self.engine_client.semaphore.release() yield "data: [DONE]\n\n" @@ -353,13 +351,13 @@ def request_output_to_completion_response( if request.echo: assert prompt_text is not None if request.max_tokens == 0: - token_ids = prompt_token_ids + # token_ids = prompt_token_ids output_text = prompt_text else: - token_ids = [*prompt_token_ids, *output["token_ids"]] + # token_ids = [*prompt_token_ids, *output["token_ids"]] output_text = prompt_text + output["text"] else: - token_ids = output["token_ids"] + # token_ids = output["token_ids"] output_text = output["text"] choice_data = CompletionResponseChoice( diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index ad13b0831c..7eec9a5cfe 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -121,6 +121,8 @@ # set traec exporter_otlp_headers. "EXPORTER_OTLP_HEADERS": lambda: os.getenv("EXPORTER_OTLP_HEADERS"), + # support max connections + "FD_SUPPORT_MAX_CONNECTIONS": lambda: 768, } diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index 7a81f96007..c64f5bf5ce 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -15,6 +15,7 @@ """ import argparse +import asyncio import codecs import importlib import logging @@ -576,6 +577,60 @@ def version(): except FileNotFoundError: llm_logger.error("[version.txt] Not Found!") +class StatefulSemaphore: + __slots__ = ("_semaphore", "_max_value", "_acquired_count", "_last_reset") + + """ + StatefulSemaphore is a class that wraps an asyncio.Semaphore and provides additional stateful information. + """ + + def __init__(self, value: int): + """ + StatefulSemaphore constructor + """ + if value < 0: + raise ValueError("Value must be non-negative.") + self._semaphore = asyncio.Semaphore(value) + self._max_value = value + self._acquired_count = 0 + self._last_reset = time.monotonic() + + async def acquire(self): + await self._semaphore.acquire() + self._acquired_count += 1 + + def release(self): + self._semaphore.release() + + self._acquired_count = max(0, self._acquired_count - 1) + + def locked(self) -> bool: + return self._semaphore.locked() + + @property + def available(self) -> int: + return self._max_value - self._acquired_count + + @property + def acquired(self) -> int: + return self._acquired_count + + @property + def max_value(self) -> int: + return self._max_value + + @property + def uptime(self) -> float: + return time.monotonic() - self._last_reset + + def status(self) -> dict: + return { + "available": self.available, + "acquired": self.acquired, + "max_value": self.max_value, + "uptime": round(self.uptime, 2), + } + llm_logger = get_logger("fastdeploy", "fastdeploy.log") data_processor_logger = get_logger("data_processor", "data_processor.log") scheduler_logger = get_logger("scheduler", "scheduler.log")