|
4 | 4 | import re |
5 | 5 | import traceback |
6 | 6 | from asyncio import CancelledError |
| 7 | +import queue |
7 | 8 |
|
8 | 9 | import websockets |
9 | 10 | from .common import get_base_url, get_data_url, get_credentials, URL |
@@ -140,12 +141,15 @@ async def unsubscribe(self, channels): |
140 | 141 | })) |
141 | 142 |
|
142 | 143 | async def close(self): |
143 | | - if self._consume_task: |
144 | | - self._consume_task.cancel() |
| 144 | + await self.cancel_task() |
145 | 145 | if self._ws: |
146 | 146 | await self._ws.close() |
147 | 147 | self._ws = None |
148 | 148 |
|
| 149 | + async def cancel_task(self): |
| 150 | + if self._consume_task: |
| 151 | + self._consume_task.cancel() |
| 152 | + |
149 | 153 | def _cast(self, channel, msg): |
150 | 154 | if channel == 'account_updates': |
151 | 155 | return Account(msg) |
@@ -227,6 +231,7 @@ def __init__( |
227 | 231 | self._data_stream = _data_stream |
228 | 232 | self._debug = debug |
229 | 233 | self._raw_data = raw_data |
| 234 | + self._stop_stream_queue = queue.Queue() |
230 | 235 |
|
231 | 236 | self.trading_ws = _StreamConn(self._key_id, |
232 | 237 | self._secret_key, |
@@ -337,6 +342,9 @@ def run(self, initial_channels: List[str] = []): |
337 | 342 | logging.error(f"error while consuming ws messages: {m}") |
338 | 343 | if self._debug: |
339 | 344 | traceback.print_exc() |
| 345 | + if not self._stop_stream_queue.empty(): |
| 346 | + self._stop_stream_queue.get() |
| 347 | + should_renew = False |
340 | 348 | loop.run_until_complete(self.close(should_renew)) |
341 | 349 | if loop.is_running(): |
342 | 350 | loop.close() |
@@ -370,6 +378,18 @@ async def close(self, renew): |
370 | 378 | self._oauth, |
371 | 379 | raw_data=self._raw_data) |
372 | 380 |
|
| 381 | + async def stop_ws(self): |
| 382 | + """ |
| 383 | + Signal the ws connections to stop listenning to api stream. |
| 384 | + """ |
| 385 | + self._stop_stream_queue.put_nowait({"should_stop": True}) |
| 386 | + if self.trading_ws is not None: |
| 387 | + logging.info("Stopping the trading websocket connection") |
| 388 | + await self.trading_ws.cancel_task() |
| 389 | + if self.data_ws is not None: |
| 390 | + logging.info("Stopping the data websocket connection") |
| 391 | + await self.data_ws.cancel_task() |
| 392 | + |
373 | 393 | def on(self, channel_pat, symbols=None): |
374 | 394 | def decorator(func): |
375 | 395 | self.register(channel_pat, func, symbols) |
|
0 commit comments