@@ -540,7 +540,7 @@ def __init__(
540540 self .max_subscriptions = asyncio .Semaphore (max_subscriptions )
541541 self .max_connections = max_connections
542542 self .shutdown_timer = shutdown_timer
543- self ._received = {}
543+ self ._received : dict [ str , asyncio . Future ] = {}
544544 self ._sending = asyncio .Queue ()
545545 self ._receiving_task = None # TODO rename, as this now does send/recv
546546 self ._attempts = 0
@@ -601,22 +601,23 @@ async def connect(self, force=False):
601601 now = await self .loop_time ()
602602 self .last_received = now
603603 self .last_sent = now
604- if self ._exit_task :
605- self ._exit_task .cancel ()
606- if self .state != State .CLOSING :
607- if not self ._initialized or force :
608- try :
609- await asyncio .wait_for (self ._cancel (), timeout = 10.0 )
610- except asyncio .TimeoutError :
611- pass
612- self .ws = await asyncio .wait_for (
613- connect (self .ws_url , ** self ._options ), timeout = 10.0
614- )
615- if self ._receiving_task is None or self ._receiving_task .done ():
616- self ._receiving_task = asyncio .get_running_loop ().create_task (
617- self ._handler (self .ws )
604+ async with self ._lock :
605+ if self ._exit_task :
606+ self ._exit_task .cancel ()
607+ if self .state not in (State .OPEN , State .CONNECTING ):
608+ if not self ._initialized or force :
609+ try :
610+ await asyncio .wait_for (self ._cancel (), timeout = 10.0 )
611+ except asyncio .TimeoutError :
612+ pass
613+ self .ws = await asyncio .wait_for (
614+ connect (self .ws_url , ** self ._options ), timeout = 10.0
618615 )
619- self ._initialized = True
616+ if self ._receiving_task is None or self ._receiving_task .done ():
617+ self ._receiving_task = asyncio .get_running_loop ().create_task (
618+ self ._handler (self .ws )
619+ )
620+ self ._initialized = True
620621
621622 async def _handler (self , ws : ClientConnection ):
622623 consumer_task = asyncio .create_task (self ._start_receiving (ws ))
@@ -669,10 +670,10 @@ async def _recv(self, recd) -> None:
669670 response = json .loads (recd )
670671 self .last_received = await self .loop_time ()
671672 if "id" in response :
672- self ._received [response ["id" ]] = response
673+ self ._received [response ["id" ]]. set_result ( response )
673674 self ._in_use_ids .remove (response ["id" ])
674675 elif "params" in response :
675- self ._received [response ["params" ]["subscription" ]] = response
676+ self ._received [response ["params" ]["subscription" ]]. set_result ( response )
676677 else :
677678 raise KeyError (response )
678679 except ssl .SSLError :
@@ -685,19 +686,26 @@ async def _start_receiving(self, ws: ClientConnection) -> Exception:
685686 async for recd in ws :
686687 await self ._recv (recd )
687688 except Exception as e :
688- return e
689+ for i in self ._received .keys ():
690+ self ._received [i ].set_exception (e )
691+ return
689692
690693 async def _start_sending (self , ws ) -> Exception :
694+ to_send = None
691695 try :
692696 while True :
693- logger .info ("699 Not Empty" )
694697 to_send = await self ._sending .get ()
695698 if self ._log_raw_websockets :
696- raw_websocket_logger .debug (f"WEBSOCKET_SEND> { to_send } " )
699+ raw_websocket_logger .debug (f"WEBSOCKET_SEND> { to_send } } ")
697700 await ws .send (json .dumps (to_send ))
698701 self .last_sent = await self .loop_time ()
699702 except Exception as e :
700- return e
703+ if to_send is not None :
704+ self ._received [to_send ["id" ]].set_exception (e )
705+ else :
706+ for i in self ._received .keys ():
707+ self ._received [i ].set_exception (e )
708+ return
701709
702710 async def send (self , payload : dict ) -> str :
703711 """
@@ -715,8 +723,8 @@ async def send(self, payload: dict) -> str:
715723 while original_id in self ._in_use_ids :
716724 original_id = get_next_id ()
717725 self ._in_use_ids .add (original_id )
726+ self ._received [original_id ] = asyncio .get_running_loop ().create_future ()
718727 to_send = {** payload , ** {"id" : original_id }}
719- logger .info (f"Sending { to_send } " )
720728 await self ._sending .put (to_send )
721729 return original_id
722730
@@ -730,11 +738,12 @@ async def retrieve(self, item_id: int) -> Optional[dict]:
730738 Returns:
731739 retrieved item
732740 """
733- try :
734- item = self . _received . pop ( item_id )
741+ item : asyncio . Future = self . _received . get ( item_id )
742+ if item . done ():
735743 self .max_subscriptions .release ()
736- return item
737- except KeyError :
744+ del self ._received [item_id ]
745+ return item .result ()
746+ else :
738747 await asyncio .sleep (0.1 )
739748 return None
740749
@@ -2263,16 +2272,9 @@ async def _make_rpc_request(
22632272 subscription_added = False
22642273
22652274 async with self .ws as ws :
2266- if len (payloads ) > 1 :
2267- send_coroutines = await asyncio .gather (
2268- * [ws .send (item ["payload" ]) for item in payloads ]
2269- )
2270- for item_id , item in zip (send_coroutines , payloads ):
2271- request_manager .add_request (item_id , item ["id" ])
2272- else :
2273- item = payloads [0 ]
2274- item_id = await ws .send (item ["payload" ])
2275- request_manager .add_request (item_id , item ["id" ])
2275+ for payload in payloads :
2276+ item_id = await ws .send (payload )
2277+ request_manager .add_request (item_id , payload ["id" ])
22762278
22772279 while True :
22782280 for item_id in list (request_manager .response_map .keys ()):
@@ -2311,17 +2313,17 @@ async def _make_rpc_request(
23112313 if request_manager .is_complete :
23122314 break
23132315 if (
2314- (current_time := await self . ws .loop_time ()) - self . ws .last_received
2316+ (current_time := await ws .loop_time ()) - ws .last_received
23152317 >= self .retry_timeout
2316- and current_time - self . ws .last_sent >= self .retry_timeout
2318+ and current_time - ws .last_sent >= self .retry_timeout
23172319 ):
23182320 if attempt >= self .max_retries :
23192321 logger .error (
23202322 f"Timed out waiting for RPC requests { attempt } times. Exiting."
23212323 )
23222324 raise MaxRetriesExceeded ("Max retries reached." )
23232325 else :
2324- self .ws .last_received = time . time ()
2326+ self .ws .last_received = await ws . loop_time ()
23252327 await self .ws .connect (force = True )
23262328 logger .warning (
23272329 f"Timed out waiting for RPC requests. "
0 commit comments