Skip to content

Commit ce5cfc6

Browse files
committed
Add signal handlers for graceful disconnects:
- On ``SIGTERM`` and ``SIGNINT`` (KeyboardInterrupt), disconnect gracefully from the provider. - KeyboardInterrupt used to hang for a while because of the open connection. This makes it so we no longer wait the whole connection_timeout and it raises the KeyboardInterrupt exception immediately. Leave the responsibility of handling the KeyboardInterrupt to the user.
1 parent 82d6732 commit ce5cfc6

File tree

3 files changed

+29
-0
lines changed

3 files changed

+29
-0
lines changed

web3/providers/persistent/async_ipc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ async def _provider_specific_connect(self) -> None:
140140
)
141141

142142
async def _provider_specific_disconnect(self) -> None:
143+
# this should remain idempotent
143144
if self._writer and not self._writer.is_closing():
144145
self._writer.close()
145146
await self._writer.wait_closed()

web3/providers/persistent/persistent.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
)
55
import asyncio
66
import logging
7+
import signal
78
from typing import (
89
TYPE_CHECKING,
910
Any,
@@ -176,6 +177,7 @@ async def connect(self) -> None:
176177
self.logger.info(
177178
f"Successfully connected to: {self.get_endpoint_uri_or_ipc_path()}"
178179
)
180+
self._set_signal_handlers()
179181
break
180182
except (WebSocketException, OSError) as e:
181183
if _connection_attempts == self._max_connection_retries:
@@ -192,6 +194,7 @@ async def connect(self) -> None:
192194
_backoff_time *= _backoff_rate_change
193195

194196
async def disconnect(self) -> None:
197+
# this should remain idempotent
195198
try:
196199
if self._message_listener_task:
197200
self._message_listener_task.cancel()
@@ -260,11 +263,35 @@ async def _provider_specific_connect(self) -> None:
260263
raise NotImplementedError("Must be implemented by subclasses")
261264

262265
async def _provider_specific_disconnect(self) -> None:
266+
# this method should be idempotent
263267
raise NotImplementedError("Must be implemented by subclasses")
264268

265269
async def _provider_specific_socket_reader(self) -> RPCResponse:
266270
raise NotImplementedError("Must be implemented by subclasses")
267271

272+
def _set_signal_handlers(self) -> None:
273+
loop = asyncio.get_event_loop()
274+
275+
def extended_handler(sig: int, frame: Any, existing_handler: Any) -> None:
276+
loop.create_task(self.disconnect())
277+
278+
# invoke the existing handler, if callable
279+
if callable(existing_handler):
280+
existing_handler(sig, frame)
281+
282+
existing_sigint_handler = signal.getsignal(signal.SIGINT)
283+
existing_sigterm_handler = signal.getsignal(signal.SIGTERM)
284+
285+
# extend the existing signal handlers to include the disconnect method
286+
signal.signal(
287+
signal.SIGINT,
288+
lambda sig, frame: extended_handler(sig, frame, existing_sigint_handler),
289+
)
290+
signal.signal(
291+
signal.SIGTERM,
292+
lambda sig, frame: extended_handler(sig, frame, existing_sigterm_handler),
293+
)
294+
268295
def _message_listener_callback(
269296
self, message_listener_task: "asyncio.Task[None]"
270297
) -> None:

web3/providers/persistent/websocket.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ async def _provider_specific_connect(self) -> None:
132132
self._ws = await connect(self.endpoint_uri, **self.websocket_kwargs)
133133

134134
async def _provider_specific_disconnect(self) -> None:
135+
# this should remain idempotent
135136
if self._ws is not None and not self._ws.closed:
136137
await self._ws.close()
137138
self._ws = None

0 commit comments

Comments
 (0)