1919 Request , status )
2020from fastapi .responses import JSONResponse , StreamingResponse , PlainTextResponse
2121from transformers import AutoTokenizer
22+ from asyncio import CancelledError
2223
2324formatter = logging .Formatter ("[%(asctime)s] %(levelname)s - %(message)s" ,
2425 "%Y-%m-%d %H:%M:%S" )
3031logger .addHandler (handler )
3132logger .propagate = False
3233
34+ from fastapi .middleware .cors import CORSMiddleware
3335
3436def log_info_blue (msg ):
3537 logger .info ("%s%s%s" , escape_codes ['cyan' ], msg , escape_codes ['reset' ])
@@ -47,10 +49,10 @@ def log_info_red(msg):
4749 logger .info ("%s%s%s" , escape_codes ['red' ], msg , escape_codes ['reset' ])
4850
4951
50- AIOHTTP_TIMEOUT = aiohttp .ClientTimeout (total = 60 * 60 * 60 ,
51- connect = 60000 ,
52- sock_read = 120000 ,
53- sock_connect = 30000 )
52+ AIOHTTP_TIMEOUT = aiohttp .ClientTimeout (total = None ,
53+ connect = None ,
54+ sock_read = None ,
55+ sock_connect = None )
5456
5557
5658async def P_first_token_generator (generator_p ,
@@ -60,40 +62,59 @@ async def P_first_token_generator(generator_p,
6062 decode_instance : str = None ,
6163 req_len : int = None ):
6264 first_decode = True
63- async for chunk in generator_p :
64- yield chunk
65- if callback_owner and hasattr (callback_owner , "on_done" ):
66- callback_owner .on_done (prefill_instance = prefill_instance ,
67- req_len = req_len )
68-
69- async for chunk in generator_d :
70- if first_decode :
71- first_decode = False
72- continue
73- yield chunk
74- if callback_owner and hasattr (callback_owner , "on_done" ):
75- callback_owner .on_done (decode_instance = decode_instance ,
76- req_len = req_len )
7765
66+ try :
67+ async for chunk in generator_p :
68+ yield chunk
69+ finally :
70+ if callback_owner :
71+ callback_owner .exception_handler (
72+ prefill_instance = prefill_instance ,
73+ decode_instance = None ,
74+ req_len = req_len
75+ )
76+
77+ try :
78+ async for chunk in generator_d :
79+ if first_decode :
80+ first_decode = False
81+ continue
82+ yield chunk
83+ finally :
84+ if callback_owner :
85+ callback_owner .exception_handler (
86+ prefill_instance = None ,
87+ decode_instance = decode_instance ,
88+ req_len = req_len
89+ )
7890
7991async def D_first_token_generator (generator_p ,
8092 generator_d ,
8193 callback_owner = None ,
8294 prefill_instance : str = None ,
8395 decode_instance : str = None ,
8496 req_len : int = None ):
85- async for _ in generator_p :
86- continue
87- if callback_owner and hasattr (callback_owner , "on_done" ):
88- callback_owner .on_done (prefill_instance = prefill_instance ,
89- req_len = req_len )
90-
91- async for chunk in generator_d :
92- yield chunk
93- if callback_owner and hasattr (callback_owner , "on_done" ):
94- callback_owner .on_done (decode_instance = decode_instance ,
95- req_len = req_len )
96-
97+ try :
98+ async for _ in generator_p :
99+ continue
100+ finally :
101+ if callback_owner :
102+ callback_owner .exception_handler (
103+ prefill_instance = prefill_instance ,
104+ decode_instance = None ,
105+ req_len = req_len
106+ )
107+
108+ try :
109+ async for chunk in generator_d :
110+ yield chunk
111+ finally :
112+ if callback_owner :
113+ callback_owner .exception_handler (
114+ prefill_instance = None ,
115+ decode_instance = decode_instance ,
116+ req_len = req_len
117+ )
97118
98119class SchedulingPolicy (ABC ):
99120
@@ -152,6 +173,25 @@ def setup_routes(self):
152173 Depends (self .validate_json_request )
153174 ])(self .custom_create_chat_completion if self .
154175 custom_create_chat_completion else self .create_chat_completion )
176+
177+ self .router .options ("/v1/completions" )(lambda : None )
178+ self .router .options ("/v1/chat/completions" )(lambda : None )
179+ self .router .options ("/v1/models" )(lambda : None )
180+ self .router .options ("/status" )(lambda : None )
181+ self .router .options ("/health" )(lambda : None )
182+ self .router .options ("/ping" )(lambda : None )
183+ self .router .options ("/tokenize" )(lambda : None )
184+ self .router .options ("/detokenize" )(lambda : None )
185+ self .router .options ("/version" )(lambda : None )
186+ self .router .options ("/v1/embeddings" )(lambda : None )
187+ self .router .options ("/pooling" )(lambda : None )
188+ self .router .options ("/score" )(lambda : None )
189+ self .router .options ("/v1/score" )(lambda : None )
190+ self .router .options ("/rerank" )(lambda : None )
191+ self .router .options ("/v1/rerank" )(lambda : None )
192+ self .router .options ("/v2/rerank" )(lambda : None )
193+ self .router .options ("/invocations" )(lambda : None )
194+
155195 self .router .get ("/status" ,
156196 response_class = JSONResponse )(self .get_status )
157197 self .router .post ("/instances/add" ,
@@ -492,10 +532,26 @@ def get_total_token_length(self, prompt):
492532 logger .error ("Unsupported prompt type: %s" , type (prompt ))
493533 return fake_len
494534
535+ def exception_handler (self , prefill_instance = None , decode_instance = None , req_len = None ):
536+ if prefill_instance or decode_instance :
537+ try :
538+ self .on_done (
539+ prefill_instance = prefill_instance ,
540+ decode_instance = decode_instance ,
541+ req_len = req_len
542+ )
543+ except Exception as e :
544+ logger .error (f"Error releasing instances: { e } " )
545+ raise
546+
495547 async def create_completion (self , raw_request : Request ):
496548 try :
497549 request = await raw_request .json ()
498550
551+ total_length = 0
552+ prefill_instance = None
553+ decode_instance = None
554+
499555 if len (self .prefill_instances ) > 0 :
500556 kv_prepare_request = request .copy ()
501557 kv_prepare_request ["max_tokens" ] = 1
@@ -519,7 +575,7 @@ async def create_completion(self, raw_request: Request):
519575 kv_prepare_request ):
520576 value += chunk
521577 except HTTPException as http_exc :
522- self .remove_instance_endpoint ( "prefill" , prefill_instance )
578+ self .exception_handler ( prefill_instance , decode_instance , total_length )
523579 raise http_exc
524580
525581 # Perform kv recv and decoding stage
@@ -540,7 +596,7 @@ async def streaming_response(value):
540596 generator_d = self .forward_request (
541597 f"http://{ decode_instance } /v1/completions" , request )
542598 except HTTPException as http_exc :
543- self .remove_instance_endpoint ( "decode" , decode_instance )
599+ self .exception_handler ( prefill_instance , decode_instance , total_length )
544600 raise http_exc
545601
546602 if request .get ("stream" , False ):
@@ -559,12 +615,17 @@ async def streaming_response(value):
559615 if request .get ("stream" , False )
560616 else "application/json"
561617 )
562- response = StreamingResponse (final_generator ,
563- media_type = media_type )
564- return response
618+ async def wrapped_generator ():
619+ try :
620+ async for chunk in final_generator :
621+ yield chunk
622+ except CancelledError :
623+ logger .warning ("[0] Client disconnected during create_completion (CancelledError)" )
624+ except Exception as e :
625+ logger .error ("[1] Exception in wrapped_generator: %s" , str (e ))
626+ raise
627+ return StreamingResponse (wrapped_generator (), media_type = media_type )
565628 except Exception :
566- import sys
567-
568629 exc_info = sys .exc_info ()
569630 print ("Error occurred in disagg proxy server" )
570631 print (exc_info )
@@ -573,6 +634,10 @@ async def create_chat_completion(self, raw_request: Request):
573634 try :
574635 request = await raw_request .json ()
575636
637+ total_length = 0
638+ prefill_instance = None
639+ decode_instance = None
640+
576641 # add params to request
577642 kv_prepare_request = request .copy ()
578643 kv_prepare_request ["max_tokens" ] = 1
@@ -599,7 +664,7 @@ async def create_chat_completion(self, raw_request: Request):
599664 kv_prepare_request ):
600665 value += chunk
601666 except HTTPException as http_exc :
602- self .remove_instance_endpoint ( "prefill" , prefill_instance )
667+ self .exception_handler ( prefill_instance , decode_instance , total_length )
603668 raise http_exc
604669 # Perform kv recv and decoding stage
605670 decode_instance = self .schedule (self .decode_cycler ,
@@ -620,7 +685,7 @@ async def streaming_response(value):
620685 "http://" + decode_instance + "/v1/chat/completions" ,
621686 request )
622687 except HTTPException as http_exc :
623- self .remove_instance_endpoint ( "decode" , decode_instance )
688+ self .exception_handler ( prefill_instance , decode_instance , total_length )
624689 raise http_exc
625690
626691 if request .get ("stream" , False ):
@@ -639,9 +704,16 @@ async def streaming_response(value):
639704 if request .get ("stream" , False )
640705 else "application/json"
641706 )
642- response = StreamingResponse (final_generator ,
643- media_type = media_type )
644- return response
707+ async def wrapped_generator ():
708+ try :
709+ async for chunk in final_generator :
710+ yield chunk
711+ except CancelledError :
712+ logger .warning ("[0] Client disconnected during create_completion (CancelledError)" )
713+ except Exception as e :
714+ logger .error ("[1] Exception in wrapped_generator: %s" , str (e ))
715+ raise
716+ return StreamingResponse (wrapped_generator (), media_type = media_type )
645717 except Exception :
646718 exc_info = sys .exc_info ()
647719 error_messages = [str (e ) for e in exc_info if e ]
@@ -744,51 +816,57 @@ def schedule_completion(self,
744816 with self .lock :
745817 if prefill_instance :
746818 index = self .prefill_instances .index (prefill_instance )
747- self .prefill_schedule_completion_index += 1
748- log_info_yellow (f"<Prefill completed "
749- f"{ self .prefill_schedule_completion_index } > "
750- f"instance = { index } , req_len={ req_len } " )
751-
752- self .prefill_bs_counter [index ] -= 1
753- all_zero = True
754- for index , _ in enumerate (self .prefill_instances ):
755- if self .prefill_bs_counter [index ] != 0 :
756- all_zero = False
757- break
758- if all_zero :
759- log_info_red ("<Prefill in idle state>" )
760- for index , _ in enumerate (self .prefill_instances ):
761- self .prefill_utils_counter [index ] = 0
819+ if self .prefill_bs_counter [index ] == 0 :
820+ logger .warning ("No alive requests for prefill instance, skipping..." )
762821 else :
763- index = self .prefill_instances .index (prefill_instance )
764- self .prefill_utils_counter [index ] -= req_len
822+ self .prefill_schedule_completion_index += 1
823+ log_info_yellow (f"<Prefill completed "
824+ f"{ self .prefill_schedule_completion_index } > "
825+ f"instance = { index } , req_len={ req_len } " )
826+
827+ self .prefill_bs_counter [index ] -= 1
828+ all_zero = True
829+ for index , _ in enumerate (self .prefill_instances ):
830+ if self .prefill_bs_counter [index ] != 0 :
831+ all_zero = False
832+ break
833+ if all_zero :
834+ log_info_red ("<Prefill in idle state>" )
835+ for index , _ in enumerate (self .prefill_instances ):
836+ self .prefill_utils_counter [index ] = 0
837+ else :
838+ index = self .prefill_instances .index (prefill_instance )
839+ self .prefill_utils_counter [index ] -= req_len
765840
766841 if decode_instance :
767842 index = self .decode_instances .index (decode_instance )
768- self .decode_schedule_completion_index += 1
769- log_info_blue (f"<Decode completed "
770- f"{ self .decode_schedule_completion_index } > "
771- f"instance = { index } , req_len={ req_len } " )
772-
773- self .decode_bs_counter [index ] -= 1
774- all_zero = True
775- for index , _ in enumerate (self .decode_instances ):
776- if self .decode_bs_counter [index ] != 0 :
777- all_zero = False
778- break
779- if all_zero :
780- log_info_red ("<Decode in idle state>" )
781- self .decode_kv_utils_counter = [0 ] * len (
782- self .decode_instances )
843+ if self .decode_bs_counter [index ] == 0 :
844+ logger .warning ("No alive requests for decode instance, skipping..." )
783845 else :
784- index = self .decode_instances .index (decode_instance )
785- self .decode_kv_utils_counter [index ] -= req_len
786- log_info_blue (
787- f"<schedule_completion decode> "
788- f"decode_bs_counter: { self .decode_bs_counter } " )
789- log_info_blue (f"<schedule_completion decode> "
790- f"decode_kv_utils_counter: "
791- f"{ self .decode_kv_utils_counter } " )
846+ self .decode_schedule_completion_index += 1
847+ log_info_blue (f"<Decode completed "
848+ f"{ self .decode_schedule_completion_index } > "
849+ f"instance = { index } , req_len={ req_len } " )
850+
851+ self .decode_bs_counter [index ] -= 1
852+ all_zero = True
853+ for index , _ in enumerate (self .decode_instances ):
854+ if self .decode_bs_counter [index ] != 0 :
855+ all_zero = False
856+ break
857+ if all_zero :
858+ log_info_red ("<Decode in idle state>" )
859+ self .decode_kv_utils_counter = [0 ] * len (
860+ self .decode_instances )
861+ else :
862+ index = self .decode_instances .index (decode_instance )
863+ self .decode_kv_utils_counter [index ] -= req_len
864+ log_info_blue (
865+ f"<schedule_completion decode> "
866+ f"decode_bs_counter: { self .decode_bs_counter } " )
867+ log_info_blue (f"<schedule_completion decode> "
868+ f"decode_kv_utils_counter: "
869+ f"{ self .decode_kv_utils_counter } " )
792870
793871
794872class ProxyServer :
@@ -861,6 +939,14 @@ def verify_model_config(self, instances: list, model: str) -> None:
861939
862940 def run_server (self ):
863941 app = FastAPI ()
942+ app .add_middleware (
943+ CORSMiddleware ,
944+ allow_origins = ["*" ],
945+ allow_credentials = False ,
946+ allow_methods = ["*" ],
947+ allow_headers = ["*" ],
948+ )
949+
864950 app .include_router (self .proxy_instance .router )
865951 config = uvicorn .Config (app ,
866952 host = "0.0.0.0" ,
0 commit comments