Skip to content

Commit a678d50

Browse files
authored
Merge pull request #189 from alpacahq/simplify_stream
Simplify stream
2 parents 3c438bb + d11f226 commit a678d50

File tree

2 files changed

+58
-34
lines changed

2 files changed

+58
-34
lines changed

alpaca_trade_api/stream2.py

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ async def _ensure_ws(self):
8787
raise ConnectionError("Max Retries Exceeded")
8888

8989
async def subscribe(self, channels):
90+
if isinstance(channels, str):
91+
channels = [channels]
9092
if len(channels) > 0:
9193
await self._ensure_ws()
9294
self._streams |= set(channels)
@@ -98,9 +100,15 @@ async def subscribe(self, channels):
98100
}))
99101

100102
async def unsubscribe(self, channels):
101-
# Currently our streams don't support unsubscribe
102-
# not as useful with our feeds
103-
pass
103+
if isinstance(channels, str):
104+
channels = [channels]
105+
if len(channels) > 0:
106+
await self._ws.send(json.dumps({
107+
'action': 'unlisten',
108+
'data': {
109+
'streams': channels,
110+
}
111+
}))
104112

105113
async def close(self):
106114
if self._consume_task:
@@ -158,15 +166,31 @@ def __init__(
158166
key_id=None,
159167
secret_key=None,
160168
base_url=None,
161-
data_url=None):
169+
data_url=None,
170+
data_stream=None):
162171
_key_id, _secret_key, _ = get_credentials(key_id, secret_key)
163172
_base_url = base_url or get_base_url()
164173
_data_url = data_url or get_data_url()
174+
if data_stream is not None:
175+
if data_stream in ('alpacadatav1', 'polygon'):
176+
_data_stream = data_stream
177+
else:
178+
raise ValueError('invalid data_stream name {}'.format(
179+
data_stream))
180+
else:
181+
_data_stream = 'alpacadatav1'
182+
self._data_stream = _data_stream
165183

166184
self.trading_ws = _StreamConn(_key_id, _secret_key, _base_url)
167-
self.data_ws = _StreamConn(_key_id, _secret_key, _data_url)
168-
self.polygon = polygon.StreamConn(
169-
_key_id + '-staging' if 'staging' in _base_url else _key_id)
185+
186+
if self._data_stream == 'polygon':
187+
self.data_ws = polygon.StreamConn(
188+
_key_id + '-staging' if 'staging' in _base_url else _key_id)
189+
self._data_prefixes = (('Q.', 'T.', 'A.', 'AM.'))
190+
else:
191+
self.data_ws = _StreamConn(_key_id, _secret_key, _data_url)
192+
self._data_prefixes = (
193+
('Q.', 'T.', 'AM.', 'polyfeed/', 'alpacadatav1/'))
170194

171195
self._handlers = {}
172196
self._handler_symbols = {}
@@ -191,34 +215,41 @@ async def _ensure_ws(self, conn):
191215
async def subscribe(self, channels):
192216
'''Start subscribing to channels.
193217
If the necessary connection isn't open yet, it opens now.
218+
This may raise ValueError if a channel is not recognized.
194219
'''
195-
trading_channels, data_channels, polygon_channels = [], [], []
220+
trading_channels, data_channels = [], []
221+
196222
for c in channels:
197-
if c.startswith(('Q.', 'T.', 'A.', 'AM.',)):
198-
polygon_channels.append(c)
199-
elif c in ('trade_updates', 'account_updates'):
223+
if c in ('trade_updates', 'account_updates'):
200224
trading_channels.append(c)
201-
else:
225+
elif c.startswith(self._data_prefixes):
202226
data_channels.append(c)
227+
else:
228+
raise ValueError(
229+
('unknown channel {} (you may need to specify ' +
230+
'the right data_stream)').format(c))
203231

204232
if trading_channels:
205233
await self._ensure_ws(self.trading_ws)
206234
await self.trading_ws.subscribe(trading_channels)
207235
if data_channels:
208236
await self._ensure_ws(self.data_ws)
209237
await self.data_ws.subscribe(data_channels)
210-
if polygon_channels:
211-
await self._ensure_ws(self.polygon)
212-
await self.polygon.subscribe(polygon_channels)
213238

214239
async def unsubscribe(self, channels):
215240
'''Handle unsubscribing from channels.'''
216-
polygon_channels = [
241+
242+
data_prefixes = ('Q.', 'T.', 'AM.')
243+
if self._data_stream == 'polygon':
244+
data_prefixes = ('Q.', 'T.', 'A.', 'AM.')
245+
246+
data_channels = [
217247
c for c in channels
218-
if c.startswith(('Q.', 'T.', 'A.', 'AM.',))
248+
if c.startswith(data_prefixes)
219249
]
220-
if polygon_channels:
221-
await self.polygon.unsubscribe(polygon_channels)
250+
251+
if data_channels:
252+
await self.data_ws.unsubscribe(data_channels)
222253

223254
def run(self, initial_channels=[]):
224255
'''Run forever and block until exception is raised.
@@ -242,9 +273,6 @@ async def close(self):
242273
if self.data_ws is not None:
243274
await self.data_ws.close()
244275
self.data_ws = None
245-
if self.polygon is not None:
246-
await self.polygon.close()
247-
self.polygon = None
248276

249277
def on(self, channel_pat, symbols=None):
250278
def decorator(func):
@@ -265,8 +293,6 @@ def register(self, channel_pat, func, symbols=None):
265293
self.trading_ws.register(channel_pat, func, symbols)
266294
if self.data_ws:
267295
self.data_ws.register(channel_pat, func, symbols)
268-
if self.polygon:
269-
self.polygon.register(channel_pat, func, symbols)
270296

271297
def deregister(self, channel_pat):
272298
if isinstance(channel_pat, str):
@@ -278,5 +304,3 @@ def deregister(self, channel_pat):
278304
self.trading_ws.deregister(channel_pat)
279305
if self.data_ws:
280306
self.data_ws.deregister(channel_pat)
281-
if self.polygon:
282-
self.polygon.deregister(channel_pat)

tests/test_stream2.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,10 @@ async def on_raise(conn, stream, msg):
8383
polygon.StreamConn().connect = AsyncMock()
8484
polygon.StreamConn()._handlers = None
8585

86-
conn = StreamConn('key-id', 'secret-key')
87-
_run(conn._ensure_ws(conn.polygon))
88-
assert conn.polygon is not None
89-
assert conn.polygon.connect.mock.called
86+
conn = StreamConn('key-id', 'secret-key', data_stream='polygon')
87+
_run(conn._ensure_ws(conn.data_ws))
88+
assert conn.data_ws is not None
89+
assert conn.data_ws.connect.mock.called
9090

9191
# _ensure_ws
9292
conn = StreamConn('key-id', 'secret-key')
@@ -119,14 +119,14 @@ async def on_raise(conn, stream, msg):
119119
assert ent.key == 'value'
120120

121121
# polygon _dispatch
122-
conn = StreamConn('key-id', 'secret-key')
123-
conn.polygon = PolyStream('key-id')
122+
conn = StreamConn('key-id', 'secret-key', data_stream='polygon')
123+
conn.data_ws = PolyStream('key-id')
124124
msg_data = {'key': 'value', 'ev': 'Q'}
125-
conn.polygon._cast = mock.Mock(return_value=msg_data)
125+
conn.data_ws._cast = mock.Mock(return_value=msg_data)
126126

127127
@conn.on('Q')
128128
async def on_q(conn, subject, data):
129129
on_q.data = data
130130

131-
_run(conn.polygon._dispatch(msg_data))
131+
_run(conn.data_ws._dispatch(msg_data))
132132
assert on_q.data['key'] == 'value'

0 commit comments

Comments
 (0)