Skip to content

Commit 7da9542

Browse files
committed
Fix streaming test
1 parent 7f0fb76 commit 7da9542

File tree

3 files changed

+15
-13
lines changed

3 files changed

+15
-13
lines changed

alpaca_trade_api/polygon/stream2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,16 +174,16 @@ async def _dispatch(self, channel, msg):
174174
ent = self._cast(channel, msg)
175175
await handler(self, channel, ent)
176176

177-
def register(self, channel_pat, func):
177+
def register(self, channel_pat, func, symbols=None):
178178
if not asyncio.iscoroutinefunction(func):
179179
raise ValueError('handler must be a coroutine function')
180180
if isinstance(channel_pat, str):
181181
channel_pat = re.compile(channel_pat)
182182
self._handlers[channel_pat] = func
183+
self._handler_symbols[func] = symbols
183184

184185
def deregister(self, channel_pat):
185186
if isinstance(channel_pat, str):
186187
channel_pat = re.compile(channel_pat)
187-
handler = self._handlers[channel_pat]
188-
del self._handler_symbols[handler]
188+
self._handler_symbols.pop(self._handlers[channel_pat], None)
189189
del self._handlers[channel_pat]

alpaca_trade_api/stream2.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -131,27 +131,25 @@ async def _dispatch(self, channel, msg):
131131

132132
def on(self, channel_pat, symbols=None):
133133
def decorator(func):
134-
self.register(channel_pat, func)
135-
if symbols:
136-
self._handler_symbols[func] = symbols
134+
self.register(channel_pat, func, symbols)
137135
return func
138136

139137
return decorator
140138

141-
def register(self, channel_pat, func):
139+
def register(self, channel_pat, func, symbols=None):
142140
if not asyncio.iscoroutinefunction(func):
143141
raise ValueError('handler must be a coroutine function')
144142
if isinstance(channel_pat, str):
145143
channel_pat = re.compile(channel_pat)
146144
self._handlers[channel_pat] = func
145+
self._handler_symbols[func] = symbols
147146
if self.polygon:
148-
self.polygon.register(channel_pat, func)
147+
self.polygon.register(channel_pat, func, symbols)
149148

150149
def deregister(self, channel_pat):
151150
if isinstance(channel_pat, str):
152151
channel_pat = re.compile(channel_pat)
153-
handler = self._handlers[channel_pat]
154-
del self._handler_symbols[handler]
152+
self._handler_symbols.pop(self._handlers[channel_pat], None)
155153
del self._handlers[channel_pat]
156154
if self.polygon:
157155
self.polygon.deregister(channel_pat)

tests/test_stream2.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from alpaca_trade_api.stream2 import StreamConn
2+
from alpaca_trade_api.polygon.stream2 import StreamConn as PolyStream
23
from alpaca_trade_api.entity import Account
34
from alpaca_trade_api.polygon.entity import Entity as PolyEntity
45
import asyncio
@@ -126,12 +127,15 @@ async def on_raise(conn, stream, msg):
126127
ent = conn._cast('other', {'key': 'value'})
127128
assert ent.key == 'value'
128129

129-
# _dispatch_nats
130+
# polygon _dispatch
130131
conn = StreamConn('key-id', 'secret-key')
132+
conn.polygon = PolyStream('key-id')
133+
msg_data = PolyEntity({'key': 'value'})
134+
conn.polygon._cast = mock.Mock(return_value=msg_data)
131135

132-
@conn.on('^Q.')
136+
@conn.on('Q')
133137
async def on_q(conn, subject, data):
134138
on_q.data = data
135139

136-
_run(conn._dispatch_polygon(conn, 'Q.SPY', PolyEntity({'key': 'value'})))
140+
_run(conn.polygon._dispatch('Q', msg_data))
137141
assert on_q.data.key == 'value'

0 commit comments

Comments
 (0)