Skip to content

Commit 12676fa

Browse files
committed
WIP check-in
1 parent c325f25 commit 12676fa

File tree

1 file changed

+83
-77
lines changed

1 file changed

+83
-77
lines changed

async_substrate_interface/async_substrate.py

Lines changed: 83 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@
3131
ss58_encode,
3232
MultiAccountId,
3333
)
34-
from websockets.asyncio.client import connect
34+
from websockets.asyncio.client import connect, ClientConnection
3535
from websockets.exceptions import ConnectionClosed, WebSocketException
36+
from websockets.protocol import State
3637

3738
from async_substrate_interface.errors import (
3839
SubstrateRequestException,
@@ -72,9 +73,6 @@
7273
decode_query_map,
7374
)
7475

75-
if TYPE_CHECKING:
76-
from websockets.asyncio.client import ClientConnection
77-
7876
ResultHandler = Callable[[dict, Any], Awaitable[tuple[dict, bool]]]
7977

8078
logger = logging.getLogger("async_substrate_interface")
@@ -516,6 +514,7 @@ def __getitem__(self, item):
516514

517515

518516
class Websocket:
517+
ws:
519518
def __init__(
520519
self,
521520
ws_url: str,
@@ -538,22 +537,19 @@ def __init__(
538537
# TODO allow setting max concurrent connections and rpc subscriptions per connection
539538
# TODO reconnection logic
540539
self.ws_url = ws_url
541-
self.ws: Optional["ClientConnection"] = None
540+
self.ws: Optional[ClientConnection] = None
542541
self.max_subscriptions = asyncio.Semaphore(max_subscriptions)
543542
self.max_connections = max_connections
544543
self.shutdown_timer = shutdown_timer
545544
self._received = {}
546-
self._in_use = 0
547-
self._receiving_task = None
545+
self._sending = asyncio.Queue()
546+
self._receiving_task = None # TODO rename, as this now does send/recv
548547
self._attempts = 0
549-
self._initialized = False
548+
self._initialized = False # TODO remove
550549
self._lock = asyncio.Lock()
551550
self._exit_task = None
552-
self._open_subscriptions = 0
553551
self._options = options if options else {}
554552
self._log_raw_websockets = _log_raw_websockets
555-
self._is_connecting = False
556-
self._is_closing = False
557553

558554
try:
559555
now = asyncio.get_running_loop().time()
@@ -570,9 +566,16 @@ def __init__(
570566
self.last_sent = now
571567
self._in_use_ids = set()
572568

569+
@property
570+
def state(self):
571+
if self.ws is None:
572+
return State.CLOSED
573+
else:
574+
return self.ws.state
575+
573576
async def __aenter__(self):
574-
self._in_use += 1
575-
await self.connect()
577+
if self.state not in (State.CONNECTING, State.OPEN):
578+
await self.connect()
576579
return self
577580

578581
@staticmethod
@@ -596,47 +599,47 @@ async def _cancel(self):
596599
)
597600

598601
async def connect(self, force=False):
599-
self._is_connecting = True
600-
try:
601-
now = await self.loop_time()
602-
self.last_received = now
603-
self.last_sent = now
604-
if self._exit_task:
605-
self._exit_task.cancel()
606-
if not self._is_closing:
607-
if not self._initialized or force:
608-
try:
609-
await asyncio.wait_for(self._cancel(), timeout=10.0)
610-
except asyncio.TimeoutError:
611-
pass
612-
613-
self.ws = await asyncio.wait_for(
614-
connect(self.ws_url, **self._options), timeout=10.0
602+
now = await self.loop_time()
603+
self.last_received = now
604+
self.last_sent = now
605+
if self._exit_task:
606+
self._exit_task.cancel()
607+
if self.state != State.CLOSING:
608+
if not self._initialized or force:
609+
try:
610+
await asyncio.wait_for(self._cancel(), timeout=10.0)
611+
except asyncio.TimeoutError:
612+
pass
613+
self.ws = await asyncio.wait_for(
614+
connect(self.ws_url, **self._options), timeout=10.0
615+
)
616+
if self._receiving_task is None or self._receiving_task.done():
617+
self._receiving_task = asyncio.get_running_loop().create_task(
618+
self._handler(self.ws)
615619
)
616-
if self._receiving_task is None or self._receiving_task.done():
617-
self._receiving_task = asyncio.get_running_loop().create_task(
618-
self._start_receiving()
619-
)
620-
self._initialized = True
621-
finally:
622-
self._is_connecting = False
620+
self._initialized = True
621+
622+
async def _handler(self, ws: ClientConnection):
623+
consumer_task = asyncio.create_task(self._start_receiving(ws))
624+
producer_task = asyncio.create_task(self._start_sending(ws))
625+
# TODO should attach futures and add exceptions to them
626+
done, pending = await asyncio.wait(
627+
[consumer_task, producer_task],
628+
return_when=asyncio.FIRST_COMPLETED,
629+
)
630+
for task in pending:
631+
task.cancel()
623632

624633
async def __aexit__(self, exc_type, exc_val, exc_tb):
625-
self._is_closing = True
626-
try:
627-
if not self._is_connecting:
628-
self._in_use -= 1
629-
if self._exit_task is not None:
630-
self._exit_task.cancel()
631-
try:
632-
await self._exit_task
633-
except asyncio.CancelledError:
634-
pass
635-
if self._in_use == 0 and self.ws is not None:
636-
self._open_subscriptions = 0
637-
self._exit_task = asyncio.create_task(self._exit_with_timer())
638-
finally:
639-
self._is_closing = False
634+
if not self.state != State.CONNECTING:
635+
if self._exit_task is not None:
636+
self._exit_task.cancel()
637+
try:
638+
await self._exit_task
639+
except asyncio.CancelledError:
640+
pass
641+
if self.ws is not None:
642+
self._exit_task = asyncio.create_task(self._exit_with_timer())
640643

641644
async def _exit_with_timer(self):
642645
"""
@@ -660,12 +663,10 @@ async def shutdown(self):
660663
self._receiving_task = None
661664
self._is_closing = False
662665

663-
async def _recv(self) -> None:
666+
async def _recv(self, recd) -> None:
664667
try:
665-
# TODO consider wrapping this in asyncio.wait_for and use that for the timeout logic
666-
recd = await self.ws.recv(decode=False)
667668
if self._log_raw_websockets:
668-
raw_websocket_logger.debug(f"WEBSOCKET_RECEIVE> {recd.decode()}")
669+
raw_websocket_logger.debug(f"WEBSOCKET_RECEIVE> {recd}")
669670
response = json.loads(recd)
670671
self.last_received = await self.loop_time()
671672
if "id" in response:
@@ -680,14 +681,24 @@ async def _recv(self) -> None:
680681
except (ConnectionClosed, KeyError):
681682
raise
682683

683-
async def _start_receiving(self):
684+
async def _start_receiving(self, ws: ClientConnection) -> Exception:
685+
try:
686+
async for recd in ws:
687+
await self._recv(recd)
688+
except Exception as e:
689+
return e
690+
691+
async def _start_sending(self, ws) -> Exception:
684692
try:
685693
while True:
686-
await self._recv()
687-
except asyncio.CancelledError:
688-
pass
689-
except ConnectionClosed:
690-
await self.connect(force=True)
694+
logger.info("699 Not Empty")
695+
to_send = await self._sending.get()
696+
if self._log_raw_websockets:
697+
raw_websocket_logger.debug(f"WEBSOCKET_SEND> {to_send}")
698+
await ws.send(json.dumps(to_send))
699+
self.last_sent = await self.loop_time()
700+
except Exception as e:
701+
return e
691702

692703
async def send(self, payload: dict) -> str:
693704
"""
@@ -699,22 +710,16 @@ async def send(self, payload: dict) -> str:
699710
Returns:
700711
id: the internal ID of the request (incremented int)
701712
"""
702-
original_id = get_next_id()
703-
while original_id in self._in_use_ids:
704-
original_id = get_next_id()
705-
self._in_use_ids.add(original_id)
706-
# self._open_subscriptions += 1
707713
await self.max_subscriptions.acquire()
708-
try:
709-
to_send = {**payload, **{"id": original_id}}
710-
if self._log_raw_websockets:
711-
raw_websocket_logger.debug(f"WEBSOCKET_SEND> {to_send}")
712-
await self.ws.send(json.dumps(to_send))
713-
self.last_sent = await self.loop_time()
714-
return original_id
715-
except (ConnectionClosed, ssl.SSLError, EOFError):
716-
await self.connect(force=True)
717-
return await self.send(payload)
714+
async with self._lock:
715+
original_id = get_next_id()
716+
while original_id in self._in_use_ids:
717+
original_id = get_next_id()
718+
self._in_use_ids.add(original_id)
719+
to_send = {**payload, **{"id": original_id}}
720+
logger.info(f"Sending {to_send}")
721+
await self._sending.put(to_send)
722+
return original_id
718723

719724
async def retrieve(self, item_id: int) -> Optional[dict]:
720725
"""
@@ -827,6 +832,7 @@ async def initialize(self):
827832
"""
828833
self._initializing = True
829834
if not self.initialized:
835+
await self.ws.connect()
830836
if not self._chain:
831837
chain = await self.rpc_request("system_chain", [])
832838
self._chain = chain.get("result")
@@ -845,7 +851,7 @@ async def initialize(self):
845851
self._initializing = False
846852

847853
async def __aexit__(self, exc_type, exc_val, exc_tb):
848-
pass
854+
await self.ws.shutdown()
849855

850856
@property
851857
def metadata(self):

0 commit comments

Comments
 (0)