Skip to content

Commit 1bce590

Browse files
authored
Merge pull request #89 from alpacahq/feature/polygon_reconnect
Further improvments to feature/polygon_reconnect
2 parents 5ca968b + 49c6197 commit 1bce590

File tree

3 files changed

+137
-46
lines changed

3 files changed

+137
-46
lines changed

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,17 @@ conn = StreamConn()
187187
async def on_account_updates(conn, channel, account):
188188
print('account', account)
189189

190+
@conn.on(r'^status$')
191+
def on_status(conn, channel, data):
192+
print('polygon status update', data)
190193

191194
@conn.on(r'^AM$')
192-
def on_bars(conn, channel, bar):
195+
def on_minute_bars(conn, channel, bar):
193196
print('bars', bar)
194197

198+
@conn.on(r'^A$')
199+
def on_second_bars(conn, channel, bar):
200+
print('bars', bar)
195201

196202
# blocks forever
197203
conn.run(['account_updates', 'AM.*'])
@@ -204,6 +210,9 @@ unless an exception is raised.
204210
### StreamConn.subscribe(channels)
205211
Request "listen" to the server. `channels` must be a list of string channel names.
206212

213+
### StreamConn.unsubscribe(channels)
214+
Request to stop "listening" to the server. `channels` must be a list of string channel names.
215+
207216
### StreamConn.run(channels)
208217
Goes into an infinite loop and awaits for messages from the server. You should
209218
set up event listeners using the `on` or `register` method before calling `run`.

alpaca_trade_api/polygon/stream2.py

Lines changed: 105 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -17,93 +17,151 @@ def __init__(self, key_id=None):
1717
'wss://alpaca.socket.polygon.io/stocks'
1818
).rstrip('/')
1919
self._handlers = {}
20+
self._streams = set([])
2021
self._ws = None
2122
self._retry = int(os.environ.get('APCA_RETRY_MAX', 3))
2223
self._retry_wait = int(os.environ.get('APCA_RETRY_WAIT', 3))
2324
self._retries = 0
2425

2526
async def connect(self):
26-
await self._dispatch('status',
27-
{'ev': 'status',
27+
await self._dispatch({'ev': 'status',
2828
'status': 'connecting',
2929
'message': 'Connecting to Polygon'})
30-
ws = await websockets.connect(self._endpoint)
30+
self._ws = await websockets.connect(self._endpoint)
31+
self._stream = self._recv()
32+
33+
msg = await self._next()
34+
if msg.get('status') != 'connected':
35+
raise ValueError(
36+
("Invalid response on Polygon websocket connection: {}"
37+
.format(msg))
38+
)
39+
await self._dispatch(msg)
40+
if await self.authenticate():
41+
asyncio.ensure_future(self._consume_msg())
42+
else:
43+
await self.close()
44+
45+
async def authenticate(self):
46+
ws = self._ws
47+
if not ws:
48+
return False
3149

3250
await ws.send(json.dumps({
3351
'action': 'auth',
3452
'params': self._key_id
3553
}))
36-
r = await ws.recv()
37-
if isinstance(r, bytes):
38-
r = r.decode('utf-8')
39-
msg = json.loads(r)
40-
if msg[0].get('status') != 'connected':
41-
raise ValueError(
42-
("Invalid Polygon credentials, Failed to authenticate: {}"
43-
.format(msg))
44-
)
54+
data = await self._next()
55+
stream = data.get('ev')
56+
msg = data.get('message')
57+
status = data.get('status')
58+
if (stream == 'status'
59+
and msg == 'authenticated'
60+
and status == 'success'):
61+
# reset retries only after we successfully authenticated
62+
self._retries = 0
63+
await self._dispatch(data)
64+
return True
65+
else:
66+
raise ValueError('Invalid Polygon credentials, '
67+
f'Failed to authenticate: {data}')
4568

46-
self._retries = 0
47-
self._ws = ws
48-
await self._dispatch('authorized', msg[0])
69+
async def _next(self):
70+
'''Returns the next message available
71+
'''
72+
return await self._stream.__anext__()
4973

50-
asyncio.ensure_future(self._consume_msg())
74+
async def _recv(self):
75+
'''Function used to recieve and parse all messages from websocket stream.
5176
52-
async def _consume_msg(self):
53-
ws = self._ws
54-
if not ws:
55-
return
77+
This generator yields one message per each call.
78+
'''
5679
try:
5780
while True:
58-
r = await ws.recv()
81+
r = await self._ws.recv()
5982
if isinstance(r, bytes):
6083
r = r.decode('utf-8')
6184
msg = json.loads(r)
6285
for update in msg:
63-
stream = update.get('ev')
64-
if stream is not None:
65-
await self._dispatch(stream, update)
66-
except websockets.exceptions.ConnectionClosedError:
67-
await self._dispatch('status',
68-
{'ev': 'status',
86+
yield update
87+
except websockets.exceptions.ConnectionClosedError as e:
88+
await self._dispatch({'ev': 'status',
6989
'status': 'disconnected',
7090
'message':
71-
'Polygon Disconnected Unexpectedly'})
72-
finally:
73-
if self._ws is not None:
74-
await self._ws.close()
75-
self._ws = None
91+
f'Polygon Disconnected Unexpectedly ({e})'})
92+
await self.close()
7693
asyncio.ensure_future(self._ensure_ws())
7794

95+
async def _consume_msg(self):
96+
async for data in self._stream:
97+
stream = data.get('ev')
98+
if stream:
99+
await self._dispatch(data)
100+
elif data.get('status') == 'disconnected':
101+
# Polygon returns this on an empty 'ev' id..
102+
data['ev'] = 'status'
103+
await self._dispatch(data)
104+
raise ConnectionResetError(
105+
'Polygon terminated connection: '
106+
f'({data.get("message")})')
107+
78108
async def _ensure_ws(self):
79109
if self._ws is not None:
80110
return
81-
try:
82-
await self.connect()
83-
except Exception:
84-
self._ws = None
85-
self._retries += 1
86-
time.sleep(self._retry_wait)
87-
if self._retries <= self._retry:
88-
asyncio.ensure_future(self._ensure_ws())
89-
else:
90-
raise ConnectionError("Max Retries Exceeded")
111+
112+
while self._retries <= self._retry:
113+
try:
114+
await self.connect()
115+
if self._streams:
116+
await self.subscribe(self._streams)
117+
118+
break
119+
except (ConnectionRefusedError, ConnectionError) as e:
120+
await self._dispatch({'ev': 'status',
121+
'status': 'connect failed',
122+
'message':
123+
f'Connection Failed ({e})'})
124+
self._ws = None
125+
self._retries += 1
126+
time.sleep(self._retry_wait * self._retry)
127+
else:
128+
raise ConnectionError("Max Retries Exceeded")
91129

92130
async def subscribe(self, channels):
93-
'''Start subscribing channels.
131+
'''Subscribe to channels.
132+
Note: This is cumulative, meaning you can add channels at runtime,
133+
and you do not need to specify all the channels.
134+
135+
To remove channels see unsubscribe().
136+
94137
If the necessary connection isn't open yet, it opens now.
95138
'''
96139
if len(channels) > 0:
97140
await self._ensure_ws()
98141
# Join channel list to string
99142
streams = ','.join(channels)
143+
self._streams |= set(channels)
100144
await self._ws.send(json.dumps({
101145
'action': 'subscribe',
102146
'params': streams
103147
}))
104148

149+
async def unsubscribe(self, channels):
150+
'''Unsubscribe from channels
151+
'''
152+
if not self._ws:
153+
return
154+
if len(channels) > 0:
155+
# Join channel list to string
156+
streams = ','.join(channels)
157+
self._streams -= set(channels)
158+
await self._ws.send(json.dumps({
159+
'action': 'unsubscribe',
160+
'params': streams
161+
}))
162+
105163
def run(self, initial_channels=[]):
106-
'''Run forever and block until exception is rasised.
164+
'''Run forever and block until exception is raised.
107165
initial_channels is the channels to start with.
108166
'''
109167
loop = asyncio.get_event_loop()
@@ -117,6 +175,7 @@ async def close(self):
117175
'''Close any of open connections'''
118176
if self._ws is not None:
119177
await self._ws.close()
178+
self._ws = None
120179

121180
def _cast(self, subject, data):
122181
if subject == 'T':
@@ -165,7 +224,8 @@ def _cast(self, subject, data):
165224
ent = Entity(data)
166225
return ent
167226

168-
async def _dispatch(self, channel, msg):
227+
async def _dispatch(self, msg):
228+
channel = msg.get('ev')
169229
for pat, handler in self._handlers.items():
170230
if pat.match(channel):
171231
ent = self._cast(channel, msg)

alpaca_trade_api/stream2.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,28 @@ async def subscribe(self, channels):
9898
await self._ensure_polygon()
9999
await self.polygon.subscribe(polygon_channels)
100100

101+
async def unsubscribe(self, channels):
102+
'''Handle un-subscribing from channels.
103+
'''
104+
if not self._ws:
105+
return
106+
107+
ws_channels = []
108+
polygon_channels = []
109+
for c in channels:
110+
if c.startswith(('Q.', 'T.', 'A.', 'AM.',)):
111+
polygon_channels.append(c)
112+
else:
113+
ws_channels.append(c)
114+
115+
if len(ws_channels) > 0:
116+
# Currently our streams don't support unsubscribe
117+
# not as useful with our feeds
118+
pass
119+
120+
if len(polygon_channels) > 0:
121+
await self.polygon.unsubscribe(polygon_channels)
122+
101123
def run(self, initial_channels=[]):
102124
'''Run forever and block until exception is rasised.
103125
initial_channels is the channels to start with.

0 commit comments

Comments
 (0)