1313from fastapi import FastAPI , HTTPException
1414from fastapi .exceptions import RequestValidationError
1515from fastapi .responses import JSONResponse , Response , StreamingResponse
16+ from starlette .status import HTTP_429_TOO_MANY_REQUESTS
1617
1718# yapf: disable
1819from tensorrt_llm .executor import CppExecutorError
@@ -40,6 +41,7 @@ def __init__(self,
4041 gen_servers : List [str ],
4142 req_timeout_secs : int = 180 ,
4243 server_start_timeout_secs : int = 180 ,
44+ max_retries : int = 3 ,
4345 ctx_router_config : Optional [RouterConfig ] = None ,
4446 gen_router_config : Optional [RouterConfig ] = None ,
4547 conditional_disagg_config : Optional [ConditionalDisaggConfig ] = None ,
@@ -52,6 +54,10 @@ def __init__(self,
5254 self .gen_router = create_router (gen_router_config , gen_servers , metadata_server_cfg , self .metadata_server )
5355 self .conditional_disagg_config = conditional_disagg_config
5456
57+ if max_retries < 0 :
58+ raise ValueError (f"Max retries { max_retries } must be greater than or equal to 0" )
59+ self .max_retries = max_retries
60+ logger .info (f"Server max retries: { self .max_retries } " )
5561
5662 if (len (self .gen_servers ) == 0 ):
5763 raise ValueError ("At least one generation server must be provided" )
@@ -323,20 +329,32 @@ async def send_request(self, url: str,
323329 endpoint : str ,
324330 response_type : Type [Union [CompletionResponse , ChatCompletionResponse ]],
325331 create_generator : callable ) -> Union [CompletionResponse , ChatCompletionResponse , StreamingResponse ]:
326- if request .stream :
327- response_generator = create_generator (url , request )
328- return StreamingResponse (content = response_generator , media_type = "text/event-stream" )
329- else :
330- async with self .session .post (url + endpoint , json = request .model_dump (exclude_unset = True )) as response :
331- content_type = response .headers .get ("Content-Type" , "" )
332- if "text/event-stream" in content_type :
333- raise ValueError ("Received an event-stream although request stream was False" )
332+ for attempt in range (self .max_retries + 1 ):
333+ try :
334+ if request .stream :
335+ response_generator = create_generator (url , request )
336+ return StreamingResponse (content = response_generator , media_type = "text/event-stream" )
337+ else :
338+ async with self .session .post (url + endpoint , json = request .model_dump (exclude_unset = True )) as response :
339+ content_type = response .headers .get ("Content-Type" , "" )
340+ if "text/event-stream" in content_type :
341+ raise ValueError ("Received an event-stream although request stream was False" )
342+
343+ response_dict = await response .json ()
344+ if not response .ok :
345+ logger .error (f"Received failed response { response_dict } " )
346+ response .raise_for_status ()
347+ return response_type (** response_dict )
348+ except (aiohttp .ClientError , OSError ) as e :
349+ if attempt == self .max_retries :
350+ raise HTTPException (status_code = HTTP_429_TOO_MANY_REQUESTS , detail = f"Too many requests" ) from e
351+ logger .error (f"Client error: { e } - retry { attempt } of { self .max_retries } " )
352+ # TODO : add a configurable retry interval
353+ await asyncio .sleep (1 )
354+ except Exception as e :
355+ logger .error (f"Error encountered while processing request to { url + endpoint } : { e } " )
356+ raise
334357
335- response_dict = await response .json ()
336- if not response .ok :
337- logger .error (f"Received failed response { response_dict } " )
338- response .raise_for_status ()
339- return response_type (** response_dict )
340358
341359 async def send_completion_request (self , url : str , request : CompletionRequest ) -> Union [CompletionResponse , StreamingResponse ]:
342360 return await self .send_request (url , request , "/v1/completions" , CompletionResponse , self .create_completion_generator )
0 commit comments