Skip to content

Commit 152e2df

Browse files
arekayShixiaowei02
andauthored
[Disaggregated] Add retry knobs and handling (NVIDIA#5808)
Signed-off-by: Rashid Kaleem <[email protected]> Signed-off-by: Shi Xiaowei <[email protected]> Co-authored-by: Shi Xiaowei <[email protected]>
1 parent fc8b29c commit 152e2df

File tree

3 files changed

+35
-14
lines changed

3 files changed

+35
-14
lines changed

tensorrt_llm/commands/serve.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ def disaggregated(config_file: Optional[str],
362362
gen_servers=gen_server_urls,
363363
req_timeout_secs=request_timeout,
364364
server_start_timeout_secs=server_start_timeout,
365+
max_retries=disagg_cfg.max_retries,
365366
ctx_router_config=disagg_cfg.ctx_router_config,
366367
gen_router_config=disagg_cfg.gen_router_config,
367368
conditional_disagg_config=disagg_cfg.conditional_disagg_config,

tensorrt_llm/llmapi/disagg_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class DisaggServerConfig():
5050
ctx_router_config: Optional[RouterConfig] = None
5151
gen_router_config: Optional[RouterConfig] = None
5252
conditional_disagg_config: Optional[ConditionalDisaggConfig] = None
53+
max_retries: int = 3
5354

5455

5556
@dataclass
@@ -74,6 +75,7 @@ def parse_disagg_config_file(yaml_config_file: str):
7475

7576
def extract_disagg_cfg(hostname: str = 'localhost',
7677
port: int = 8000,
78+
max_retries: int = 3,
7779
context_servers: Optional[dict] = None,
7880
generation_servers: Optional[dict] = None,
7981
conditional_disagg_config: Optional[dict] = None,
@@ -112,7 +114,7 @@ def extract_disagg_cfg(hostname: str = 'localhost',
112114

113115
config = DisaggServerConfig(server_configs, hostname, port,
114116
ctx_router_config, gen_router_config,
115-
conditional_disagg_config)
117+
conditional_disagg_config, max_retries)
116118

117119
return config
118120

tensorrt_llm/serve/openai_disagg_server.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from fastapi import FastAPI, HTTPException
1414
from fastapi.exceptions import RequestValidationError
1515
from fastapi.responses import JSONResponse, Response, StreamingResponse
16+
from starlette.status import HTTP_429_TOO_MANY_REQUESTS
1617

1718
# yapf: disable
1819
from tensorrt_llm.executor import CppExecutorError
@@ -40,6 +41,7 @@ def __init__(self,
4041
gen_servers: List[str],
4142
req_timeout_secs: int = 180,
4243
server_start_timeout_secs: int = 180,
44+
max_retries: int = 3,
4345
ctx_router_config: Optional[RouterConfig] = None,
4446
gen_router_config: Optional[RouterConfig] = None,
4547
conditional_disagg_config: Optional[ConditionalDisaggConfig] = None,
@@ -52,6 +54,10 @@ def __init__(self,
5254
self.gen_router = create_router(gen_router_config, gen_servers, metadata_server_cfg, self.metadata_server)
5355
self.conditional_disagg_config = conditional_disagg_config
5456

57+
if max_retries < 0:
58+
raise ValueError(f"Max retries {max_retries} must be greater than or equal to 0")
59+
self.max_retries = max_retries
60+
logger.info(f"Server max retries: {self.max_retries}")
5561

5662
if (len(self.gen_servers) == 0):
5763
raise ValueError("At least one generation server must be provided")
@@ -323,20 +329,32 @@ async def send_request(self, url: str,
323329
endpoint: str,
324330
response_type: Type[Union[CompletionResponse, ChatCompletionResponse]],
325331
create_generator: callable) -> Union[CompletionResponse, ChatCompletionResponse, StreamingResponse]:
326-
if request.stream:
327-
response_generator = create_generator(url, request)
328-
return StreamingResponse(content=response_generator, media_type="text/event-stream")
329-
else:
330-
async with self.session.post(url + endpoint, json=request.model_dump(exclude_unset=True)) as response:
331-
content_type = response.headers.get("Content-Type", "")
332-
if "text/event-stream" in content_type:
333-
raise ValueError("Received an event-stream although request stream was False")
332+
for attempt in range(self.max_retries + 1):
333+
try:
334+
if request.stream:
335+
response_generator = create_generator(url, request)
336+
return StreamingResponse(content=response_generator, media_type="text/event-stream")
337+
else:
338+
async with self.session.post(url + endpoint, json=request.model_dump(exclude_unset=True)) as response:
339+
content_type = response.headers.get("Content-Type", "")
340+
if "text/event-stream" in content_type:
341+
raise ValueError("Received an event-stream although request stream was False")
342+
343+
response_dict = await response.json()
344+
if not response.ok:
345+
logger.error(f"Received failed response {response_dict}")
346+
response.raise_for_status()
347+
return response_type(**response_dict)
348+
except (aiohttp.ClientError, OSError) as e:
349+
if attempt == self.max_retries:
350+
raise HTTPException(status_code=HTTP_429_TOO_MANY_REQUESTS, detail=f"Too many requests") from e
351+
logger.error(f"Client error: {e} - retry {attempt} of {self.max_retries}")
352+
# TODO : add a configurable retry interval
353+
await asyncio.sleep(1)
354+
except Exception as e:
355+
logger.error(f"Error encountered while processing request to {url+endpoint}: {e}")
356+
raise
334357

335-
response_dict = await response.json()
336-
if not response.ok:
337-
logger.error(f"Received failed response {response_dict}")
338-
response.raise_for_status()
339-
return response_type(**response_dict)
340358

341359
async def send_completion_request(self, url: str, request: CompletionRequest) -> Union[CompletionResponse, StreamingResponse]:
342360
return await self.send_request(url, request, "/v1/completions", CompletionResponse, self.create_completion_generator)

0 commit comments

Comments
 (0)