|
6 | 6 | from .common import get_base_url, get_data_url, get_credentials |
7 | 7 | from .entity import Account, Entity, trade_mapping, agg_mapping, quote_mapping |
8 | 8 | from . import polygon |
9 | | -from .polygon.entity import Trade, Quote, Agg |
| 9 | +from .entity import Trade, Quote, Agg |
10 | 10 | import logging |
11 | 11 |
|
12 | 12 |
|
@@ -52,6 +52,10 @@ async def _connect(self): |
52 | 52 |
|
53 | 53 | self._consume_task = asyncio.ensure_future(self._consume_msg()) |
54 | 54 |
|
| 55 | + async def consume(self): |
| 56 | + if self._consume_task: |
| 57 | + await self._consume_task |
| 58 | + |
55 | 59 | async def _consume_msg(self): |
56 | 60 | ws = self._ws |
57 | 61 | try: |
@@ -239,26 +243,28 @@ async def subscribe(self, channels): |
239 | 243 | async def unsubscribe(self, channels): |
240 | 244 | '''Handle unsubscribing from channels.''' |
241 | 245 |
|
242 | | - data_prefixes = ('Q.', 'T.', 'AM.') |
243 | | - if self._data_stream == 'polygon': |
244 | | - data_prefixes = ('Q.', 'T.', 'A.', 'AM.') |
245 | | - |
246 | 246 | data_channels = [ |
247 | 247 | c for c in channels |
248 | | - if c.startswith(data_prefixes) |
| 248 | + if c.startswith(self._data_prefixes) |
249 | 249 | ] |
250 | 250 |
|
251 | 251 | if data_channels: |
252 | 252 | await self.data_ws.unsubscribe(data_channels) |
253 | 253 |
|
| 254 | + async def consume(self): |
| 255 | + await asyncio.gather( |
| 256 | + self.trading_ws.consume(), |
| 257 | + self.data_ws.consume(), |
| 258 | + ) |
| 259 | + |
254 | 260 | def run(self, initial_channels=[]): |
255 | 261 | '''Run forever and block until exception is raised. |
256 | 262 | initial_channels is the channels to start with. |
257 | 263 | ''' |
258 | 264 | loop = self.loop |
259 | 265 | try: |
260 | 266 | loop.run_until_complete(self.subscribe(initial_channels)) |
261 | | - loop.run_forever() |
| 267 | + loop.run_until_complete(self.consume()) |
262 | 268 | except KeyboardInterrupt: |
263 | 269 | logging.info("Exiting on Interrupt") |
264 | 270 | finally: |
|
0 commit comments