Skip to content

Commit 86223cd

Browse files
authored
pd master health, tokens and server busy error (#895)
1 parent 28de112 commit 86223cd

File tree

5 files changed

+49
-42
lines changed

5 files changed

+49
-42
lines changed

lightllm/server/api_http.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@
4141
from .multimodal_params import MultimodalParams
4242
from .httpserver.manager import HttpServerManager
4343
from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster
44-
from .api_lightllm import lightllm_get_score, lightllm_pd_generate_stream
44+
from .api_lightllm import lightllm_get_score
4545
from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size
4646
from lightllm.utils.log_utils import init_logger
47+
from lightllm.utils.error_utils import ServerBusyError
4748
from lightllm.server.metrics.manager import MetricClient
4849
from lightllm.utils.envs_utils import get_unique_server_name
4950
from dataclasses import dataclass
@@ -136,6 +137,9 @@ def get_model_name():
136137
@app.get("/health", summary="Check server health")
137138
@app.head("/health", summary="Check server health")
138139
async def healthcheck(request: Request):
140+
if g_objs.args.run_mode == "pd_master":
141+
return JSONResponse({"message": "Ok"}, status_code=200)
142+
139143
if os.environ.get("DEBUG_HEALTHCHECK_RETURN_FAIL") == "true":
140144
return JSONResponse({"message": "Error"}, status_code=503)
141145
from lightllm.utils.health_check import health_check, health_obj
@@ -175,6 +179,9 @@ async def token_load(request: Request):
175179
async def generate(request: Request) -> Response:
176180
try:
177181
return await g_objs.g_generate_func(request, g_objs.httpserver_manager)
182+
except ServerBusyError as e:
183+
logger.error("%s", str(e), exc_info=True)
184+
return create_error_response(HTTPStatus.SERVICE_UNAVAILABLE, str(e))
178185
except Exception as e:
179186
logger.error("An error occurred: %s", str(e), exc_info=True)
180187
return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e))
@@ -184,15 +191,9 @@ async def generate(request: Request) -> Response:
184191
async def generate_stream(request: Request) -> Response:
185192
try:
186193
return await g_objs.g_generate_stream_func(request, g_objs.httpserver_manager)
187-
except Exception as e:
188-
logger.error("An error occurred: %s", str(e), exc_info=True)
189-
return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e))
190-
191-
192-
@app.post("/pd_generate_stream")
193-
async def pd_generate_stream(request: Request) -> Response:
194-
try:
195-
return await lightllm_pd_generate_stream(request, g_objs.httpserver_manager)
194+
except ServerBusyError as e:
195+
logger.error("%s", str(e), exc_info=True)
196+
return create_error_response(HTTPStatus.SERVICE_UNAVAILABLE, str(e))
196197
except Exception as e:
197198
logger.error("An error occurred: %s", str(e), exc_info=True)
198199
return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e))

lightllm/server/api_lightllm.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -136,29 +136,3 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
136136

137137
background_tasks = BackgroundTasks()
138138
return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks)
139-
140-
141-
async def lightllm_pd_generate_stream(request: Request, httpserver_manager: HttpServerManager) -> Response:
142-
143-
request_dict = await request.json()
144-
prompt = request_dict.pop("inputs")
145-
sample_params_dict = request_dict["parameters"]
146-
_ = sample_params_dict.pop("return_details", False)
147-
sampling_params = SamplingParams()
148-
sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict)
149-
sampling_params.verify()
150-
if sampling_params.best_of != 1:
151-
raise Exception("stream api only support best_of == 1")
152-
153-
multimodal_params_dict = request_dict.get("multimodal_params", {})
154-
multimodal_params = MultimodalParams(**multimodal_params_dict)
155-
results_generator = httpserver_manager.generate(prompt, sampling_params, multimodal_params, request=request)
156-
157-
# Streaming case
158-
async def stream_results() -> AsyncGenerator[bytes, None]:
159-
async for sub_req_id, request_output, metadata, finish_status in results_generator:
160-
ret = [sub_req_id, request_output, metadata, finish_status.value]
161-
yield ("data:" + json.dumps(ret, ensure_ascii=False) + "\n\n").encode("utf-8")
162-
163-
background_tasks = BackgroundTasks()
164-
return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks)

lightllm/server/api_openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from .multimodal_params import MultimodalParams
2727
from .httpserver.manager import HttpServerManager
2828
from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster
29-
from .api_lightllm import lightllm_get_score, lightllm_pd_generate_stream
29+
from .api_lightllm import lightllm_get_score
3030
from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size
3131

3232
from lightllm.utils.log_utils import init_logger

lightllm/server/httpserver_for_pd_master/manager.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from lightllm.server.metrics.manager import MetricClient
2525
from lightllm.utils.statics_utils import MovingAverage
2626
from lightllm.server.httpserver.manager import AsyncQueue
27+
from lightllm.utils.error_utils import ServerBusyError
2728

2829
logger = init_logger(__name__)
2930

@@ -87,9 +88,22 @@ async def update_req_status(self, upkv_status: UpKVStatus):
8788
pass
8889
return
8990

90-
def tokens(self, prompt: str):
91-
# to do
92-
raise NotImplementedError("tokens is not implements")
91+
def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwargs=None):
92+
kwargs = {} if kwargs is None else kwargs
93+
prompt_ids = self.tokenizer.encode(prompt, None, **kwargs)
94+
image_tokens = 0
95+
img_count = 0
96+
audio_tokens = 0
97+
audio_count = 0
98+
for img in multimodal_params.images:
99+
img_count += 1
100+
self.tokenizer.init_imageitem_extral_params(img, multimodal_params, samping_params)
101+
image_tokens += self.tokenizer.get_image_token_length(img)
102+
for audio in multimodal_params.audios:
103+
audio_count += 1
104+
self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, samping_params)
105+
audio_tokens += self.tokenizer.get_audio_token_length(audio)
106+
return len(prompt_ids) + image_tokens + img_count + audio_tokens + audio_count
93107

94108
async def select_p_d_node(
95109
self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams
@@ -219,8 +233,8 @@ async def fetch_stream(
219233
try:
220234
await asyncio.wait_for(up_status_event.wait(), timeout=60)
221235
except asyncio.TimeoutError:
222-
logger.warning(f"group_request_id: {group_request_id} kv move time out err")
223-
assert False, f"req_id {group_request_id} kv move time out, server is busy"
236+
logger.warning(f"group_request_id: {group_request_id} kv move time out err, server is busy now.")
237+
raise ServerBusyError()
224238

225239
sampling_params.move_kv_to_decode_node.initialize(None)
226240
sampling_params.max_new_tokens = old_max_new_tokens - 1

lightllm/utils/error_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
class ServerBusyError(Exception):
2+
"""Custom exception for server busy/overload situations"""
3+
4+
def __init__(self, message="Server is busy, please try again later", status_code=503):
5+
"""
6+
Initialize the ServerBusyError
7+
8+
Args:
9+
message (str): Error message to display
10+
status_code (int): HTTP status code (default 503 Service Unavailable)
11+
"""
12+
super().__init__(message)
13+
self.message = message
14+
self.status_code = status_code # HTTP 503 Service Unavailable
15+
16+
def __str__(self):
17+
"""String representation of the error"""
18+
return f"{self.message} (Status code: {self.status_code})"

0 commit comments

Comments
 (0)