11import asyncio
22import json
3+ import os
34import re
45import websockets
56from .common import get_base_url , get_credentials
@@ -16,11 +17,15 @@ def __init__(self, key_id=None, secret_key=None, base_url=None):
1617 self ._handlers = {}
1718 self ._handler_symbols = {}
1819 self ._base_url = base_url
20+ self ._streams = set ([])
1921 self ._ws = None
22+ self ._retry = int (os .environ .get ('APCA_RETRY_MAX' , 3 ))
23+ self ._retry_wait = int (os .environ .get ('APCA_RETRY_WAIT' , 3 ))
24+ self ._retries = 0
2025 self .polygon = None
2126 try :
2227 self .loop = asyncio .get_event_loop ()
23- except :
28+ except Exception :
2429 self .loop = asyncio .new_event_loop ()
2530 asyncio .set_event_loop (self .loop )
2631
@@ -43,12 +48,13 @@ async def _connect(self):
4348 ("Invalid Alpaca API credentials, Failed to authenticate: {}"
4449 .format (msg ))
4550 )
51+ else :
52+ self ._retries = 0
4653
4754 self ._ws = ws
4855 await self ._dispatch ('authorized' , msg )
4956
5057 asyncio .ensure_future (self ._consume_msg ())
51- return ws
5258
5359 async def _consume_msg (self ):
5460 ws = self ._ws
@@ -61,9 +67,9 @@ async def _consume_msg(self):
6167 stream = msg .get ('stream' )
6268 if stream is not None :
6369 await self ._dispatch (stream , msg )
64- finally :
65- await ws .close ()
66- self ._ws = None
70+ except Exception :
71+ await self .close ()
72+ asyncio . ensure_future ( self ._ensure_ws ())
6773
6874 async def _ensure_polygon (self ):
6975 if self .polygon is not None :
@@ -79,10 +85,22 @@ async def _ensure_polygon(self):
7985 async def _ensure_ws (self ):
8086 if self ._ws is not None :
8187 return
82- self ._ws = await self ._connect ()
88+
89+ while self ._retries <= self ._retry :
90+ try :
91+ await self ._connect ()
92+ if self ._streams :
93+ await self .subscribe (self ._streams )
94+ break
95+ except Exception :
96+ self ._ws = None
97+ self ._retries += 1
98+ await asyncio .sleep (self ._retry_wait * self ._retry )
99+ else :
100+ raise ConnectionError ("Max Retries Exceeded" )
83101
84102 async def subscribe (self , channels ):
85- '''Start subscribing channels.
103+ '''Start subscribing to channels.
86104 If the necessary connection isn't open yet, it opens now.
87105 '''
88106 ws_channels = []
@@ -94,6 +112,7 @@ async def subscribe(self, channels):
94112 ws_channels .append (c )
95113
96114 if len (ws_channels ) > 0 :
115+ self ._streams |= set (ws_channels )
97116 await self ._ensure_ws ()
98117 await self ._ws .send (json .dumps ({
99118 'action' : 'listen' ,
@@ -129,7 +148,7 @@ async def unsubscribe(self, channels):
129148 await self .polygon .unsubscribe (polygon_channels )
130149
131150 def run (self , initial_channels = []):
132- '''Run forever and block until exception is rasised .
151+ '''Run forever and block until exception is raised .
133152 initial_channels is the channels to start with.
134153 '''
135154 loop = self .loop
@@ -146,8 +165,10 @@ async def close(self):
146165 '''Close any of open connections'''
147166 if self ._ws is not None :
148167 await self ._ws .close ()
168+ self ._ws = None
149169 if self .polygon is not None :
150170 await self .polygon .close ()
171+ self .polygon = None
151172
152173 def _cast (self , channel , msg ):
153174 if channel == 'account_updates' :
0 commit comments