Skip to content

Commit 104ab43

Browse files
Reject requests to protect server (#1275)
This PR monitors the request queue within the shortfin server and rejects requests with a 503 error when the queue is filled with more than max batch size + 2 TODO: Add load testing to integration tests
1 parent 7b56254 commit 104ab43

File tree

3 files changed

+55
-8
lines changed

3 files changed

+55
-8
lines changed

shortfin/python/shortfin_apps/llm/components/generate.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
# TODO: Have a generic "Responder" interface vs just the concrete impl.
2020
from shortfin.interop.fastapi import FastAPIResponder
21+
from fastapi.responses import JSONResponse
22+
from fastapi import status
2123

2224
from .config_struct import DecodeConfig
2325
from .io_struct import (
@@ -130,6 +132,7 @@ class ClientGenerateBatchProcess(sf.Process):
130132
"responder",
131133
"tokenizer",
132134
"decode_config",
135+
"service",
133136
]
134137

135138
def __init__(
@@ -140,6 +143,7 @@ def __init__(
140143
fiber: sf.Fiber | None = None,
141144
):
142145
super().__init__(fiber=service.main_fiber if fiber is None else fiber)
146+
self.service = service
143147
self.gen_req = gen_req
144148
self.responder = responder
145149
self.tokenizer = service.tokenizer
@@ -151,12 +155,29 @@ def __init__(
151155

152156
async def run(self):
153157
logger.debug("Started ClientBatchGenerateProcess: %r", self)
154-
streaming = self.gen_req.stream
155-
self.responder.start_response()
156-
if streaming:
157-
self.responder.stream_start()
158+
159+
# Try to add request to queue
160+
# TODO(@zphoenixrises): Add load testing and integration tests for this.
161+
if not self.service.add_to_queue():
162+
error_response = JSONResponse(
163+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
164+
content={
165+
"error": "Server queue is full. Please try again later.",
166+
"code": "QUEUE_FULL",
167+
"current_size": self.service.current_queue_size,
168+
"max_size": self.service.max_queue_size,
169+
},
170+
)
171+
self.responder.send_response(error_response)
172+
self.responder.ensure_response()
173+
return
158174

159175
try:
176+
streaming = self.gen_req.stream
177+
self.responder.start_response()
178+
if streaming:
179+
self.responder.stream_start()
180+
160181
# Launch all individual generate processes and wait for them to finish.
161182
gen_processes = []
162183
input_ids = self.gen_req.input_ids
@@ -166,6 +187,7 @@ async def run(self):
166187
input_batch = [input_ids] if self.gen_req.is_single else input_ids
167188
else:
168189
input_batch = self.tokenize()
190+
169191
for index, input_tokens in enumerate(input_batch):
170192
decode_config = copy(self.decode_config)
171193
decode_config.update_from_sampling_params(
@@ -189,7 +211,10 @@ async def run(self):
189211

190212
await asyncio.gather(*gen_processes)
191213
self.generate_response(gen_processes, streaming)
214+
192215
finally:
216+
# Remove request from queue when done
217+
self.service.remove_from_queue()
193218
self.responder.ensure_response()
194219

195220
def generate_response(

shortfin/python/shortfin_apps/llm/components/service.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
from ...utils import GenerateService
2626

27-
2827
logger = logging.getLogger(__name__)
2928

3029

@@ -44,18 +43,41 @@ def __init__(
4443
model_params: ModelParams,
4544
server_params: "ServerParams",
4645
program_isolation: str = "per_call",
46+
max_queue_size: int = 3, # Maximum number of requests in queue
4747
):
4848
super().__init__(sysman)
4949
self.name = name
5050
self.tokenizer = tokenizer
5151
self.model_params = model_params
5252
self.server_params = server_params
53+
self.max_queue_size = max_queue_size
54+
self.current_queue_size = 0
5355

5456
self.set_isolation(program_isolation)
5557
self.initialize_worker_and_fiber()
58+
self.initialize_queues()
5659
self.initialize_page_cache()
5760

61+
def initialize_queues(self):
62+
"""Initialize request and response queues"""
63+
if self.model_params.decode_batch_sizes:
64+
self.max_queue_size = max(self.model_params.decode_batch_sizes) + 2
65+
print(f"Max queue size: {self.max_queue_size}")
66+
67+
def add_to_queue(self) -> bool:
68+
"""Try to add a request to the queue. Returns True if successful, False if queue is full."""
69+
if self.current_queue_size >= self.max_queue_size:
70+
return False
71+
self.current_queue_size += 1
72+
return True
73+
74+
def remove_from_queue(self):
75+
"""Remove a request from the queue."""
76+
if self.current_queue_size > 0:
77+
self.current_queue_size -= 1
78+
5879
def initialize_worker_and_fiber(self):
80+
5981
self.main_worker = self.sysman.ls.create_worker(f"{self.name}-inference")
6082
self.main_fiber = self.sysman.ls.create_fiber(self.main_worker)
6183
self.prefill_fiber = self.sysman.ls.create_fiber(self.main_worker)

shortfin/python/shortfin_apps/llm/server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def parse_args(argv):
6868
return parser.parse_args(argv)
6969

7070

71-
def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
71+
def run_server(argv, log_config=uvicorn.config.LOGGING_CONFIG, port: int | None = None):
7272
args = parse_args(argv)
7373
if args.tokenizer_config_json is None:
7474
# this is only used for the EOS token
@@ -84,7 +84,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
8484
uvicorn.run(
8585
get_app(lifecycle_manager.fastapi_lifespan),
8686
host=args.host,
87-
port=args.port,
87+
port=port or args.port,
8888
log_config=log_config,
8989
timeout_keep_alive=args.timeout_keep_alive,
9090
)
@@ -94,7 +94,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
9494
from shortfin.support.logging_setup import configure_main_logger
9595

9696
logger = configure_main_logger("server")
97-
main(
97+
run_server(
9898
sys.argv[1:],
9999
# Make logging defer to the default shortfin logging config.
100100
log_config=UVICORN_LOG_CONFIG,

0 commit comments

Comments
 (0)