@@ -17,93 +17,151 @@ def __init__(self, key_id=None):
1717 'wss://alpaca.socket.polygon.io/stocks'
1818 ).rstrip ('/' )
1919 self ._handlers = {}
20+ self ._streams = set ([])
2021 self ._ws = None
2122 self ._retry = int (os .environ .get ('APCA_RETRY_MAX' , 3 ))
2223 self ._retry_wait = int (os .environ .get ('APCA_RETRY_WAIT' , 3 ))
2324 self ._retries = 0
2425
2526 async def connect (self ):
26- await self ._dispatch ('status' ,
27- {'ev' : 'status' ,
27+ await self ._dispatch ({'ev' : 'status' ,
2828 'status' : 'connecting' ,
2929 'message' : 'Connecting to Polygon' })
30- ws = await websockets .connect (self ._endpoint )
30+ self ._ws = await websockets .connect (self ._endpoint )
31+ self ._stream = self ._recv ()
32+
33+ msg = await self ._next ()
34+ if msg .get ('status' ) != 'connected' :
35+ raise ValueError (
36+ ("Invalid response on Polygon websocket connection: {}"
37+ .format (msg ))
38+ )
39+ await self ._dispatch (msg )
40+ if await self .authenticate ():
41+ asyncio .ensure_future (self ._consume_msg ())
42+ else :
43+ await self .close ()
44+
45+ async def authenticate (self ):
46+ ws = self ._ws
47+ if not ws :
48+ return False
3149
3250 await ws .send (json .dumps ({
3351 'action' : 'auth' ,
3452 'params' : self ._key_id
3553 }))
36- r = await ws .recv ()
37- if isinstance (r , bytes ):
38- r = r .decode ('utf-8' )
39- msg = json .loads (r )
40- if msg [0 ].get ('status' ) != 'connected' :
41- raise ValueError (
42- ("Invalid Polygon credentials, Failed to authenticate: {}"
43- .format (msg ))
44- )
54+ data = await self ._next ()
55+ stream = data .get ('ev' )
56+ msg = data .get ('message' )
57+ status = data .get ('status' )
58+ if (stream == 'status'
59+ and msg == 'authenticated'
60+ and status == 'success' ):
61+ # reset retries only after we successfully authenticated
62+ self ._retries = 0
63+ await self ._dispatch (data )
64+ return True
65+ else :
66+ raise ValueError ('Invalid Polygon credentials, '
67+ f'Failed to authenticate: { data } ' )
4568
46- self ._retries = 0
47- self ._ws = ws
48- await self ._dispatch ('authorized' , msg [0 ])
69+ async def _next (self ):
70+ '''Returns the next message available
71+ '''
72+ return await self ._stream .__anext__ ()
4973
50- asyncio .ensure_future (self ._consume_msg ())
74+ async def _recv (self ):
75+ '''Function used to recieve and parse all messages from websocket stream.
5176
52- async def _consume_msg (self ):
53- ws = self ._ws
54- if not ws :
55- return
77+ This generator yields one message per each call.
78+ '''
5679 try :
5780 while True :
58- r = await ws .recv ()
81+ r = await self . _ws .recv ()
5982 if isinstance (r , bytes ):
6083 r = r .decode ('utf-8' )
6184 msg = json .loads (r )
6285 for update in msg :
63- stream = update .get ('ev' )
64- if stream is not None :
65- await self ._dispatch (stream , update )
66- except websockets .exceptions .ConnectionClosedError :
67- await self ._dispatch ('status' ,
68- {'ev' : 'status' ,
86+ yield update
87+ except websockets .exceptions .ConnectionClosedError as e :
88+ await self ._dispatch ({'ev' : 'status' ,
6989 'status' : 'disconnected' ,
7090 'message' :
71- 'Polygon Disconnected Unexpectedly' })
72- finally :
73- if self ._ws is not None :
74- await self ._ws .close ()
75- self ._ws = None
91+ f'Polygon Disconnected Unexpectedly ({ e } )' })
92+ await self .close ()
7693 asyncio .ensure_future (self ._ensure_ws ())
7794
95+ async def _consume_msg (self ):
96+ async for data in self ._stream :
97+ stream = data .get ('ev' )
98+ if stream :
99+ await self ._dispatch (data )
100+ elif data .get ('status' ) == 'disconnected' :
101+ # Polygon returns this on an empty 'ev' id..
102+ data ['ev' ] = 'status'
103+ await self ._dispatch (data )
104+ raise ConnectionResetError (
105+ 'Polygon terminated connection: '
106+ f'({ data .get ("message" )} )' )
107+
78108 async def _ensure_ws (self ):
79109 if self ._ws is not None :
80110 return
81- try :
82- await self .connect ()
83- except Exception :
84- self ._ws = None
85- self ._retries += 1
86- time .sleep (self ._retry_wait )
87- if self ._retries <= self ._retry :
88- asyncio .ensure_future (self ._ensure_ws ())
89- else :
90- raise ConnectionError ("Max Retries Exceeded" )
111+
112+ while self ._retries <= self ._retry :
113+ try :
114+ await self .connect ()
115+ if self ._streams :
116+ await self .subscribe (self ._streams )
117+
118+ break
119+ except (ConnectionRefusedError , ConnectionError ) as e :
120+ await self ._dispatch ({'ev' : 'status' ,
121+ 'status' : 'connect failed' ,
122+ 'message' :
123+ f'Connection Failed ({ e } )' })
124+ self ._ws = None
125+ self ._retries += 1
126+ time .sleep (self ._retry_wait * self ._retry )
127+ else :
128+ raise ConnectionError ("Max Retries Exceeded" )
91129
92130 async def subscribe (self , channels ):
93- '''Start subscribing channels.
131+ '''Subscribe to channels.
132+ Note: This is cumulative, meaning you can add channels at runtime,
133+ and you do not need to specify all the channels.
134+
135+ To remove channels see unsubscribe().
136+
94137 If the necessary connection isn't open yet, it opens now.
95138 '''
96139 if len (channels ) > 0 :
97140 await self ._ensure_ws ()
98141 # Join channel list to string
99142 streams = ',' .join (channels )
143+ self ._streams |= set (channels )
100144 await self ._ws .send (json .dumps ({
101145 'action' : 'subscribe' ,
102146 'params' : streams
103147 }))
104148
149+ async def unsubscribe (self , channels ):
150+ '''Unsubscribe from channels
151+ '''
152+ if not self ._ws :
153+ return
154+ if len (channels ) > 0 :
155+ # Join channel list to string
156+ streams = ',' .join (channels )
157+ self ._streams -= set (channels )
158+ await self ._ws .send (json .dumps ({
159+ 'action' : 'unsubscribe' ,
160+ 'params' : streams
161+ }))
162+
105163 def run (self , initial_channels = []):
106- '''Run forever and block until exception is rasised .
164+ '''Run forever and block until exception is raised .
107165 initial_channels is the channels to start with.
108166 '''
109167 loop = asyncio .get_event_loop ()
@@ -117,6 +175,7 @@ async def close(self):
117175 '''Close any of open connections'''
118176 if self ._ws is not None :
119177 await self ._ws .close ()
178+ self ._ws = None
120179
121180 def _cast (self , subject , data ):
122181 if subject == 'T' :
@@ -165,7 +224,8 @@ def _cast(self, subject, data):
165224 ent = Entity (data )
166225 return ent
167226
168- async def _dispatch (self , channel , msg ):
227+ async def _dispatch (self , msg ):
228+ channel = msg .get ('ev' )
169229 for pat , handler in self ._handlers .items ():
170230 if pat .match (channel ):
171231 ent = self ._cast (channel , msg )
0 commit comments