@@ -524,6 +524,7 @@ def __init__(
524
524
shutdown_timer = 5 ,
525
525
options : Optional [dict ] = None ,
526
526
_log_raw_websockets : bool = False ,
527
+ retry_timeout : float = 60.0
527
528
):
528
529
"""
529
530
Websocket manager object. Allows for the use of a single websocket connection by multiple
@@ -542,10 +543,12 @@ def __init__(
542
543
self .max_subscriptions = asyncio .Semaphore (max_subscriptions )
543
544
self .max_connections = max_connections
544
545
self .shutdown_timer = shutdown_timer
546
+ self .retry_timeout = retry_timeout
545
547
self ._received : dict [str , asyncio .Future ] = {}
546
548
self ._received_subscriptions : dict [str , asyncio .Queue ] = {}
547
549
self ._sending = asyncio .Queue ()
548
- self ._receiving_task = None # TODO rename, as this now does send/recv
550
+ self ._send_recv_task = None
551
+ self ._inflight : dict [str , str ] = {}
549
552
self ._attempts = 0
550
553
self ._initialized = False # TODO remove
551
554
self ._lock = asyncio .Lock ()
@@ -586,8 +589,8 @@ async def loop_time() -> float:
586
589
587
590
async def _cancel (self ):
588
591
try :
589
- self ._receiving_task .cancel ()
590
- await self ._receiving_task
592
+ self ._send_recv_task .cancel ()
593
+ await self ._send_recv_task
591
594
await self .ws .close ()
592
595
except (
593
596
AttributeError ,
@@ -601,13 +604,14 @@ async def _cancel(self):
601
604
)
602
605
603
606
async def connect (self , force = False ):
607
+ # TODO after connecting, move from _inflight to the queue
604
608
now = await self .loop_time ()
605
609
self .last_received = now
606
610
self .last_sent = now
607
611
async with self ._lock :
608
612
if self ._exit_task :
609
613
self ._exit_task .cancel ()
610
- if self .state not in (State .OPEN , State .CONNECTING ):
614
+ if self .state not in (State .OPEN , State .CONNECTING ) or force :
611
615
if not self ._initialized or force :
612
616
try :
613
617
await asyncio .wait_for (self ._cancel (), timeout = 10.0 )
@@ -616,21 +620,34 @@ async def connect(self, force=False):
616
620
self .ws = await asyncio .wait_for (
617
621
connect (self .ws_url , ** self ._options ), timeout = 10.0
618
622
)
619
- if self ._receiving_task is None or self ._receiving_task .done ():
620
- self ._receiving_task = asyncio .get_running_loop ().create_task (
623
+ if self ._send_recv_task is None or self ._send_recv_task .done ():
624
+ self ._send_recv_task = asyncio .get_running_loop ().create_task (
621
625
self ._handler (self .ws )
622
626
)
623
627
self ._initialized = True
624
628
625
- async def _handler (self , ws : ClientConnection ):
626
- consumer_task = asyncio .create_task (self ._start_receiving (ws ))
627
- producer_task = asyncio .create_task (self ._start_sending (ws ))
629
+ async def _handler (self , ws : ClientConnection ) -> None :
630
+ recv_task = asyncio .create_task (self ._start_receiving (ws ))
631
+ send_task = asyncio .create_task (self ._start_sending (ws ))
628
632
done , pending = await asyncio .wait (
629
- [consumer_task , producer_task ],
633
+ [recv_task , send_task ],
630
634
return_when = asyncio .FIRST_COMPLETED ,
631
635
)
636
+ loop = asyncio .get_running_loop ()
637
+ should_reconnect = False
632
638
for task in pending :
633
639
task .cancel ()
640
+ if isinstance (task .exception (), asyncio .TimeoutError ):
641
+ should_reconnect = True
642
+ if should_reconnect is True :
643
+ for original_id , payload in list (self ._inflight .items ()):
644
+ self ._received [original_id ] = loop .create_future ()
645
+ to_send = json .loads (payload )
646
+ await self ._sending .put (to_send )
647
+ logger .info ("Timeout occurred. Reconnecting." )
648
+ await self .connect (True )
649
+ await self ._handler (ws = ws )
650
+
634
651
635
652
async def __aexit__ (self , exc_type , exc_val , exc_tb ):
636
653
if not self .state != State .CONNECTING :
@@ -662,7 +679,7 @@ async def shutdown(self):
662
679
pass
663
680
self .ws = None
664
681
self ._initialized = False
665
- self ._receiving_task = None
682
+ self ._send_recv_task = None
666
683
self ._is_closing = False
667
684
668
685
async def _recv (self , recd : bytes ) -> None :
@@ -671,9 +688,12 @@ async def _recv(self, recd: bytes) -> None:
671
688
response = json .loads (recd )
672
689
self .last_received = await self .loop_time ()
673
690
if "id" in response :
691
+ async with self ._lock :
692
+ self ._inflight .pop (response ["id" ])
674
693
self ._received [response ["id" ]].set_result (response )
675
694
self ._in_use_ids .remove (response ["id" ])
676
695
elif "params" in response :
696
+ # TODO self._inflight won't work with subscriptions
677
697
sub_id = response ["params" ]["subscription" ]
678
698
await self ._received_subscriptions [sub_id ].put (response )
679
699
else :
@@ -682,7 +702,9 @@ async def _recv(self, recd: bytes) -> None:
682
702
async def _start_receiving (self , ws : ClientConnection ) -> Exception :
683
703
try :
684
704
while True :
685
- await self ._recv (await ws .recv (decode = False ))
705
+ if self ._inflight :
706
+ recd = await asyncio .wait_for (ws .recv (decode = False ), timeout = self .retry_timeout )
707
+ await self ._recv (recd )
686
708
except Exception as e :
687
709
if isinstance (e , ssl .SSLError ):
688
710
e = ConnectionClosed
@@ -696,13 +718,14 @@ async def _start_sending(self, ws) -> Exception:
696
718
to_send = None
697
719
try :
698
720
while True :
699
- # TODO possibly when these are pulled from the Queue, they should also go into a dict or set, with the
700
- # TODO done_callback assigned to remove them when complete. This could allow easier resending in cases
701
- # TODO such as a timeout.
702
- to_send = await self ._sending .get ()
721
+ to_send_ = await self ._sending .get ()
722
+ send_id = to_send_ ["id" ]
723
+ to_send = json .dumps (to_send_ )
724
+ async with self ._lock :
725
+ self ._inflight [send_id ] = to_send
703
726
if self ._log_raw_websockets :
704
727
raw_websocket_logger .debug (f"WEBSOCKET_SEND> { to_send } " )
705
- await ws .send (json . dumps ( to_send ) )
728
+ await ws .send (to_send )
706
729
self .last_sent = await self .loop_time ()
707
730
except Exception as e :
708
731
if to_send is not None :
@@ -824,6 +847,7 @@ def __init__(
824
847
"write_limit" : 2 ** 16 ,
825
848
},
826
849
shutdown_timer = ws_shutdown_timer ,
850
+ retry_timeout = self .retry_timeout ,
827
851
)
828
852
else :
829
853
self .ws = AsyncMock (spec = Websocket )
0 commit comments