Skip to content

Commit 715e3c1

Browse files
authored
[Proxy Server] changes for option methods and request cancellation (#1972)
1 parent 8a7cb87 commit 715e3c1

File tree

1 file changed

+168
-82
lines changed

1 file changed

+168
-82
lines changed

examples/online_serving/disagg_examples/disagg_proxy_advanced.py

Lines changed: 168 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Request, status)
2020
from fastapi.responses import JSONResponse, StreamingResponse, PlainTextResponse
2121
from transformers import AutoTokenizer
22+
from asyncio import CancelledError
2223

2324
formatter = logging.Formatter("[%(asctime)s] %(levelname)s - %(message)s",
2425
"%Y-%m-%d %H:%M:%S")
@@ -30,6 +31,7 @@
3031
logger.addHandler(handler)
3132
logger.propagate = False
3233

34+
from fastapi.middleware.cors import CORSMiddleware
3335

3436
def log_info_blue(msg):
3537
logger.info("%s%s%s", escape_codes['cyan'], msg, escape_codes['reset'])
@@ -47,10 +49,10 @@ def log_info_red(msg):
4749
logger.info("%s%s%s", escape_codes['red'], msg, escape_codes['reset'])
4850

4951

50-
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=60 * 60 * 60,
51-
connect=60000,
52-
sock_read=120000,
53-
sock_connect=30000)
52+
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=None,
53+
connect=None,
54+
sock_read=None,
55+
sock_connect=None)
5456

5557

5658
async def P_first_token_generator(generator_p,
@@ -60,40 +62,59 @@ async def P_first_token_generator(generator_p,
6062
decode_instance: str = None,
6163
req_len: int = None):
6264
first_decode = True
63-
async for chunk in generator_p:
64-
yield chunk
65-
if callback_owner and hasattr(callback_owner, "on_done"):
66-
callback_owner.on_done(prefill_instance=prefill_instance,
67-
req_len=req_len)
68-
69-
async for chunk in generator_d:
70-
if first_decode:
71-
first_decode = False
72-
continue
73-
yield chunk
74-
if callback_owner and hasattr(callback_owner, "on_done"):
75-
callback_owner.on_done(decode_instance=decode_instance,
76-
req_len=req_len)
7765

66+
try:
67+
async for chunk in generator_p:
68+
yield chunk
69+
finally:
70+
if callback_owner:
71+
callback_owner.exception_handler(
72+
prefill_instance=prefill_instance,
73+
decode_instance=None,
74+
req_len=req_len
75+
)
76+
77+
try:
78+
async for chunk in generator_d:
79+
if first_decode:
80+
first_decode = False
81+
continue
82+
yield chunk
83+
finally:
84+
if callback_owner:
85+
callback_owner.exception_handler(
86+
prefill_instance=None,
87+
decode_instance=decode_instance,
88+
req_len=req_len
89+
)
7890

7991
async def D_first_token_generator(generator_p,
8092
generator_d,
8193
callback_owner=None,
8294
prefill_instance: str = None,
8395
decode_instance: str = None,
8496
req_len: int = None):
85-
async for _ in generator_p:
86-
continue
87-
if callback_owner and hasattr(callback_owner, "on_done"):
88-
callback_owner.on_done(prefill_instance=prefill_instance,
89-
req_len=req_len)
90-
91-
async for chunk in generator_d:
92-
yield chunk
93-
if callback_owner and hasattr(callback_owner, "on_done"):
94-
callback_owner.on_done(decode_instance=decode_instance,
95-
req_len=req_len)
96-
97+
try:
98+
async for _ in generator_p:
99+
continue
100+
finally:
101+
if callback_owner:
102+
callback_owner.exception_handler(
103+
prefill_instance=prefill_instance,
104+
decode_instance=None,
105+
req_len=req_len
106+
)
107+
108+
try:
109+
async for chunk in generator_d:
110+
yield chunk
111+
finally:
112+
if callback_owner:
113+
callback_owner.exception_handler(
114+
prefill_instance=None,
115+
decode_instance=decode_instance,
116+
req_len=req_len
117+
)
97118

98119
class SchedulingPolicy(ABC):
99120

@@ -152,6 +173,25 @@ def setup_routes(self):
152173
Depends(self.validate_json_request)
153174
])(self.custom_create_chat_completion if self.
154175
custom_create_chat_completion else self.create_chat_completion)
176+
177+
self.router.options("/v1/completions")(lambda: None)
178+
self.router.options("/v1/chat/completions")(lambda: None)
179+
self.router.options("/v1/models")(lambda: None)
180+
self.router.options("/status")(lambda: None)
181+
self.router.options("/health")(lambda: None)
182+
self.router.options("/ping")(lambda: None)
183+
self.router.options("/tokenize")(lambda: None)
184+
self.router.options("/detokenize")(lambda: None)
185+
self.router.options("/version")(lambda: None)
186+
self.router.options("/v1/embeddings")(lambda: None)
187+
self.router.options("/pooling")(lambda: None)
188+
self.router.options("/score")(lambda: None)
189+
self.router.options("/v1/score")(lambda: None)
190+
self.router.options("/rerank")(lambda: None)
191+
self.router.options("/v1/rerank")(lambda: None)
192+
self.router.options("/v2/rerank")(lambda: None)
193+
self.router.options("/invocations")(lambda: None)
194+
155195
self.router.get("/status",
156196
response_class=JSONResponse)(self.get_status)
157197
self.router.post("/instances/add",
@@ -492,10 +532,26 @@ def get_total_token_length(self, prompt):
492532
logger.error("Unsupported prompt type: %s", type(prompt))
493533
return fake_len
494534

535+
def exception_handler(self, prefill_instance=None, decode_instance=None, req_len=None):
536+
if prefill_instance or decode_instance:
537+
try:
538+
self.on_done(
539+
prefill_instance=prefill_instance,
540+
decode_instance=decode_instance,
541+
req_len=req_len
542+
)
543+
except Exception as e:
544+
logger.error(f"Error releasing instances: {e}")
545+
raise
546+
495547
async def create_completion(self, raw_request: Request):
496548
try:
497549
request = await raw_request.json()
498550

551+
total_length = 0
552+
prefill_instance = None
553+
decode_instance = None
554+
499555
if len(self.prefill_instances) > 0:
500556
kv_prepare_request = request.copy()
501557
kv_prepare_request["max_tokens"] = 1
@@ -519,7 +575,7 @@ async def create_completion(self, raw_request: Request):
519575
kv_prepare_request):
520576
value += chunk
521577
except HTTPException as http_exc:
522-
self.remove_instance_endpoint("prefill", prefill_instance)
578+
self.exception_handler(prefill_instance, decode_instance, total_length)
523579
raise http_exc
524580

525581
# Perform kv recv and decoding stage
@@ -540,7 +596,7 @@ async def streaming_response(value):
540596
generator_d = self.forward_request(
541597
f"http://{decode_instance}/v1/completions", request)
542598
except HTTPException as http_exc:
543-
self.remove_instance_endpoint("decode", decode_instance)
599+
self.exception_handler(prefill_instance, decode_instance, total_length)
544600
raise http_exc
545601

546602
if request.get("stream", False):
@@ -559,12 +615,17 @@ async def streaming_response(value):
559615
if request.get("stream", False)
560616
else "application/json"
561617
)
562-
response = StreamingResponse(final_generator,
563-
media_type=media_type)
564-
return response
618+
async def wrapped_generator():
619+
try:
620+
async for chunk in final_generator:
621+
yield chunk
622+
except CancelledError:
623+
logger.warning("[0] Client disconnected during create_completion (CancelledError)")
624+
except Exception as e:
625+
logger.error("[1] Exception in wrapped_generator: %s", str(e))
626+
raise
627+
return StreamingResponse(wrapped_generator(), media_type=media_type)
565628
except Exception:
566-
import sys
567-
568629
exc_info = sys.exc_info()
569630
print("Error occurred in disagg proxy server")
570631
print(exc_info)
@@ -573,6 +634,10 @@ async def create_chat_completion(self, raw_request: Request):
573634
try:
574635
request = await raw_request.json()
575636

637+
total_length = 0
638+
prefill_instance = None
639+
decode_instance = None
640+
576641
# add params to request
577642
kv_prepare_request = request.copy()
578643
kv_prepare_request["max_tokens"] = 1
@@ -599,7 +664,7 @@ async def create_chat_completion(self, raw_request: Request):
599664
kv_prepare_request):
600665
value += chunk
601666
except HTTPException as http_exc:
602-
self.remove_instance_endpoint("prefill", prefill_instance)
667+
self.exception_handler(prefill_instance, decode_instance, total_length)
603668
raise http_exc
604669
# Perform kv recv and decoding stage
605670
decode_instance = self.schedule(self.decode_cycler,
@@ -620,7 +685,7 @@ async def streaming_response(value):
620685
"http://" + decode_instance + "/v1/chat/completions",
621686
request)
622687
except HTTPException as http_exc:
623-
self.remove_instance_endpoint("decode", decode_instance)
688+
self.exception_handler(prefill_instance, decode_instance, total_length)
624689
raise http_exc
625690

626691
if request.get("stream", False):
@@ -639,9 +704,16 @@ async def streaming_response(value):
639704
if request.get("stream", False)
640705
else "application/json"
641706
)
642-
response = StreamingResponse(final_generator,
643-
media_type=media_type)
644-
return response
707+
async def wrapped_generator():
708+
try:
709+
async for chunk in final_generator:
710+
yield chunk
711+
except CancelledError:
712+
logger.warning("[0] Client disconnected during create_completion (CancelledError)")
713+
except Exception as e:
714+
logger.error("[1] Exception in wrapped_generator: %s", str(e))
715+
raise
716+
return StreamingResponse(wrapped_generator(), media_type=media_type)
645717
except Exception:
646718
exc_info = sys.exc_info()
647719
error_messages = [str(e) for e in exc_info if e]
@@ -744,51 +816,57 @@ def schedule_completion(self,
744816
with self.lock:
745817
if prefill_instance:
746818
index = self.prefill_instances.index(prefill_instance)
747-
self.prefill_schedule_completion_index += 1
748-
log_info_yellow(f"<Prefill completed "
749-
f"{self.prefill_schedule_completion_index}> "
750-
f"instance = {index}, req_len={req_len}")
751-
752-
self.prefill_bs_counter[index] -= 1
753-
all_zero = True
754-
for index, _ in enumerate(self.prefill_instances):
755-
if self.prefill_bs_counter[index] != 0:
756-
all_zero = False
757-
break
758-
if all_zero:
759-
log_info_red("<Prefill in idle state>")
760-
for index, _ in enumerate(self.prefill_instances):
761-
self.prefill_utils_counter[index] = 0
819+
if self.prefill_bs_counter[index] == 0:
820+
logger.warning("No alive requests for prefill instance, skipping...")
762821
else:
763-
index = self.prefill_instances.index(prefill_instance)
764-
self.prefill_utils_counter[index] -= req_len
822+
self.prefill_schedule_completion_index += 1
823+
log_info_yellow(f"<Prefill completed "
824+
f"{self.prefill_schedule_completion_index}> "
825+
f"instance = {index}, req_len={req_len}")
826+
827+
self.prefill_bs_counter[index] -= 1
828+
all_zero = True
829+
for index, _ in enumerate(self.prefill_instances):
830+
if self.prefill_bs_counter[index] != 0:
831+
all_zero = False
832+
break
833+
if all_zero:
834+
log_info_red("<Prefill in idle state>")
835+
for index, _ in enumerate(self.prefill_instances):
836+
self.prefill_utils_counter[index] = 0
837+
else:
838+
index = self.prefill_instances.index(prefill_instance)
839+
self.prefill_utils_counter[index] -= req_len
765840

766841
if decode_instance:
767842
index = self.decode_instances.index(decode_instance)
768-
self.decode_schedule_completion_index += 1
769-
log_info_blue(f"<Decode completed "
770-
f"{self.decode_schedule_completion_index}> "
771-
f"instance = {index}, req_len={req_len}")
772-
773-
self.decode_bs_counter[index] -= 1
774-
all_zero = True
775-
for index, _ in enumerate(self.decode_instances):
776-
if self.decode_bs_counter[index] != 0:
777-
all_zero = False
778-
break
779-
if all_zero:
780-
log_info_red("<Decode in idle state>")
781-
self.decode_kv_utils_counter = [0] * len(
782-
self.decode_instances)
843+
if self.decode_bs_counter[index] == 0:
844+
logger.warning("No alive requests for decode instance, skipping...")
783845
else:
784-
index = self.decode_instances.index(decode_instance)
785-
self.decode_kv_utils_counter[index] -= req_len
786-
log_info_blue(
787-
f"<schedule_completion decode> "
788-
f"decode_bs_counter: {self.decode_bs_counter}")
789-
log_info_blue(f"<schedule_completion decode> "
790-
f"decode_kv_utils_counter: "
791-
f"{self.decode_kv_utils_counter}")
846+
self.decode_schedule_completion_index += 1
847+
log_info_blue(f"<Decode completed "
848+
f"{self.decode_schedule_completion_index}> "
849+
f"instance = {index}, req_len={req_len}")
850+
851+
self.decode_bs_counter[index] -= 1
852+
all_zero = True
853+
for index, _ in enumerate(self.decode_instances):
854+
if self.decode_bs_counter[index] != 0:
855+
all_zero = False
856+
break
857+
if all_zero:
858+
log_info_red("<Decode in idle state>")
859+
self.decode_kv_utils_counter = [0] * len(
860+
self.decode_instances)
861+
else:
862+
index = self.decode_instances.index(decode_instance)
863+
self.decode_kv_utils_counter[index] -= req_len
864+
log_info_blue(
865+
f"<schedule_completion decode> "
866+
f"decode_bs_counter: {self.decode_bs_counter}")
867+
log_info_blue(f"<schedule_completion decode> "
868+
f"decode_kv_utils_counter: "
869+
f"{self.decode_kv_utils_counter}")
792870

793871

794872
class ProxyServer:
@@ -861,6 +939,14 @@ def verify_model_config(self, instances: list, model: str) -> None:
861939

862940
def run_server(self):
863941
app = FastAPI()
942+
app.add_middleware(
943+
CORSMiddleware,
944+
allow_origins=["*"],
945+
allow_credentials=False,
946+
allow_methods=["*"],
947+
allow_headers=["*"],
948+
)
949+
864950
app.include_router(self.proxy_instance.router)
865951
config = uvicorn.Config(app,
866952
host="0.0.0.0",

0 commit comments

Comments
 (0)