diff --git a/async_substrate_interface/async_substrate.py b/async_substrate_interface/async_substrate.py index 609b046..598a882 100644 --- a/async_substrate_interface/async_substrate.py +++ b/async_substrate_interface/async_substrate.py @@ -32,7 +32,7 @@ MultiAccountId, ) from websockets.asyncio.client import connect -from websockets.exceptions import ConnectionClosed +from websockets.exceptions import ConnectionClosed, WebSocketException from async_substrate_interface.const import SS58_FORMAT from async_substrate_interface.errors import ( @@ -535,6 +535,8 @@ def __init__( self._open_subscriptions = 0 self._options = options if options else {} self._log_raw_websockets = _log_raw_websockets + self._is_connecting = False + self._is_closing = False try: now = asyncio.get_running_loop().time() @@ -560,38 +562,63 @@ async def __aenter__(self): async def loop_time() -> float: return asyncio.get_running_loop().time() + async def _cancel(self): + try: + self._receiving_task.cancel() + await self._receiving_task + await self.ws.close() + except ( + AttributeError, + asyncio.CancelledError, + WebSocketException, + ): + pass + except Exception as e: + logger.warning( + f"{e} encountered while trying to close websocket connection." + ) + async def connect(self, force=False): - now = await self.loop_time() - self.last_received = now - self.last_sent = now - if self._exit_task: - self._exit_task.cancel() - async with self._lock: - if not self._initialized or force: - try: - self._receiving_task.cancel() - await self._receiving_task - await self.ws.close() - except (AttributeError, asyncio.CancelledError): - pass - self.ws = await asyncio.wait_for( - connect(self.ws_url, **self._options), timeout=10 - ) - self._receiving_task = asyncio.create_task(self._start_receiving()) - self._initialized = True + self._is_connecting = True + try: + now = await self.loop_time() + self.last_received = now + self.last_sent = now + if self._exit_task: + self._exit_task.cancel() + if not self._is_closing: + if not self._initialized or force: + try: + await asyncio.wait_for(self._cancel(), timeout=10.0) + except asyncio.TimeoutError: + pass + + self.ws = await asyncio.wait_for( + connect(self.ws_url, **self._options), timeout=10.0 + ) + self._receiving_task = asyncio.get_running_loop().create_task( + self._start_receiving() + ) + self._initialized = True + finally: + self._is_connecting = False async def __aexit__(self, exc_type, exc_val, exc_tb): - async with self._lock: # TODO is this actually what I want to happen? - self._in_use -= 1 - if self._exit_task is not None: - self._exit_task.cancel() - try: - await self._exit_task - except asyncio.CancelledError: - pass - if self._in_use == 0 and self.ws is not None: - self._open_subscriptions = 0 - self._exit_task = asyncio.create_task(self._exit_with_timer()) + self._is_closing = True + try: + if not self._is_connecting: + self._in_use -= 1 + if self._exit_task is not None: + self._exit_task.cancel() + try: + await self._exit_task + except asyncio.CancelledError: + pass + if self._in_use == 0 and self.ws is not None: + self._open_subscriptions = 0 + self._exit_task = asyncio.create_task(self._exit_with_timer()) + finally: + self._is_closing = False async def _exit_with_timer(self): """ @@ -605,16 +632,15 @@ async def _exit_with_timer(self): pass async def shutdown(self): - async with self._lock: - try: - self._receiving_task.cancel() - await self._receiving_task - await self.ws.close() - except (AttributeError, asyncio.CancelledError): - pass - self.ws = None - self._initialized = False - self._receiving_task = None + self._is_closing = True + try: + await asyncio.wait_for(self._cancel(), timeout=10.0) + except asyncio.TimeoutError: + pass + self.ws = None + self._initialized = False + self._receiving_task = None + self._is_closing = False async def _recv(self) -> None: try: @@ -624,10 +650,6 @@ async def _recv(self) -> None: raw_websocket_logger.debug(f"WEBSOCKET_RECEIVE> {recd.decode()}") response = json.loads(recd) self.last_received = await self.loop_time() - async with self._lock: - # note that these 'subscriptions' are all waiting sent messages which have not received - # responses, and thus are not the same as RPC 'subscriptions', which are unique - self._open_subscriptions -= 1 if "id" in response: self._received[response["id"]] = response self._in_use_ids.remove(response["id"]) @@ -647,8 +669,7 @@ async def _start_receiving(self): except asyncio.CancelledError: pass except ConnectionClosed: - async with self._lock: - await self.connect(force=True) + await self.connect(force=True) async def send(self, payload: dict) -> int: """ @@ -674,8 +695,7 @@ async def send(self, payload: dict) -> int: self.last_sent = await self.loop_time() return original_id except (ConnectionClosed, ssl.SSLError, EOFError): - async with self._lock: - await self.connect(force=True) + await self.connect(force=True) async def retrieve(self, item_id: int) -> Optional[dict]: """ @@ -710,6 +730,7 @@ def __init__( retry_timeout: float = 60.0, _mock: bool = False, _log_raw_websockets: bool = False, + ws_shutdown_timer: float = 5.0, ): """ The asyncio-compatible version of the subtensor interface commands we use in bittensor. It is important to @@ -728,6 +749,7 @@ def __init__( retry_timeout: how to long wait since the last ping to retry the RPC request _mock: whether to use mock version of the subtensor interface _log_raw_websockets: whether to log raw websocket requests during RPC requests + ws_shutdown_timer: how long after the last connection your websocket should close """ self.max_retries = max_retries @@ -744,6 +766,7 @@ def __init__( "max_size": self.ws_max_size, "write_limit": 2**16, }, + shutdown_timer=ws_shutdown_timer, ) else: self.ws = AsyncMock(spec=Websocket) diff --git a/async_substrate_interface/substrate_addons.py b/async_substrate_interface/substrate_addons.py index 7ec19b4..578eb80 100644 --- a/async_substrate_interface/substrate_addons.py +++ b/async_substrate_interface/substrate_addons.py @@ -264,6 +264,7 @@ def __init__( _mock: bool = False, _log_raw_websockets: bool = False, archive_nodes: Optional[list[str]] = None, + ws_shutdown_timer: float = 5.0, ): fallback_chains = fallback_chains or [] archive_nodes = archive_nodes or [] @@ -291,6 +292,7 @@ def __init__( retry_timeout=retry_timeout, max_retries=max_retries, _log_raw_websockets=_log_raw_websockets, + ws_shutdown_timer=ws_shutdown_timer, ) self._original_methods = { method: getattr(self, method) for method in RETRY_METHODS diff --git a/tests/unit_tests/asyncio_/test_substrate_interface.py b/tests/unit_tests/asyncio_/test_substrate_interface.py index ea76595..a64d570 100644 --- a/tests/unit_tests/asyncio_/test_substrate_interface.py +++ b/tests/unit_tests/asyncio_/test_substrate_interface.py @@ -1,3 +1,4 @@ +import asyncio from unittest.mock import AsyncMock, MagicMock import pytest @@ -91,3 +92,22 @@ async def test_runtime_call(monkeypatch): substrate.rpc_request.assert_any_call( "state_call", ["SubstrateApi_SubstrateMethod", "", None] ) + + +@pytest.mark.asyncio +async def test_websocket_shutdown_timer(): + # using default ws shutdown timer of 5.0 seconds + async with AsyncSubstrateInterface("wss://lite.sub.latent.to:443") as substrate: + await substrate.get_chain_head() + await asyncio.sleep(6) + assert ( + substrate.ws._initialized is False + ) # connection should have closed automatically + + # using custom ws shutdown timer of 10.0 seconds + async with AsyncSubstrateInterface( + "wss://lite.sub.latent.to:443", ws_shutdown_timer=10.0 + ) as substrate: + await substrate.get_chain_head() + await asyncio.sleep(6) # same sleep time as before + assert substrate.ws._initialized is True # connection should still be open