Skip to content

Commit c7d8d2f

Browse files
FrsECMpre-commit-ci[bot]aniketmaurya
authored
Fix Windows Threading Issues (#385)
* Fix bug on windows with uvicorn when multiple workers. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Force socket to listen before starting server * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Ctrl+C on windows * Update src/litserve/server.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix comments - Ctrl+C on Windows * Update src/litserve/server.py Fix Type hint Co-authored-by: Aniket Maurya <theaniketmaurya@gmail.com> * Update src/litserve/server.py Remove windows comment. Co-authored-by: Aniket Maurya <theaniketmaurya@gmail.com> * Fix threading import Thread * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Increase test timeout => 30mn * Fix default self._uvicorn_servers * Fix sockets iteration * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * No Need to catch Keyboard Interrupt on windows. just close Threads * Update Timeout + testing CICD * Fix timeout for gpu tests * Fix KeyboardInterrupt Windows * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * KeyboardInterrupt Windows - MultipleWorkers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Move pid detection to LitLoop for less intrusivity * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert changes in CICD * Remove lock (useless) * Apply suggestions from code review --------- Co-authored-by: Francois Ponchon <francois.ponchon@michelin.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Aniket Maurya <theaniketmaurya@gmail.com>
1 parent 7e01984 commit c7d8d2f

File tree

4 files changed

+54
-13
lines changed

4 files changed

+54
-13
lines changed

src/litserve/loops/base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
import asyncio
1515
import inspect
1616
import logging
17+
import os
1718
import pickle
19+
import signal
1820
import sys
1921
import time
2022
from abc import ABC
@@ -212,6 +214,15 @@ def run(
212214
class LitLoop(_BaseLoop):
213215
def __init__(self):
214216
self._context = {}
217+
self._server_pid = os.getpid()
218+
219+
def kill(self):
220+
try:
221+
print(f"Stop Server Requested - Kill parent pid [{self._server_pid}] from [{os.getpid()}]")
222+
os.kill(self._server_pid, signal.SIGTERM)
223+
except PermissionError:
224+
# Access Denied because pid already killed...
225+
return
215226

216227
def get_batch_requests(
217228
self,

src/litserve/loops/simple_loops.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ def run_single_loop(
4242
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0)
4343
except (Empty, ValueError):
4444
continue
45+
except KeyboardInterrupt: # pragma: no cover
46+
self.kill()
47+
return
4548

4649
if (lit_api.request_timeout and lit_api.request_timeout != -1) and (
4750
time.monotonic() - timestamp > lit_api.request_timeout
@@ -213,7 +216,9 @@ def run_batched_loop(
213216
PickleableHTTPException.from_exception(e),
214217
LitAPIStatus.ERROR,
215218
)
216-
219+
except KeyboardInterrupt: # pragma: no cover
220+
self.kill()
221+
return
217222
except Exception as e:
218223
logger.exception(
219224
"LitAPI ran into an error while processing the batched request.\n"

src/litserve/loops/streaming_loops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ def run_streaming_loop(
9797
PickleableHTTPException.from_exception(e),
9898
LitAPIStatus.ERROR,
9999
)
100+
except KeyboardInterrupt: # pragma: no cover
101+
self.kill()
102+
return
100103
except Exception as e:
101104
logger.exception(
102105
"LitAPI ran into an error while processing the streaming request uid=%s.\n"
@@ -185,6 +188,9 @@ def run_batched_streaming_loop(
185188

186189
for response_queue_id, uid in zip(response_queue_ids, uids):
187190
self.put_response(transport, response_queue_id, uid, "", LitAPIStatus.FINISH_STREAMING)
191+
except KeyboardInterrupt: # pragma: no cover
192+
self.kill()
193+
return
188194

189195
except HTTPException as e:
190196
for response_queue_id, uid in zip(response_queue_ids, uids):

src/litserve/server.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,12 @@
2626
import warnings
2727
from collections import deque
2828
from contextlib import asynccontextmanager
29+
from multiprocessing.context import Process
30+
from threading import Thread
2931
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
3032

3133
import uvicorn
34+
import uvicorn.server
3235
from fastapi import Depends, FastAPI, HTTPException, Request, Response
3336
from fastapi.responses import JSONResponse, StreamingResponse
3437
from fastapi.security import APIKeyHeader
@@ -176,7 +179,6 @@ def __init__(
176179
DeprecationWarning,
177180
stacklevel=2,
178181
)
179-
180182
lit_api.max_batch_size = max_batch_size
181183
lit_api.batch_timeout = batch_timeout
182184
if isinstance(spec, LitSpec):
@@ -341,6 +343,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int):
341343
),
342344
)
343345
process.start()
346+
print(f"Inference Worker {worker_id} - [{process.pid}]")
344347
process_list.append(process)
345348
return manager, process_list
346349

@@ -599,20 +602,28 @@ def run(
599602
elif api_server_worker_type is None:
600603
api_server_worker_type = "process"
601604

602-
manager, litserve_workers = self.launch_inference_worker(num_api_servers)
605+
manager, inference_workers = self.launch_inference_worker(num_api_servers)
603606

604607
self.verify_worker_status()
605608
try:
606-
servers = self._start_server(port, num_api_servers, log_level, sockets, api_server_worker_type, **kwargs)
609+
uvicorn_workers = self._start_server(
610+
port, num_api_servers, log_level, sockets, api_server_worker_type, **kwargs
611+
)
607612
print(f"Swagger UI is available at http://0.0.0.0:{port}/docs")
608-
for s in servers:
609-
s.join()
613+
# On Linux, kill signal will be captured by uvicorn.
614+
# => They will join and raise a KeyboardInterrupt, allowing to Shutdown server.
615+
for i, uw in enumerate(uvicorn_workers):
616+
uw: Union[Process, Thread]
617+
if isinstance(uw, Process):
618+
print(f"Uvicorn worker {i} : [{uw.pid}]")
619+
uw.join()
610620
finally:
611621
print("Shutting down LitServe")
612622
self._transport.close()
613-
for w in litserve_workers:
614-
w.terminate()
615-
w.join()
623+
for iw in inference_workers:
624+
iw: Process
625+
iw.terminate()
626+
iw.join()
616627
manager.shutdown()
617628

618629
def _prepare_app_run(self, app: FastAPI):
@@ -622,16 +633,24 @@ def _prepare_app_run(self, app: FastAPI):
622633
app.add_middleware(RequestCountMiddleware, active_counter=active_counter)
623634

624635
def _start_server(self, port, num_uvicorn_servers, log_level, sockets, uvicorn_worker_type, **kwargs):
625-
servers = []
636+
workers = []
626637
for response_queue_id in range(num_uvicorn_servers):
627638
self.app.response_queue_id = response_queue_id
628639
if self.lit_spec:
629640
self.lit_spec.response_queue_id = response_queue_id
630641
app: FastAPI = copy.copy(self.app)
631642

632643
self._prepare_app_run(app)
633-
634644
config = uvicorn.Config(app=app, host="0.0.0.0", port=port, log_level=log_level, **kwargs)
645+
if sys.platform == "win32" and num_uvicorn_servers > 1:
646+
logger.debug("Enable Windows explicit socket sharing...")
647+
# We make sure sockets is listening...
648+
# It prevents further [WinError 10022]
649+
for sock in sockets:
650+
sock.listen(config.backlog)
651+
# We add worker to say unicorn to use a shared socket (win32)
652+
# https://github.com/encode/uvicorn/pull/802
653+
config.workers = num_uvicorn_servers
635654
server = uvicorn.Server(config=config)
636655
if uvicorn_worker_type == "process":
637656
ctx = mp.get_context("fork")
@@ -641,8 +660,8 @@ def _start_server(self, port, num_uvicorn_servers, log_level, sockets, uvicorn_w
641660
else:
642661
raise ValueError("Invalid value for api_server_worker_type. Must be 'process' or 'thread'")
643662
w.start()
644-
servers.append(w)
645-
return servers
663+
workers.append(w)
664+
return workers
646665

647666
def setup_auth(self):
648667
if hasattr(self.lit_api, "authorize") and callable(self.lit_api.authorize):

0 commit comments

Comments
 (0)