Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 47 additions & 16 deletions async_substrate_interface/async_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ def __init__(
options: Optional[dict] = None,
_log_raw_websockets: bool = False,
retry_timeout: float = 60.0,
max_retries: int = 5,
):
"""
Websocket manager object. Allows for the use of a single websocket connection by multiple
Expand All @@ -536,6 +537,10 @@ def __init__(
max_subscriptions: Maximum number of subscriptions per websocket connection
max_connections: Maximum number of connections total
shutdown_timer: Number of seconds to shut down websocket connection after last use
options: Options to pass to the websocket connection
_log_raw_websockets: Whether to log raw websockets in the "raw_websocket" logger
retry_timeout: Timeout in seconds to retry websocket connection
max_retries: Maximum number of retries following a timeout
"""
# TODO allow setting max concurrent connections and rpc subscriptions per connection
self.ws_url = ws_url
Expand All @@ -555,6 +560,7 @@ def __init__(
self._options = options if options else {}
self._log_raw_websockets = _log_raw_websockets
self._in_use_ids = set()
self._max_retries = max_retries

@property
def state(self):
Expand All @@ -575,7 +581,6 @@ async def loop_time() -> float:
async def _cancel(self):
try:
self._send_recv_task.cancel()
await self._send_recv_task
await self.ws.close()
except (
AttributeError,
Expand Down Expand Up @@ -616,19 +621,30 @@ async def _handler(self, ws: ClientConnection) -> None:
)
loop = asyncio.get_running_loop()
should_reconnect = False
is_retry = False
for task in pending:
task.cancel()
for task in done:
if isinstance(task.result(), (asyncio.TimeoutError, ConnectionClosed)):
task_res = task.result()
if isinstance(
task_res, (asyncio.TimeoutError, ConnectionClosed, TimeoutError)
):
should_reconnect = True
if isinstance(task_res, (asyncio.TimeoutError, TimeoutError)):
self._attempts += 1
is_retry = True
if should_reconnect is True:
for original_id, payload in list(self._inflight.items()):
self._received[original_id] = loop.create_future()
to_send = json.loads(payload)
await self._sending.put(to_send)
logger.info("Timeout occurred. Reconnecting.")
if is_retry:
# Otherwise the connection was just closed due to no activity, which should not count against retries
logger.info(
f"Timeout occurred. Reconnecting. Attempt {self._attempts} of {self._max_retries}"
)
await self.connect(True)
await self._handler(ws=ws)
await self._handler(ws=self.ws)
elif isinstance(e := recv_task.result(), Exception):
return e
elif isinstance(e := send_task.result(), Exception):
Expand Down Expand Up @@ -689,15 +705,22 @@ async def _start_receiving(self, ws: ClientConnection) -> Exception:
recd = await asyncio.wait_for(
ws.recv(decode=False), timeout=self.retry_timeout
)
# reset the counter once we successfully receive something back
self._attempts = 0
await self._recv(recd)
except Exception as e:
logger.exception("Start receiving exception", exc_info=e)
if isinstance(e, ssl.SSLError):
e = ConnectionClosed
for fut in self._received.values():
if not fut.done():
fut.set_exception(e)
fut.cancel()
if not isinstance(
e, (asyncio.TimeoutError, TimeoutError, ConnectionClosed)
):
logger.exception("Websocket receiving exception", exc_info=e)
for fut in self._received.values():
if not fut.done():
fut.set_exception(e)
fut.cancel()
else:
logger.warning("Timeout occurred. Reconnecting.")
return e

async def _start_sending(self, ws) -> Exception:
Expand All @@ -713,14 +736,21 @@ async def _start_sending(self, ws) -> Exception:
raw_websocket_logger.debug(f"WEBSOCKET_SEND> {to_send}")
await ws.send(to_send)
except Exception as e:
logger.exception("Start sending exception", exc_info=e)
if to_send is not None:
self._received[to_send["id"]].set_exception(e)
self._received[to_send["id"]].cancel()
if isinstance(e, ssl.SSLError):
e = ConnectionClosed
if not isinstance(
e, (asyncio.TimeoutError, TimeoutError, ConnectionClosed)
):
logger.exception("Websocket sending exception", exc_info=e)
if to_send is not None:
self._received[to_send["id"]].set_exception(e)
self._received[to_send["id"]].cancel()
else:
for i in self._received.keys():
self._received[i].set_exception(e)
self._received[i].cancel()
else:
for i in self._received.keys():
self._received[i].set_exception(e)
self._received[i].cancel()
logger.warning("Timeout occurred. Reconnecting.")
return e

async def send(self, payload: dict) -> str:
Expand Down Expand Up @@ -860,6 +890,7 @@ def __init__(
},
shutdown_timer=ws_shutdown_timer,
retry_timeout=self.retry_timeout,
max_retries=max_retries,
)
else:
self.ws = AsyncMock(spec=Websocket)
Expand Down
12 changes: 12 additions & 0 deletions tests/integration_tests/test_async_substrate_interface.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import time

import pytest
Expand Down Expand Up @@ -149,3 +150,14 @@ async def test_query_multiple():
storage_function="OwnedHotkeys",
block_hash=block_hash,
)


@pytest.mark.asyncio
async def test_reconnection():
async with AsyncSubstrateInterface(
ARCHIVE_ENTRYPOINT, ss58_format=42, retry_timeout=8.0
) as substrate:
await asyncio.sleep(9) # sleep for longer than the retry timeout
bh = await substrate.get_chain_finalised_head()
assert isinstance(bh, str)
assert isinstance(await substrate.get_block_number(bh), int)
Loading