@@ -18,93 +18,151 @@ def __init__(self, key_id=None):
1818 ).rstrip ('/' )
1919 self ._handlers = {}
2020 self ._handler_symbols = {}
21+ self ._streams = set ([])
2122 self ._ws = None
2223 self ._retry = int (os .environ .get ('APCA_RETRY_MAX' , 3 ))
2324 self ._retry_wait = int (os .environ .get ('APCA_RETRY_WAIT' , 3 ))
2425 self ._retries = 0
2526
2627 async def connect (self ):
27- await self ._dispatch ('status' ,
28- {'ev' : 'status' ,
28+ await self ._dispatch ({'ev' : 'status' ,
2929 'status' : 'connecting' ,
3030 'message' : 'Connecting to Polygon' })
31- ws = await websockets .connect (self ._endpoint )
31+ self ._ws = await websockets .connect (self ._endpoint )
32+ self ._stream = self ._recv ()
33+
34+ msg = await self ._next ()
35+ if msg .get ('status' ) != 'connected' :
36+ raise ValueError (
37+ ("Invalid response on Polygon websocket connection: {}"
38+ .format (msg ))
39+ )
40+ await self ._dispatch (msg )
41+ if await self .authenticate ():
42+ asyncio .ensure_future (self ._consume_msg ())
43+ else :
44+ await self .close ()
45+
46+ async def authenticate (self ):
47+ ws = self ._ws
48+ if not ws :
49+ return False
3250
3351 await ws .send (json .dumps ({
3452 'action' : 'auth' ,
3553 'params' : self ._key_id
3654 }))
37- r = await ws .recv ()
38- if isinstance (r , bytes ):
39- r = r .decode ('utf-8' )
40- msg = json .loads (r )
41- if msg [0 ].get ('status' ) != 'connected' :
42- raise ValueError (
43- ("Invalid Polygon credentials, Failed to authenticate: {}"
44- .format (msg ))
45- )
55+ data = await self ._next ()
56+ stream = data .get ('ev' )
57+ msg = data .get ('message' )
58+ status = data .get ('status' )
59+ if (stream == 'status'
60+ and msg == 'authenticated'
61+ and status == 'success' ):
62+ # reset retries only after we successfully authenticated
63+ self ._retries = 0
64+ await self ._dispatch (data )
65+ return True
66+ else :
67+ raise ValueError ('Invalid Polygon credentials, '
68+ f'Failed to authenticate: { data } ' )
4669
47- self ._retries = 0
48- self ._ws = ws
49- await self ._dispatch ('authorized' , msg [0 ])
70+ async def _next (self ):
71+ '''Returns the next message available
72+ '''
73+ return await self ._stream .__anext__ ()
5074
51- asyncio .ensure_future (self ._consume_msg ())
75+ async def _recv (self ):
76+ '''Function used to recieve and parse all messages from websocket stream.
5277
53- async def _consume_msg (self ):
54- ws = self ._ws
55- if not ws :
56- return
78+ This generator yields one message per each call.
79+ '''
5780 try :
5881 while True :
59- r = await ws .recv ()
82+ r = await self . _ws .recv ()
6083 if isinstance (r , bytes ):
6184 r = r .decode ('utf-8' )
6285 msg = json .loads (r )
6386 for update in msg :
64- stream = update .get ('ev' )
65- if stream is not None :
66- await self ._dispatch (stream , update )
67- except websockets .exceptions .ConnectionClosedError :
68- await self ._dispatch ('status' ,
69- {'ev' : 'status' ,
87+ yield update
88+ except websockets .exceptions .ConnectionClosedError as e :
89+ await self ._dispatch ({'ev' : 'status' ,
7090 'status' : 'disconnected' ,
7191 'message' :
72- 'Polygon Disconnected Unexpectedly' })
73- finally :
74- if self ._ws is not None :
75- await self ._ws .close ()
76- self ._ws = None
92+ f'Polygon Disconnected Unexpectedly ({ e } )' })
93+ await self .close ()
7794 asyncio .ensure_future (self ._ensure_ws ())
7895
96+ async def _consume_msg (self ):
97+ async for data in self ._stream :
98+ stream = data .get ('ev' )
99+ if stream :
100+ await self ._dispatch (data )
101+ elif data .get ('status' ) == 'disconnected' :
102+ # Polygon returns this on an empty 'ev' id..
103+ data ['ev' ] = 'status'
104+ await self ._dispatch (data )
105+ raise ConnectionResetError (
106+ 'Polygon terminated connection: '
107+ f'({ data .get ("message" )} )' )
108+
79109 async def _ensure_ws (self ):
80110 if self ._ws is not None :
81111 return
82- try :
83- await self .connect ()
84- except Exception :
85- self ._ws = None
86- self ._retries += 1
87- time .sleep (self ._retry_wait )
88- if self ._retries <= self ._retry :
89- asyncio .ensure_future (self ._ensure_ws ())
90- else :
91- raise ConnectionError ("Max Retries Exceeded" )
112+
113+ while self ._retries <= self ._retry :
114+ try :
115+ await self .connect ()
116+ if self ._streams :
117+ await self .subscribe (self ._streams )
118+
119+ break
120+ except (ConnectionRefusedError , ConnectionError ) as e :
121+ await self ._dispatch ({'ev' : 'status' ,
122+ 'status' : 'connect failed' ,
123+ 'message' :
124+ f'Connection Failed ({ e } )' })
125+ self ._ws = None
126+ self ._retries += 1
127+ time .sleep (self ._retry_wait * self ._retry )
128+ else :
129+ raise ConnectionError ("Max Retries Exceeded" )
92130
93131 async def subscribe (self , channels ):
94- '''Start subscribing channels.
132+ '''Subscribe to channels.
133+ Note: This is cumulative, meaning you can add channels at runtime,
134+ and you do not need to specify all the channels.
135+
136+ To remove channels see unsubscribe().
137+
95138 If the necessary connection isn't open yet, it opens now.
96139 '''
97140 if len (channels ) > 0 :
98141 await self ._ensure_ws ()
99142 # Join channel list to string
100143 streams = ',' .join (channels )
144+ self ._streams |= set (channels )
101145 await self ._ws .send (json .dumps ({
102146 'action' : 'subscribe' ,
103147 'params' : streams
104148 }))
105149
150+ async def unsubscribe (self , channels ):
151+ '''Unsubscribe from channels
152+ '''
153+ if not self ._ws :
154+ return
155+ if len (channels ) > 0 :
156+ # Join channel list to string
157+ streams = ',' .join (channels )
158+ self ._streams -= set (channels )
159+ await self ._ws .send (json .dumps ({
160+ 'action' : 'unsubscribe' ,
161+ 'params' : streams
162+ }))
163+
106164 def run (self , initial_channels = []):
107- '''Run forever and block until exception is rasised .
165+ '''Run forever and block until exception is raised .
108166 initial_channels is the channels to start with.
109167 '''
110168 loop = asyncio .get_event_loop ()
@@ -118,6 +176,7 @@ async def close(self):
118176 '''Close any open connections'''
119177 if self ._ws is not None :
120178 await self ._ws .close ()
179+ self ._ws = None
121180
122181 def _cast (self , subject , data ):
123182 if subject == 'T' :
@@ -166,7 +225,8 @@ def _cast(self, subject, data):
166225 ent = Entity (data )
167226 return ent
168227
169- async def _dispatch (self , channel , msg ):
228+ async def _dispatch (self , msg ):
229+ channel = msg .get ('ev' )
170230 for pat , handler in self ._handlers .items ():
171231 if pat .match (channel ):
172232 handled_symbols = self ._handler_symbols .get (handler )
0 commit comments