3131 ss58_encode ,
3232 MultiAccountId ,
3333)
34- from websockets .asyncio .client import connect
34+ from websockets .asyncio .client import connect , ClientConnection
3535from websockets .exceptions import ConnectionClosed , WebSocketException
36+ from websockets .protocol import State
3637
3738from async_substrate_interface .errors import (
3839 SubstrateRequestException ,
7273 decode_query_map ,
7374)
7475
75- if TYPE_CHECKING :
76- from websockets .asyncio .client import ClientConnection
77-
7876ResultHandler = Callable [[dict , Any ], Awaitable [tuple [dict , bool ]]]
7977
8078logger = logging .getLogger ("async_substrate_interface" )
@@ -516,6 +514,7 @@ def __getitem__(self, item):
516514
517515
518516class Websocket :
517+ ws :
519518 def __init__ (
520519 self ,
521520 ws_url : str ,
@@ -538,22 +537,19 @@ def __init__(
538537 # TODO allow setting max concurrent connections and rpc subscriptions per connection
539538 # TODO reconnection logic
540539 self .ws_url = ws_url
541- self .ws : Optional [" ClientConnection" ] = None
540+ self .ws : Optional [ClientConnection ] = None
542541 self .max_subscriptions = asyncio .Semaphore (max_subscriptions )
543542 self .max_connections = max_connections
544543 self .shutdown_timer = shutdown_timer
545544 self ._received = {}
546- self ._in_use = 0
547- self ._receiving_task = None
545+ self ._sending = asyncio . Queue ()
546+ self ._receiving_task = None # TODO rename, as this now does send/recv
548547 self ._attempts = 0
549- self ._initialized = False
548+ self ._initialized = False # TODO remove
550549 self ._lock = asyncio .Lock ()
551550 self ._exit_task = None
552- self ._open_subscriptions = 0
553551 self ._options = options if options else {}
554552 self ._log_raw_websockets = _log_raw_websockets
555- self ._is_connecting = False
556- self ._is_closing = False
557553
558554 try :
559555 now = asyncio .get_running_loop ().time ()
@@ -570,9 +566,16 @@ def __init__(
570566 self .last_sent = now
571567 self ._in_use_ids = set ()
572568
569+ @property
570+ def state (self ):
571+ if self .ws is None :
572+ return State .CLOSED
573+ else :
574+ return self .ws .state
575+
573576 async def __aenter__ (self ):
574- self ._in_use += 1
575- await self .connect ()
577+ if self .state not in ( State . CONNECTING , State . OPEN ):
578+ await self .connect ()
576579 return self
577580
578581 @staticmethod
@@ -596,47 +599,47 @@ async def _cancel(self):
596599 )
597600
598601 async def connect (self , force = False ):
599- self ._is_connecting = True
600- try :
601- now = await self .loop_time ()
602- self .last_received = now
603- self .last_sent = now
604- if self ._exit_task :
605- self ._exit_task .cancel ()
606- if not self ._is_closing :
607- if not self ._initialized or force :
608- try :
609- await asyncio .wait_for (self ._cancel (), timeout = 10.0 )
610- except asyncio .TimeoutError :
611- pass
612-
613- self .ws = await asyncio .wait_for (
614- connect (self .ws_url , ** self ._options ), timeout = 10.0
602+ now = await self .loop_time ()
603+ self .last_received = now
604+ self .last_sent = now
605+ if self ._exit_task :
606+ self ._exit_task .cancel ()
607+ if self .state != State .CLOSING :
608+ if not self ._initialized or force :
609+ try :
610+ await asyncio .wait_for (self ._cancel (), timeout = 10.0 )
611+ except asyncio .TimeoutError :
612+ pass
613+ self .ws = await asyncio .wait_for (
614+ connect (self .ws_url , ** self ._options ), timeout = 10.0
615+ )
616+ if self ._receiving_task is None or self ._receiving_task .done ():
617+ self ._receiving_task = asyncio .get_running_loop ().create_task (
618+ self ._handler (self .ws )
615619 )
616- if self ._receiving_task is None or self ._receiving_task .done ():
617- self ._receiving_task = asyncio .get_running_loop ().create_task (
618- self ._start_receiving ()
619- )
620- self ._initialized = True
621- finally :
622- self ._is_connecting = False
620+ self ._initialized = True
621+
622+ async def _handler (self , ws : ClientConnection ):
623+ consumer_task = asyncio .create_task (self ._start_receiving (ws ))
624+ producer_task = asyncio .create_task (self ._start_sending (ws ))
625+ # TODO should attach futures and add exceptions to them
626+ done , pending = await asyncio .wait (
627+ [consumer_task , producer_task ],
628+ return_when = asyncio .FIRST_COMPLETED ,
629+ )
630+ for task in pending :
631+ task .cancel ()
623632
624633 async def __aexit__ (self , exc_type , exc_val , exc_tb ):
625- self ._is_closing = True
626- try :
627- if not self ._is_connecting :
628- self ._in_use -= 1
629- if self ._exit_task is not None :
630- self ._exit_task .cancel ()
631- try :
632- await self ._exit_task
633- except asyncio .CancelledError :
634- pass
635- if self ._in_use == 0 and self .ws is not None :
636- self ._open_subscriptions = 0
637- self ._exit_task = asyncio .create_task (self ._exit_with_timer ())
638- finally :
639- self ._is_closing = False
634+ if not self .state != State .CONNECTING :
635+ if self ._exit_task is not None :
636+ self ._exit_task .cancel ()
637+ try :
638+ await self ._exit_task
639+ except asyncio .CancelledError :
640+ pass
641+ if self .ws is not None :
642+ self ._exit_task = asyncio .create_task (self ._exit_with_timer ())
640643
641644 async def _exit_with_timer (self ):
642645 """
@@ -660,12 +663,10 @@ async def shutdown(self):
660663 self ._receiving_task = None
661664 self ._is_closing = False
662665
663- async def _recv (self ) -> None :
666+ async def _recv (self , recd ) -> None :
664667 try :
665- # TODO consider wrapping this in asyncio.wait_for and use that for the timeout logic
666- recd = await self .ws .recv (decode = False )
667668 if self ._log_raw_websockets :
668- raw_websocket_logger .debug (f"WEBSOCKET_RECEIVE> { recd . decode () } " )
669+ raw_websocket_logger .debug (f"WEBSOCKET_RECEIVE> { recd } " )
669670 response = json .loads (recd )
670671 self .last_received = await self .loop_time ()
671672 if "id" in response :
@@ -680,14 +681,24 @@ async def _recv(self) -> None:
680681 except (ConnectionClosed , KeyError ):
681682 raise
682683
683- async def _start_receiving (self ):
684+ async def _start_receiving (self , ws : ClientConnection ) -> Exception :
685+ try :
686+ async for recd in ws :
687+ await self ._recv (recd )
688+ except Exception as e :
689+ return e
690+
691+ async def _start_sending (self , ws ) -> Exception :
684692 try :
685693 while True :
686- await self ._recv ()
687- except asyncio .CancelledError :
688- pass
689- except ConnectionClosed :
690- await self .connect (force = True )
694+ logger .info ("699 Not Empty" )
695+ to_send = await self ._sending .get ()
696+ if self ._log_raw_websockets :
697+ raw_websocket_logger .debug (f"WEBSOCKET_SEND> { to_send } " )
698+ await ws .send (json .dumps (to_send ))
699+ self .last_sent = await self .loop_time ()
700+ except Exception as e :
701+ return e
691702
692703 async def send (self , payload : dict ) -> str :
693704 """
@@ -699,22 +710,16 @@ async def send(self, payload: dict) -> str:
699710 Returns:
700711 id: the internal ID of the request (incremented int)
701712 """
702- original_id = get_next_id ()
703- while original_id in self ._in_use_ids :
704- original_id = get_next_id ()
705- self ._in_use_ids .add (original_id )
706- # self._open_subscriptions += 1
707713 await self .max_subscriptions .acquire ()
708- try :
709- to_send = {** payload , ** {"id" : original_id }}
710- if self ._log_raw_websockets :
711- raw_websocket_logger .debug (f"WEBSOCKET_SEND> { to_send } " )
712- await self .ws .send (json .dumps (to_send ))
713- self .last_sent = await self .loop_time ()
714- return original_id
715- except (ConnectionClosed , ssl .SSLError , EOFError ):
716- await self .connect (force = True )
717- return await self .send (payload )
714+ async with self ._lock :
715+ original_id = get_next_id ()
716+ while original_id in self ._in_use_ids :
717+ original_id = get_next_id ()
718+ self ._in_use_ids .add (original_id )
719+ to_send = {** payload , ** {"id" : original_id }}
720+ logger .info (f"Sending { to_send } " )
721+ await self ._sending .put (to_send )
722+ return original_id
718723
719724 async def retrieve (self , item_id : int ) -> Optional [dict ]:
720725 """
@@ -827,6 +832,7 @@ async def initialize(self):
827832 """
828833 self ._initializing = True
829834 if not self .initialized :
835+ await self .ws .connect ()
830836 if not self ._chain :
831837 chain = await self .rpc_request ("system_chain" , [])
832838 self ._chain = chain .get ("result" )
@@ -845,7 +851,7 @@ async def initialize(self):
845851 self ._initializing = False
846852
847853 async def __aexit__ (self , exc_type , exc_val , exc_tb ):
848- pass
854+ await self . ws . shutdown ()
849855
850856 @property
851857 def metadata (self ):
0 commit comments