Skip to content

Commit 63addf7

Browse files
authored
Merge pull request #90 from alpacahq/feature/ws-handler-registry
Fix Polygon handler registry and add symbol specification
2 parents 213bf8d + 1736075 commit 63addf7

File tree

3 files changed

+29
-25
lines changed

3 files changed

+29
-25
lines changed

alpaca_trade_api/polygon/stream2.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(self, key_id=None):
1818
'wss://alpaca.socket.polygon.io/stocks'
1919
).rstrip('/')
2020
self._handlers = {}
21+
self._handler_symbols = {}
2122
self._streams = set([])
2223
self._ws = None
2324
self._retry = int(os.environ.get('APCA_RETRY_MAX', 3))
@@ -180,7 +181,7 @@ def run(self, initial_channels=[]):
180181
loop.close()
181182

182183
async def close(self):
183-
'''Close any of open connections'''
184+
'''Close any open connections'''
184185
if self._ws is not None:
185186
await self._ws.close()
186187
self._ws = None
@@ -236,24 +237,21 @@ async def _dispatch(self, msg):
236237
channel = msg.get('ev')
237238
for pat, handler in self._handlers.items():
238239
if pat.match(channel):
239-
ent = self._cast(channel, msg)
240-
await handler(self, channel, ent)
240+
handled_symbols = self._handler_symbols.get(handler)
241+
if handled_symbols is None or msg['sym'] in handled_symbols:
242+
ent = self._cast(channel, msg)
243+
await handler(self, channel, ent)
241244

242-
def on(self, channel_pat):
243-
def decorator(func):
244-
self.register(channel_pat, func)
245-
return func
246-
247-
return decorator
248-
249-
def register(self, channel_pat, func):
245+
def register(self, channel_pat, func, symbols=None):
250246
if not asyncio.iscoroutinefunction(func):
251247
raise ValueError('handler must be a coroutine function')
252248
if isinstance(channel_pat, str):
253249
channel_pat = re.compile(channel_pat)
254250
self._handlers[channel_pat] = func
251+
self._handler_symbols[func] = symbols
255252

256253
def deregister(self, channel_pat):
257254
if isinstance(channel_pat, str):
258255
channel_pat = re.compile(channel_pat)
256+
self._handler_symbols.pop(self._handlers[channel_pat], None)
259257
del self._handlers[channel_pat]

alpaca_trade_api/stream2.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def __init__(self, key_id=None, secret_key=None, base_url=None):
1414
base_url = re.sub(r'^http', 'ws', base_url or get_base_url())
1515
self._endpoint = base_url + '/stream'
1616
self._handlers = {}
17+
self._handler_symbols = {}
1718
self._base_url = base_url
1819
self._ws = None
1920
self.polygon = None
@@ -67,7 +68,8 @@ async def _ensure_polygon(self):
6768
if 'staging' in self._base_url:
6869
key_id += '-staging'
6970
self.polygon = polygon.StreamConn(key_id)
70-
self.polygon.register(r'.*', self._dispatch_polygon)
71+
self.polygon._handlers = self._handlers
72+
self.polygon._handler_symbols = self._handler_symbols
7173
await self.polygon.connect()
7274

7375
async def _ensure_ws(self):
@@ -148,32 +150,33 @@ def _cast(self, channel, msg):
148150
return Account(msg)
149151
return Entity(msg)
150152

151-
async def _dispatch_polygon(self, conn, subject, data):
152-
for pat, handler in self._handlers.items():
153-
if pat.match(subject):
154-
await handler(self, subject, data)
155-
156153
async def _dispatch(self, channel, msg):
157154
for pat, handler in self._handlers.items():
158155
if pat.match(channel):
159156
ent = self._cast(channel, msg['data'])
160157
await handler(self, channel, ent)
161158

162-
def on(self, channel_pat):
159+
def on(self, channel_pat, symbols=None):
163160
def decorator(func):
164-
self.register(channel_pat, func)
161+
self.register(channel_pat, func, symbols)
165162
return func
166163

167164
return decorator
168165

169-
def register(self, channel_pat, func):
166+
def register(self, channel_pat, func, symbols=None):
170167
if not asyncio.iscoroutinefunction(func):
171168
raise ValueError('handler must be a coroutine function')
172169
if isinstance(channel_pat, str):
173170
channel_pat = re.compile(channel_pat)
174171
self._handlers[channel_pat] = func
172+
self._handler_symbols[func] = symbols
173+
if self.polygon:
174+
self.polygon.register(channel_pat, func, symbols)
175175

176176
def deregister(self, channel_pat):
177177
if isinstance(channel_pat, str):
178178
channel_pat = re.compile(channel_pat)
179+
self._handler_symbols.pop(self._handlers[channel_pat], None)
179180
del self._handlers[channel_pat]
181+
if self.polygon:
182+
self.polygon.deregister(channel_pat)

tests/test_stream2.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from alpaca_trade_api.stream2 import StreamConn
2+
from alpaca_trade_api.polygon.stream2 import StreamConn as PolyStream
23
from alpaca_trade_api.entity import Account
3-
from alpaca_trade_api.polygon.entity import Entity as PolyEntity
44
import asyncio
55
import json
66

@@ -126,12 +126,15 @@ async def on_raise(conn, stream, msg):
126126
ent = conn._cast('other', {'key': 'value'})
127127
assert ent.key == 'value'
128128

129-
# _dispatch_nats
129+
# polygon _dispatch
130130
conn = StreamConn('key-id', 'secret-key')
131+
conn.polygon = PolyStream('key-id')
132+
msg_data = {'key': 'value', 'ev': 'Q'}
133+
conn.polygon._cast = mock.Mock(return_value=msg_data)
131134

132-
@conn.on('^Q.')
135+
@conn.on('Q')
133136
async def on_q(conn, subject, data):
134137
on_q.data = data
135138

136-
_run(conn._dispatch_polygon(conn, 'Q.SPY', PolyEntity({'key': 'value'})))
137-
assert on_q.data.key == 'value'
139+
_run(conn.polygon._dispatch(msg_data))
140+
assert on_q.data['key'] == 'value'

0 commit comments

Comments
 (0)