@@ -540,7 +540,7 @@ def __init__(
540
540
self .max_subscriptions = asyncio .Semaphore (max_subscriptions )
541
541
self .max_connections = max_connections
542
542
self .shutdown_timer = shutdown_timer
543
- self ._received = {}
543
+ self ._received : dict [ str , asyncio . Future ] = {}
544
544
self ._sending = asyncio .Queue ()
545
545
self ._receiving_task = None # TODO rename, as this now does send/recv
546
546
self ._attempts = 0
@@ -601,22 +601,23 @@ async def connect(self, force=False):
601
601
now = await self .loop_time ()
602
602
self .last_received = now
603
603
self .last_sent = now
604
- if self ._exit_task :
605
- self ._exit_task .cancel ()
606
- if self .state != State .CLOSING :
607
- if not self ._initialized or force :
608
- try :
609
- await asyncio .wait_for (self ._cancel (), timeout = 10.0 )
610
- except asyncio .TimeoutError :
611
- pass
612
- self .ws = await asyncio .wait_for (
613
- connect (self .ws_url , ** self ._options ), timeout = 10.0
614
- )
615
- if self ._receiving_task is None or self ._receiving_task .done ():
616
- self ._receiving_task = asyncio .get_running_loop ().create_task (
617
- self ._handler (self .ws )
604
+ async with self ._lock :
605
+ if self ._exit_task :
606
+ self ._exit_task .cancel ()
607
+ if self .state not in (State .OPEN , State .CONNECTING ):
608
+ if not self ._initialized or force :
609
+ try :
610
+ await asyncio .wait_for (self ._cancel (), timeout = 10.0 )
611
+ except asyncio .TimeoutError :
612
+ pass
613
+ self .ws = await asyncio .wait_for (
614
+ connect (self .ws_url , ** self ._options ), timeout = 10.0
618
615
)
619
- self ._initialized = True
616
+ if self ._receiving_task is None or self ._receiving_task .done ():
617
+ self ._receiving_task = asyncio .get_running_loop ().create_task (
618
+ self ._handler (self .ws )
619
+ )
620
+ self ._initialized = True
620
621
621
622
async def _handler (self , ws : ClientConnection ):
622
623
consumer_task = asyncio .create_task (self ._start_receiving (ws ))
@@ -669,10 +670,10 @@ async def _recv(self, recd) -> None:
669
670
response = json .loads (recd )
670
671
self .last_received = await self .loop_time ()
671
672
if "id" in response :
672
- self ._received [response ["id" ]] = response
673
+ self ._received [response ["id" ]]. set_result ( response )
673
674
self ._in_use_ids .remove (response ["id" ])
674
675
elif "params" in response :
675
- self ._received [response ["params" ]["subscription" ]] = response
676
+ self ._received [response ["params" ]["subscription" ]]. set_result ( response )
676
677
else :
677
678
raise KeyError (response )
678
679
except ssl .SSLError :
@@ -685,19 +686,26 @@ async def _start_receiving(self, ws: ClientConnection) -> Exception:
685
686
async for recd in ws :
686
687
await self ._recv (recd )
687
688
except Exception as e :
688
- return e
689
+ for i in self ._received .keys ():
690
+ self ._received [i ].set_exception (e )
691
+ return
689
692
690
693
async def _start_sending (self , ws ) -> Exception :
694
+ to_send = None
691
695
try :
692
696
while True :
693
- logger .info ("699 Not Empty" )
694
697
to_send = await self ._sending .get ()
695
698
if self ._log_raw_websockets :
696
- raw_websocket_logger .debug (f"WEBSOCKET_SEND> { to_send } " )
699
+ raw_websocket_logger .debug (f"WEBSOCKET_SEND> { to_send } } ")
697
700
await ws .send (json .dumps (to_send ))
698
701
self .last_sent = await self .loop_time ()
699
702
except Exception as e :
700
- return e
703
+ if to_send is not None :
704
+ self ._received [to_send ["id" ]].set_exception (e )
705
+ else :
706
+ for i in self ._received .keys ():
707
+ self ._received [i ].set_exception (e )
708
+ return
701
709
702
710
async def send (self , payload : dict ) -> str :
703
711
"""
@@ -715,8 +723,8 @@ async def send(self, payload: dict) -> str:
715
723
while original_id in self ._in_use_ids :
716
724
original_id = get_next_id ()
717
725
self ._in_use_ids .add (original_id )
726
+ self ._received [original_id ] = asyncio .get_running_loop ().create_future ()
718
727
to_send = {** payload , ** {"id" : original_id }}
719
- logger .info (f"Sending { to_send } " )
720
728
await self ._sending .put (to_send )
721
729
return original_id
722
730
@@ -730,11 +738,12 @@ async def retrieve(self, item_id: int) -> Optional[dict]:
730
738
Returns:
731
739
retrieved item
732
740
"""
733
- try :
734
- item = self . _received . pop ( item_id )
741
+ item : asyncio . Future = self . _received . get ( item_id )
742
+ if item . done ():
735
743
self .max_subscriptions .release ()
736
- return item
737
- except KeyError :
744
+ del self ._received [item_id ]
745
+ return item .result ()
746
+ else :
738
747
await asyncio .sleep (0.1 )
739
748
return None
740
749
@@ -2263,16 +2272,9 @@ async def _make_rpc_request(
2263
2272
subscription_added = False
2264
2273
2265
2274
async with self .ws as ws :
2266
- if len (payloads ) > 1 :
2267
- send_coroutines = await asyncio .gather (
2268
- * [ws .send (item ["payload" ]) for item in payloads ]
2269
- )
2270
- for item_id , item in zip (send_coroutines , payloads ):
2271
- request_manager .add_request (item_id , item ["id" ])
2272
- else :
2273
- item = payloads [0 ]
2274
- item_id = await ws .send (item ["payload" ])
2275
- request_manager .add_request (item_id , item ["id" ])
2275
+ for payload in payloads :
2276
+ item_id = await ws .send (payload )
2277
+ request_manager .add_request (item_id , payload ["id" ])
2276
2278
2277
2279
while True :
2278
2280
for item_id in list (request_manager .response_map .keys ()):
@@ -2311,17 +2313,17 @@ async def _make_rpc_request(
2311
2313
if request_manager .is_complete :
2312
2314
break
2313
2315
if (
2314
- (current_time := await self . ws .loop_time ()) - self . ws .last_received
2316
+ (current_time := await ws .loop_time ()) - ws .last_received
2315
2317
>= self .retry_timeout
2316
- and current_time - self . ws .last_sent >= self .retry_timeout
2318
+ and current_time - ws .last_sent >= self .retry_timeout
2317
2319
):
2318
2320
if attempt >= self .max_retries :
2319
2321
logger .error (
2320
2322
f"Timed out waiting for RPC requests { attempt } times. Exiting."
2321
2323
)
2322
2324
raise MaxRetriesExceeded ("Max retries reached." )
2323
2325
else :
2324
- self .ws .last_received = time . time ()
2326
+ self .ws .last_received = await ws . loop_time ()
2325
2327
await self .ws .connect (force = True )
2326
2328
logger .warning (
2327
2329
f"Timed out waiting for RPC requests. "
0 commit comments