Skip to content

Commit 8f54649

Browse files
committed
WIP
1 parent 78cf8c2 commit 8f54649

File tree

1 file changed

+41
-17
lines changed

1 file changed

+41
-17
lines changed

async_substrate_interface/async_substrate.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,7 @@ def __init__(
524524
shutdown_timer=5,
525525
options: Optional[dict] = None,
526526
_log_raw_websockets: bool = False,
527+
retry_timeout: float = 60.0
527528
):
528529
"""
529530
Websocket manager object. Allows for the use of a single websocket connection by multiple
@@ -542,10 +543,12 @@ def __init__(
542543
self.max_subscriptions = asyncio.Semaphore(max_subscriptions)
543544
self.max_connections = max_connections
544545
self.shutdown_timer = shutdown_timer
546+
self.retry_timeout = retry_timeout
545547
self._received: dict[str, asyncio.Future] = {}
546548
self._received_subscriptions: dict[str, asyncio.Queue] = {}
547549
self._sending = asyncio.Queue()
548-
self._receiving_task = None # TODO rename, as this now does send/recv
550+
self._send_recv_task = None
551+
self._inflight: dict[str, str] = {}
549552
self._attempts = 0
550553
self._initialized = False # TODO remove
551554
self._lock = asyncio.Lock()
@@ -586,8 +589,8 @@ async def loop_time() -> float:
586589

587590
async def _cancel(self):
588591
try:
589-
self._receiving_task.cancel()
590-
await self._receiving_task
592+
self._send_recv_task.cancel()
593+
await self._send_recv_task
591594
await self.ws.close()
592595
except (
593596
AttributeError,
@@ -601,13 +604,14 @@ async def _cancel(self):
601604
)
602605

603606
async def connect(self, force=False):
607+
# TODO after connecting, move from _inflight to the queue
604608
now = await self.loop_time()
605609
self.last_received = now
606610
self.last_sent = now
607611
async with self._lock:
608612
if self._exit_task:
609613
self._exit_task.cancel()
610-
if self.state not in (State.OPEN, State.CONNECTING):
614+
if self.state not in (State.OPEN, State.CONNECTING) or force:
611615
if not self._initialized or force:
612616
try:
613617
await asyncio.wait_for(self._cancel(), timeout=10.0)
@@ -616,21 +620,34 @@ async def connect(self, force=False):
616620
self.ws = await asyncio.wait_for(
617621
connect(self.ws_url, **self._options), timeout=10.0
618622
)
619-
if self._receiving_task is None or self._receiving_task.done():
620-
self._receiving_task = asyncio.get_running_loop().create_task(
623+
if self._send_recv_task is None or self._send_recv_task.done():
624+
self._send_recv_task = asyncio.get_running_loop().create_task(
621625
self._handler(self.ws)
622626
)
623627
self._initialized = True
624628

625-
async def _handler(self, ws: ClientConnection):
626-
consumer_task = asyncio.create_task(self._start_receiving(ws))
627-
producer_task = asyncio.create_task(self._start_sending(ws))
629+
async def _handler(self, ws: ClientConnection) -> None:
630+
recv_task = asyncio.create_task(self._start_receiving(ws))
631+
send_task = asyncio.create_task(self._start_sending(ws))
628632
done, pending = await asyncio.wait(
629-
[consumer_task, producer_task],
633+
[recv_task, send_task],
630634
return_when=asyncio.FIRST_COMPLETED,
631635
)
636+
loop = asyncio.get_running_loop()
637+
should_reconnect = False
632638
for task in pending:
633639
task.cancel()
640+
if isinstance(task.exception(), asyncio.TimeoutError):
641+
should_reconnect = True
642+
if should_reconnect is True:
643+
for original_id, payload in list(self._inflight.items()):
644+
self._received[original_id] = loop.create_future()
645+
to_send = json.loads(payload)
646+
await self._sending.put(to_send)
647+
logger.info("Timeout occurred. Reconnecting.")
648+
await self.connect(True)
649+
await self._handler(ws=ws)
650+
634651

635652
async def __aexit__(self, exc_type, exc_val, exc_tb):
636653
if not self.state != State.CONNECTING:
@@ -662,7 +679,7 @@ async def shutdown(self):
662679
pass
663680
self.ws = None
664681
self._initialized = False
665-
self._receiving_task = None
682+
self._send_recv_task = None
666683
self._is_closing = False
667684

668685
async def _recv(self, recd: bytes) -> None:
@@ -671,9 +688,12 @@ async def _recv(self, recd: bytes) -> None:
671688
response = json.loads(recd)
672689
self.last_received = await self.loop_time()
673690
if "id" in response:
691+
async with self._lock:
692+
self._inflight.pop(response["id"])
674693
self._received[response["id"]].set_result(response)
675694
self._in_use_ids.remove(response["id"])
676695
elif "params" in response:
696+
# TODO self._inflight won't work with subscriptions
677697
sub_id = response["params"]["subscription"]
678698
await self._received_subscriptions[sub_id].put(response)
679699
else:
@@ -682,7 +702,9 @@ async def _recv(self, recd: bytes) -> None:
682702
async def _start_receiving(self, ws: ClientConnection) -> Exception:
683703
try:
684704
while True:
685-
await self._recv(await ws.recv(decode=False))
705+
if self._inflight:
706+
recd = await asyncio.wait_for(ws.recv(decode=False), timeout=self.retry_timeout)
707+
await self._recv(recd)
686708
except Exception as e:
687709
if isinstance(e, ssl.SSLError):
688710
e = ConnectionClosed
@@ -696,13 +718,14 @@ async def _start_sending(self, ws) -> Exception:
696718
to_send = None
697719
try:
698720
while True:
699-
# TODO possibly when these are pulled from the Queue, they should also go into a dict or set, with the
700-
# TODO done_callback assigned to remove them when complete. This could allow easier resending in cases
701-
# TODO such as a timeout.
702-
to_send = await self._sending.get()
721+
to_send_ = await self._sending.get()
722+
send_id = to_send_["id"]
723+
to_send = json.dumps(to_send_)
724+
async with self._lock:
725+
self._inflight[send_id] = to_send
703726
if self._log_raw_websockets:
704727
raw_websocket_logger.debug(f"WEBSOCKET_SEND> {to_send}")
705-
await ws.send(json.dumps(to_send))
728+
await ws.send(to_send)
706729
self.last_sent = await self.loop_time()
707730
except Exception as e:
708731
if to_send is not None:
@@ -824,6 +847,7 @@ def __init__(
824847
"write_limit": 2**16,
825848
},
826849
shutdown_timer=ws_shutdown_timer,
850+
retry_timeout=self.retry_timeout,
827851
)
828852
else:
829853
self.ws = AsyncMock(spec=Websocket)

0 commit comments

Comments
 (0)