Skip to content

Commit 91d84d3

Browse files
authored
Merge pull request #111 from alpacahq/feature/alpaca-stream-reconnect
Add automatic reconnection for Alpaca stream
2 parents 71646b0 + fe72f28 commit 91d84d3

File tree

3 files changed

+34
-22
lines changed

3 files changed

+34
-22
lines changed

alpaca_trade_api/polygon/streamconn.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,7 @@ async def _recv(self):
8888
msg = json.loads(r)
8989
for update in msg:
9090
yield update
91-
except websockets.exceptions.ConnectionClosed:
92-
# Ignore, occurs on self.close() such as after KeyboardInterrupt
93-
pass
94-
except websockets.exceptions.ConnectionClosedError as e:
91+
except Exception as e:
9592
await self._dispatch({'ev': 'status',
9693
'status': 'disconnected',
9794
'message':
@@ -121,13 +118,12 @@ async def _ensure_ws(self):
121118
await self.connect()
122119
if self._streams:
123120
await self.subscribe(self._streams)
124-
125121
break
126-
except (ConnectionRefusedError, ConnectionError) as e:
122+
except Exception as e:
127123
await self._dispatch({'ev': 'status',
128124
'status': 'connect failed',
129125
'message':
130-
f'Connection Failed ({e})'})
126+
f'Polygon Connection Failed ({e})'})
131127
self._ws = None
132128
self._retries += 1
133129
time.sleep(self._retry_wait * self._retry)

alpaca_trade_api/stream2.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import json
3+
import os
34
import re
45
import websockets
56
from .common import get_base_url, get_credentials
@@ -16,11 +17,15 @@ def __init__(self, key_id=None, secret_key=None, base_url=None):
1617
self._handlers = {}
1718
self._handler_symbols = {}
1819
self._base_url = base_url
20+
self._streams = set([])
1921
self._ws = None
22+
self._retry = int(os.environ.get('APCA_RETRY_MAX', 3))
23+
self._retry_wait = int(os.environ.get('APCA_RETRY_WAIT', 3))
24+
self._retries = 0
2025
self.polygon = None
2126
try:
2227
self.loop = asyncio.get_event_loop()
23-
except:
28+
except Exception:
2429
self.loop = asyncio.new_event_loop()
2530
asyncio.set_event_loop(self.loop)
2631

@@ -43,12 +48,13 @@ async def _connect(self):
4348
("Invalid Alpaca API credentials, Failed to authenticate: {}"
4449
.format(msg))
4550
)
51+
else:
52+
self._retries = 0
4653

4754
self._ws = ws
4855
await self._dispatch('authorized', msg)
4956

5057
asyncio.ensure_future(self._consume_msg())
51-
return ws
5258

5359
async def _consume_msg(self):
5460
ws = self._ws
@@ -61,9 +67,9 @@ async def _consume_msg(self):
6167
stream = msg.get('stream')
6268
if stream is not None:
6369
await self._dispatch(stream, msg)
64-
finally:
65-
await ws.close()
66-
self._ws = None
70+
except Exception:
71+
await self.close()
72+
asyncio.ensure_future(self._ensure_ws())
6773

6874
async def _ensure_polygon(self):
6975
if self.polygon is not None:
@@ -79,10 +85,22 @@ async def _ensure_polygon(self):
7985
async def _ensure_ws(self):
8086
if self._ws is not None:
8187
return
82-
self._ws = await self._connect()
88+
89+
while self._retries <= self._retry:
90+
try:
91+
await self._connect()
92+
if self._streams:
93+
await self.subscribe(self._streams)
94+
break
95+
except Exception:
96+
self._ws = None
97+
self._retries += 1
98+
await asyncio.sleep(self._retry_wait * self._retry)
99+
else:
100+
raise ConnectionError("Max Retries Exceeded")
83101

84102
async def subscribe(self, channels):
85-
'''Start subscribing channels.
103+
'''Start subscribing to channels.
86104
If the necessary connection isn't open yet, it opens now.
87105
'''
88106
ws_channels = []
@@ -94,6 +112,7 @@ async def subscribe(self, channels):
94112
ws_channels.append(c)
95113

96114
if len(ws_channels) > 0:
115+
self._streams |= set(ws_channels)
97116
await self._ensure_ws()
98117
await self._ws.send(json.dumps({
99118
'action': 'listen',
@@ -129,7 +148,7 @@ async def unsubscribe(self, channels):
129148
await self.polygon.unsubscribe(polygon_channels)
130149

131150
def run(self, initial_channels=[]):
132-
'''Run forever and block until exception is rasised.
151+
'''Run forever and block until exception is raised.
133152
initial_channels is the channels to start with.
134153
'''
135154
loop = self.loop
@@ -146,8 +165,10 @@ async def close(self):
146165
'''Close any of open connections'''
147166
if self._ws is not None:
148167
await self._ws.close()
168+
self._ws = None
149169
if self.polygon is not None:
150170
await self.polygon.close()
171+
self.polygon = None
151172

152173
def _cast(self, channel, msg):
153174
if channel == 'account_updates':

tests/test_stream2.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,6 @@ class TestException(Exception):
7777
async def on_raise(conn, stream, msg):
7878
raise TestException()
7979

80-
with pytest.raises(TestException):
81-
_run(conn._consume_msg())
82-
assert ws.close.mock.called
83-
8480
# _ensure_polygon
8581
conn = StreamConn('key-id', 'secret-key')
8682
with mock.patch('alpaca_trade_api.stream2.polygon') as polygon:
@@ -94,7 +90,6 @@ async def on_raise(conn, stream, msg):
9490
conn._connect = AsyncMock()
9591
_run(conn._ensure_ws())
9692
assert conn._connect.mock.called
97-
assert conn._ws is not None
9893

9994
# subscribe
10095
conn = StreamConn('key-id', 'secret-key')
@@ -116,8 +111,8 @@ async def on_raise(conn, stream, msg):
116111
conn.polygon = mock.Mock()
117112
conn.polygon.close = AsyncMock()
118113
_run(conn.close())
119-
assert conn._ws.close.mock.called
120-
assert conn.polygon.close.mock.called
114+
assert conn._ws is None
115+
assert conn.polygon is None
121116

122117
# _cast
123118
conn = StreamConn('key-id', 'secret-key')

0 commit comments

Comments
 (0)