Skip to content

Commit 7c39c68

Browse files
committed
Support restarting workers after max requests
This is useful as a "solution" to memory leaks in apps as it ensures that after the max requests have been handled the worker will restart hence freeing any memory leak. The options match those used by Gunicorn. This also ensures that the workers self-heal such that if a worker crashes it will be restored.
1 parent c0468e5 commit 7c39c68

21 files changed

+163
-68
lines changed

src/hypercorn/__main__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,19 @@ def main(sys_args: Optional[List[str]] = None) -> int:
8989
default=sentinel,
9090
type=int,
9191
)
92+
parser.add_argument(
93+
"--max-requests",
94+
help="""Maximum number of requests a worker will process before restarting""",
95+
default=sentinel,
96+
type=int,
97+
)
98+
parser.add_argument(
99+
"--max-requests-jitter",
100+
help="This jitter causes the max-requests per worker to be "
101+
"randomized by randint(0, max_requests_jitter)",
102+
default=sentinel,
103+
type=int,
104+
)
92105
parser.add_argument(
93106
"-g", "--group", help="Group to own any unix sockets.", default=sentinel, type=int
94107
)
@@ -252,6 +265,10 @@ def _convert_verify_mode(value: str) -> ssl.VerifyMode:
252265
config.keyfile_password = args.keyfile_password
253266
if args.log_config is not sentinel:
254267
config.logconfig = args.log_config
268+
if args.max_requests is not sentinel:
269+
config.max_requests = args.max_requests
270+
if args.max_requests_jitter is not sentinel:
271+
config.max_requests_jitter = args.max_requests
255272
if args.pid is not sentinel:
256273
config.pid_path = args.pid
257274
if args.root_path is not sentinel:

src/hypercorn/asyncio/run.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import platform
55
import signal
66
import ssl
7+
import sys
78
from functools import partial
89
from multiprocessing.synchronize import Event as EventType
910
from os import getpid
11+
from random import randint
1012
from socket import socket
1113
from typing import Any, Awaitable, Callable, Optional, Set
1214

@@ -30,6 +32,14 @@
3032
except ImportError:
3133
from taskgroup import Runner # type: ignore
3234

35+
try:
36+
from asyncio import TaskGroup
37+
except ImportError:
38+
from taskgroup import TaskGroup # type: ignore
39+
40+
if sys.version_info < (3, 11):
41+
from exceptiongroup import BaseExceptionGroup
42+
3343

3444
def _share_socket(sock: socket) -> socket:
3545
# Windows requires the socket be explicitly shared across
@@ -84,7 +94,10 @@ def _signal_handler(*_: Any) -> None: # noqa: N803
8494
ssl_context = config.create_ssl_context()
8595
ssl_handshake_timeout = config.ssl_handshake_timeout
8696

87-
context = WorkerContext()
97+
max_requests = None
98+
if config.max_requests is not None:
99+
max_requests = config.max_requests + randint(0, config.max_requests_jitter)
100+
context = WorkerContext(max_requests)
88101
server_tasks: Set[asyncio.Task] = set()
89102

90103
async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
@@ -136,7 +149,13 @@ async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamW
136149
await config.log.info(f"Running on https://{bind} (QUIC) (CTRL + C to quit)")
137150

138151
try:
139-
await raise_shutdown(shutdown_trigger)
152+
async with TaskGroup() as task_group:
153+
task_group.create_task(raise_shutdown(shutdown_trigger))
154+
task_group.create_task(raise_shutdown(context.terminate.wait))
155+
except BaseExceptionGroup as error:
156+
_, other_errors = error.split((ShutdownError, KeyboardInterrupt))
157+
if other_errors is not None:
158+
raise other_errors
140159
except (ShutdownError, KeyboardInterrupt):
141160
pass
142161
finally:

src/hypercorn/asyncio/worker_context.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4-
from typing import Type, Union
4+
from typing import Optional, Type, Union
55

66
from ..typing import Event
77

@@ -26,9 +26,20 @@ def is_set(self) -> bool:
2626
class WorkerContext:
2727
event_class: Type[Event] = EventWrapper
2828

29-
def __init__(self) -> None:
29+
def __init__(self, max_requests: Optional[int]) -> None:
30+
self.max_requests = max_requests
31+
self.requests = 0
32+
self.terminate = self.event_class()
3033
self.terminated = self.event_class()
3134

35+
async def mark_request(self) -> None:
36+
if self.max_requests is None:
37+
return
38+
39+
self.requests += 1
40+
if self.requests > self.max_requests:
41+
await self.terminate.set()
42+
3243
@staticmethod
3344
async def sleep(wait: Union[float, int]) -> None:
3445
return await asyncio.sleep(wait)

src/hypercorn/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ class Config:
9292
logger_class = Logger
9393
loglevel: str = "INFO"
9494
max_app_queue_size: int = 10
95+
max_requests: Optional[int] = None
96+
max_requests_jitter: int = 0
9597
pid_path: Optional[str] = None
9698
server_names: List[str] = []
9799
shutdown_timeout = 60 * SECONDS

src/hypercorn/protocol/h11.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ async def _create_stream(self, request: h11.Request) -> None:
236236
)
237237
)
238238
self.keep_alive_requests += 1
239+
await self.context.mark_request()
239240

240241
async def _send_h11_event(self, event: H11SendableEvent) -> None:
241242
try:

src/hypercorn/protocol/h2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ async def _create_stream(self, request: h2.events.RequestReceived) -> None:
354354
)
355355
)
356356
self.keep_alive_requests += 1
357+
await self.context.mark_request()
357358

358359
async def _create_server_push(
359360
self, stream_id: int, path: bytes, headers: List[Tuple[bytes, bytes]]

src/hypercorn/protocol/h3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ async def _create_stream(self, request: HeadersReceived) -> None:
125125
raw_path=raw_path,
126126
)
127127
)
128+
await self.context.mark_request()
128129

129130
async def _create_server_push(
130131
self, stream_id: int, path: bytes, headers: List[Tuple[bytes, bytes]]

src/hypercorn/run.py

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import signal
55
import time
66
from multiprocessing import get_context
7+
from multiprocessing.connection import wait
78
from multiprocessing.context import BaseContext
89
from multiprocessing.process import BaseProcess
910
from multiprocessing.synchronize import Event as EventType
@@ -12,12 +13,10 @@
1213

1314
from .config import Config, Sockets
1415
from .typing import WorkerFunc
15-
from .utils import load_application, wait_for_changes, write_pid_file
16+
from .utils import check_for_updates, files_to_watch, load_application, write_pid_file
1617

1718

1819
def run(config: Config) -> int:
19-
exit_code = 0
20-
2120
if config.pid_path is not None:
2221
write_pid_file(config.pid_path)
2322

@@ -42,67 +41,82 @@ def run(config: Config) -> int:
4241
if config.use_reloader and config.workers == 0:
4342
raise RuntimeError("Cannot reload without workers")
4443

45-
if config.use_reloader or config.workers == 0:
46-
# Load the application so that the correct paths are checked for
47-
# changes, but only when the reloader is being used.
48-
load_application(config.application_path, config.wsgi_max_body_size)
49-
44+
exitcode = 0
5045
if config.workers == 0:
5146
worker_func(config, sockets)
5247
else:
48+
if config.use_reloader:
49+
# Load the application so that the correct paths are checked for
50+
# changes, but only when the reloader is being used.
51+
load_application(config.application_path, config.wsgi_max_body_size)
52+
5353
ctx = get_context("spawn")
5454

5555
active = True
56+
shutdown_event = ctx.Event()
57+
58+
def shutdown(*args: Any) -> None:
59+
nonlocal active, shutdown_event
60+
shutdown_event.set()
61+
active = False
62+
63+
processes: List[BaseProcess] = []
5664
while active:
5765
# Ignore SIGINT before creating the processes, so that they
5866
# inherit the signal handling. This means that the shutdown
5967
# function controls the shutdown.
6068
signal.signal(signal.SIGINT, signal.SIG_IGN)
6169

62-
shutdown_event = ctx.Event()
63-
processes = start_processes(config, worker_func, sockets, shutdown_event, ctx)
64-
65-
def shutdown(*args: Any) -> None:
66-
nonlocal active, shutdown_event
67-
shutdown_event.set()
68-
active = False
70+
_populate(processes, config, worker_func, sockets, shutdown_event, ctx)
6971

7072
for signal_name in {"SIGINT", "SIGTERM", "SIGBREAK"}:
7173
if hasattr(signal, signal_name):
7274
signal.signal(getattr(signal, signal_name), shutdown)
7375

7476
if config.use_reloader:
75-
wait_for_changes(shutdown_event)
76-
shutdown_event.set()
77+
files = files_to_watch()
78+
while True:
79+
finished = wait((process.sentinel for process in processes), timeout=1)
80+
updated = check_for_updates(files)
81+
if updated:
82+
shutdown_event.set()
83+
for process in processes:
84+
process.join()
85+
shutdown_event.clear()
86+
break
87+
if len(finished) > 0:
88+
break
7789
else:
78-
active = False
90+
wait(process.sentinel for process in processes)
7991

80-
for process in processes:
81-
process.join()
82-
if process.exitcode != 0:
83-
exit_code = process.exitcode
92+
exitcode = _join_exited(processes)
93+
if exitcode != 0:
94+
shutdown_event.set()
95+
active = False
8496

8597
for process in processes:
8698
process.terminate()
8799

100+
exitcode = _join_exited(processes) if exitcode != 0 else exitcode
101+
88102
for sock in sockets.secure_sockets:
89103
sock.close()
90104

91105
for sock in sockets.insecure_sockets:
92106
sock.close()
93107

94-
return exit_code
108+
return exitcode
95109

96110

97-
def start_processes(
111+
def _populate(
112+
processes: List[BaseProcess],
98113
config: Config,
99114
worker_func: WorkerFunc,
100115
sockets: Sockets,
101116
shutdown_event: EventType,
102117
ctx: BaseContext,
103-
) -> List[BaseProcess]:
104-
processes = []
105-
for _ in range(config.workers):
118+
) -> None:
119+
for _ in range(config.workers - len(processes)):
106120
process = ctx.Process( # type: ignore
107121
target=worker_func,
108122
kwargs={"config": config, "shutdown_event": shutdown_event, "sockets": sockets},
@@ -117,4 +131,15 @@ def start_processes(
117131
processes.append(process)
118132
if platform.system() == "Windows":
119133
time.sleep(0.1)
120-
return processes
134+
135+
136+
def _join_exited(processes: List[BaseProcess]) -> int:
137+
exitcode = 0
138+
for index in reversed(range(len(processes))):
139+
worker = processes[index]
140+
if worker.exitcode is not None:
141+
worker.join()
142+
exitcode = worker.exitcode if exitcode == 0 else exitcode
143+
del processes[index]
144+
145+
return exitcode

src/hypercorn/trio/run.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
from functools import partial
55
from multiprocessing.synchronize import Event as EventType
6+
from random import randint
67
from typing import Awaitable, Callable, Optional
78

89
import trio
@@ -37,7 +38,10 @@ async def worker_serve(
3738
config.set_statsd_logger_class(StatsdLogger)
3839

3940
lifespan = Lifespan(app, config)
40-
context = WorkerContext()
41+
max_requests = None
42+
if config.max_requests is not None:
43+
max_requests = config.max_requests + randint(0, config.max_requests_jitter)
44+
context = WorkerContext(max_requests)
4145

4246
async with trio.open_nursery() as lifespan_nursery:
4347
await lifespan_nursery.start(lifespan.handle_lifespan)
@@ -82,6 +86,7 @@ async def worker_serve(
8286
async with trio.open_nursery(strict_exception_groups=True) as nursery:
8387
if shutdown_trigger is not None:
8488
nursery.start_soon(raise_shutdown, shutdown_trigger)
89+
nursery.start_soon(raise_shutdown, context.terminate.wait)
8590

8691
nursery.start_soon(
8792
partial(

src/hypercorn/trio/worker_context.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Type, Union
3+
from typing import Optional, Type, Union
44

55
import trio
66

@@ -27,9 +27,20 @@ def is_set(self) -> bool:
2727
class WorkerContext:
2828
event_class: Type[Event] = EventWrapper
2929

30-
def __init__(self) -> None:
30+
def __init__(self, max_requests: Optional[int]) -> None:
31+
self.max_requests = max_requests
32+
self.requests = 0
33+
self.terminate = self.event_class()
3134
self.terminated = self.event_class()
3235

36+
async def mark_request(self) -> None:
37+
if self.max_requests is None:
38+
return
39+
40+
self.requests += 1
41+
if self.requests > self.max_requests:
42+
await self.terminate.set()
43+
3344
@staticmethod
3445
async def sleep(wait: Union[float, int]) -> None:
3546
return await trio.sleep(wait)

0 commit comments

Comments
 (0)