Skip to content

Commit 9347e3b

Browse files
committed
Merge branch 'master' of https://github.com/alpacahq/alpaca-trade-api-python into feature/ws-handler-registry
2 parents 7da9542 + 1bce590 commit 9347e3b

File tree

3 files changed

+137
-46
lines changed

3 files changed

+137
-46
lines changed

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,17 @@ conn = StreamConn()
187187
async def on_account_updates(conn, channel, account):
188188
print('account', account)
189189

190+
@conn.on(r'^status$')
191+
def on_status(conn, channel, data):
192+
print('polygon status update', data)
190193

191194
@conn.on(r'^AM$')
192-
def on_bars(conn, channel, bar):
195+
def on_minute_bars(conn, channel, bar):
193196
print('bars', bar)
194197

198+
@conn.on(r'^A$')
199+
def on_second_bars(conn, channel, bar):
200+
print('bars', bar)
195201

196202
# blocks forever
197203
conn.run(['account_updates', 'AM.*'])
@@ -204,6 +210,9 @@ unless an exception is raised.
204210
### StreamConn.subscribe(channels)
205211
Request "listen" to the server. `channels` must be a list of string channel names.
206212

213+
### StreamConn.unsubscribe(channels)
214+
Request to stop "listening" to the server. `channels` must be a list of string channel names.
215+
207216
### StreamConn.run(channels)
208217
Goes into an infinite loop and awaits for messages from the server. You should
209218
set up event listeners using the `on` or `register` method before calling `run`.

alpaca_trade_api/polygon/stream2.py

Lines changed: 105 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -18,93 +18,151 @@ def __init__(self, key_id=None):
1818
).rstrip('/')
1919
self._handlers = {}
2020
self._handler_symbols = {}
21+
self._streams = set([])
2122
self._ws = None
2223
self._retry = int(os.environ.get('APCA_RETRY_MAX', 3))
2324
self._retry_wait = int(os.environ.get('APCA_RETRY_WAIT', 3))
2425
self._retries = 0
2526

2627
async def connect(self):
27-
await self._dispatch('status',
28-
{'ev': 'status',
28+
await self._dispatch({'ev': 'status',
2929
'status': 'connecting',
3030
'message': 'Connecting to Polygon'})
31-
ws = await websockets.connect(self._endpoint)
31+
self._ws = await websockets.connect(self._endpoint)
32+
self._stream = self._recv()
33+
34+
msg = await self._next()
35+
if msg.get('status') != 'connected':
36+
raise ValueError(
37+
("Invalid response on Polygon websocket connection: {}"
38+
.format(msg))
39+
)
40+
await self._dispatch(msg)
41+
if await self.authenticate():
42+
asyncio.ensure_future(self._consume_msg())
43+
else:
44+
await self.close()
45+
46+
async def authenticate(self):
47+
ws = self._ws
48+
if not ws:
49+
return False
3250

3351
await ws.send(json.dumps({
3452
'action': 'auth',
3553
'params': self._key_id
3654
}))
37-
r = await ws.recv()
38-
if isinstance(r, bytes):
39-
r = r.decode('utf-8')
40-
msg = json.loads(r)
41-
if msg[0].get('status') != 'connected':
42-
raise ValueError(
43-
("Invalid Polygon credentials, Failed to authenticate: {}"
44-
.format(msg))
45-
)
55+
data = await self._next()
56+
stream = data.get('ev')
57+
msg = data.get('message')
58+
status = data.get('status')
59+
if (stream == 'status'
60+
and msg == 'authenticated'
61+
and status == 'success'):
62+
# reset retries only after we successfully authenticated
63+
self._retries = 0
64+
await self._dispatch(data)
65+
return True
66+
else:
67+
raise ValueError('Invalid Polygon credentials, '
68+
f'Failed to authenticate: {data}')
4669

47-
self._retries = 0
48-
self._ws = ws
49-
await self._dispatch('authorized', msg[0])
70+
async def _next(self):
71+
'''Returns the next message available
72+
'''
73+
return await self._stream.__anext__()
5074

51-
asyncio.ensure_future(self._consume_msg())
75+
async def _recv(self):
76+
'''Function used to recieve and parse all messages from websocket stream.
5277
53-
async def _consume_msg(self):
54-
ws = self._ws
55-
if not ws:
56-
return
78+
This generator yields one message per each call.
79+
'''
5780
try:
5881
while True:
59-
r = await ws.recv()
82+
r = await self._ws.recv()
6083
if isinstance(r, bytes):
6184
r = r.decode('utf-8')
6285
msg = json.loads(r)
6386
for update in msg:
64-
stream = update.get('ev')
65-
if stream is not None:
66-
await self._dispatch(stream, update)
67-
except websockets.exceptions.ConnectionClosedError:
68-
await self._dispatch('status',
69-
{'ev': 'status',
87+
yield update
88+
except websockets.exceptions.ConnectionClosedError as e:
89+
await self._dispatch({'ev': 'status',
7090
'status': 'disconnected',
7191
'message':
72-
'Polygon Disconnected Unexpectedly'})
73-
finally:
74-
if self._ws is not None:
75-
await self._ws.close()
76-
self._ws = None
92+
f'Polygon Disconnected Unexpectedly ({e})'})
93+
await self.close()
7794
asyncio.ensure_future(self._ensure_ws())
7895

96+
async def _consume_msg(self):
97+
async for data in self._stream:
98+
stream = data.get('ev')
99+
if stream:
100+
await self._dispatch(data)
101+
elif data.get('status') == 'disconnected':
102+
# Polygon returns this on an empty 'ev' id..
103+
data['ev'] = 'status'
104+
await self._dispatch(data)
105+
raise ConnectionResetError(
106+
'Polygon terminated connection: '
107+
f'({data.get("message")})')
108+
79109
async def _ensure_ws(self):
80110
if self._ws is not None:
81111
return
82-
try:
83-
await self.connect()
84-
except Exception:
85-
self._ws = None
86-
self._retries += 1
87-
time.sleep(self._retry_wait)
88-
if self._retries <= self._retry:
89-
asyncio.ensure_future(self._ensure_ws())
90-
else:
91-
raise ConnectionError("Max Retries Exceeded")
112+
113+
while self._retries <= self._retry:
114+
try:
115+
await self.connect()
116+
if self._streams:
117+
await self.subscribe(self._streams)
118+
119+
break
120+
except (ConnectionRefusedError, ConnectionError) as e:
121+
await self._dispatch({'ev': 'status',
122+
'status': 'connect failed',
123+
'message':
124+
f'Connection Failed ({e})'})
125+
self._ws = None
126+
self._retries += 1
127+
time.sleep(self._retry_wait * self._retry)
128+
else:
129+
raise ConnectionError("Max Retries Exceeded")
92130

93131
async def subscribe(self, channels):
94-
'''Start subscribing channels.
132+
'''Subscribe to channels.
133+
Note: This is cumulative, meaning you can add channels at runtime,
134+
and you do not need to specify all the channels.
135+
136+
To remove channels see unsubscribe().
137+
95138
If the necessary connection isn't open yet, it opens now.
96139
'''
97140
if len(channels) > 0:
98141
await self._ensure_ws()
99142
# Join channel list to string
100143
streams = ','.join(channels)
144+
self._streams |= set(channels)
101145
await self._ws.send(json.dumps({
102146
'action': 'subscribe',
103147
'params': streams
104148
}))
105149

150+
async def unsubscribe(self, channels):
151+
'''Unsubscribe from channels
152+
'''
153+
if not self._ws:
154+
return
155+
if len(channels) > 0:
156+
# Join channel list to string
157+
streams = ','.join(channels)
158+
self._streams -= set(channels)
159+
await self._ws.send(json.dumps({
160+
'action': 'unsubscribe',
161+
'params': streams
162+
}))
163+
106164
def run(self, initial_channels=[]):
107-
'''Run forever and block until exception is rasised.
165+
'''Run forever and block until exception is raised.
108166
initial_channels is the channels to start with.
109167
'''
110168
loop = asyncio.get_event_loop()
@@ -118,6 +176,7 @@ async def close(self):
118176
'''Close any open connections'''
119177
if self._ws is not None:
120178
await self._ws.close()
179+
self._ws = None
121180

122181
def _cast(self, subject, data):
123182
if subject == 'T':
@@ -166,7 +225,8 @@ def _cast(self, subject, data):
166225
ent = Entity(data)
167226
return ent
168227

169-
async def _dispatch(self, channel, msg):
228+
async def _dispatch(self, msg):
229+
channel = msg.get('ev')
170230
for pat, handler in self._handlers.items():
171231
if pat.match(channel):
172232
handled_symbols = self._handler_symbols.get(handler)

alpaca_trade_api/stream2.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,28 @@ async def subscribe(self, channels):
100100
await self._ensure_polygon()
101101
await self.polygon.subscribe(polygon_channels)
102102

103+
async def unsubscribe(self, channels):
104+
'''Handle un-subscribing from channels.
105+
'''
106+
if not self._ws:
107+
return
108+
109+
ws_channels = []
110+
polygon_channels = []
111+
for c in channels:
112+
if c.startswith(('Q.', 'T.', 'A.', 'AM.',)):
113+
polygon_channels.append(c)
114+
else:
115+
ws_channels.append(c)
116+
117+
if len(ws_channels) > 0:
118+
# Currently our streams don't support unsubscribe
119+
# not as useful with our feeds
120+
pass
121+
122+
if len(polygon_channels) > 0:
123+
await self.polygon.unsubscribe(polygon_channels)
124+
103125
def run(self, initial_channels=[]):
104126
'''Run forever and block until exception is rasised.
105127
initial_channels is the channels to start with.

0 commit comments

Comments
 (0)