Skip to content

Commit 7d9872b

Browse files
authored
Handle trade error cancels and corrections as auto subscribed messages with trades (#536)
1 parent c0fbfd9 commit 7d9872b

File tree

2 files changed

+125
-5
lines changed

2 files changed

+125
-5
lines changed

alpaca_trade_api/entity_v2.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,32 @@
6161
"z": "tape"
6262
}
6363

64+
cancel_error_mapping_v2 = {
65+
"S": "symbol",
66+
"i": "id",
67+
"x": "exchange",
68+
"p": "price",
69+
"s": "size",
70+
"a": "cancel_error_action",
71+
"z": "tape",
72+
"t": "timestamp",
73+
}
74+
75+
correction_mapping_v2 = {
76+
"S": "symbol",
77+
"x": "exchange",
78+
"oi": "original_id",
79+
"op": "original_price",
80+
"os": "original_size",
81+
"oc": "original_conditions",
82+
"ci": "corrected_id",
83+
"cp": "corrected_price",
84+
"cs": "corrected_size",
85+
"cc": "corrected_conditions",
86+
"z": "tape",
87+
"t": "timestamp",
88+
}
89+
6490

6591
class EntityListType(Enum):
6692
Trade = Trade, trade_mapping_v2
@@ -152,6 +178,20 @@ def __init__(self, raw):
152178
super().__init__(luld_mapping_v2, raw)
153179

154180

181+
class CancelErrorV2(Remapped, _NanoTimestamped, Entity):
182+
_tskeys = ('t',)
183+
184+
def __init__(self, raw):
185+
super().__init__(cancel_error_mapping_v2, raw)
186+
187+
188+
class CorrectionV2(Remapped, _NanoTimestamped, Entity):
189+
_tskeys = ('t',)
190+
191+
def __init__(self, raw):
192+
super().__init__(correction_mapping_v2, raw)
193+
194+
155195
class SnapshotV2:
156196
def __init__(self, raw):
157197
self.latest_trade = _convert_or_none(TradeV2, raw.get('latestTrade'))

alpaca_trade_api/stream.py

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,22 @@
1414

1515
from .common import get_base_url, get_data_stream_url, get_credentials, URL
1616
from .entity import Entity
17-
from .entity_v2 import quote_mapping_v2, trade_mapping_v2, bar_mapping_v2, \
18-
status_mapping_v2, luld_mapping_v2, Trade, Quote, Bar, StatusV2, LULDV2
17+
from .entity_v2 import (
18+
quote_mapping_v2,
19+
trade_mapping_v2,
20+
bar_mapping_v2,
21+
status_mapping_v2,
22+
luld_mapping_v2,
23+
cancel_error_mapping_v2,
24+
correction_mapping_v2,
25+
Trade,
26+
Quote,
27+
Bar,
28+
StatusV2,
29+
LULDV2,
30+
CancelErrorV2,
31+
CorrectionV2,
32+
)
1933

2034
log = logging.getLogger(__name__)
2135

@@ -167,7 +181,10 @@ def _subscribe(self, handler, symbols, handlers):
167181
asyncio.get_event_loop().run_until_complete(self._subscribe_all())
168182

169183
async def _subscribe_all(self):
170-
if any(self._handlers.values()):
184+
if any(
185+
v for k, v in self._handlers.items()
186+
if k not in ("cancelErrors", "corrections")
187+
):
171188
msg = {
172189
k: tuple(v.keys())
173190
for k, v in self._handlers.items()
@@ -193,7 +210,10 @@ async def _unsubscribe(self,
193210

194211
async def _run_forever(self):
195212
# do not start the websocket connection until we subscribe to something
196-
while not any(self._handlers.values()):
213+
while not any(
214+
v for k, v in self._handlers.items()
215+
if k not in ("cancelErrors", "corrections")
216+
):
197217
if not self._stop_stream_queue.empty():
198218
# the ws was signaled to stop before starting the loop so
199219
# we break
@@ -283,6 +303,8 @@ def __init__(self,
283303
)
284304
self._handlers['statuses'] = {}
285305
self._handlers['lulds'] = {}
306+
self._handlers['cancelErrors'] = {}
307+
self._handlers['corrections'] = {}
286308
self._name = 'stock data'
287309

288310
def _cast(self, msg_type, msg):
@@ -298,6 +320,16 @@ def _cast(self, msg_type, msg):
298320
luld_mapping_v2[k]: v
299321
for k, v in msg.items() if k in luld_mapping_v2
300322
})
323+
elif msg_type == 'x':
324+
result = CancelErrorV2({
325+
cancel_error_mapping_v2[k]: v
326+
for k, v in msg.items() if k in cancel_error_mapping_v2
327+
})
328+
elif msg_type == 'c':
329+
result = CorrectionV2({
330+
correction_mapping_v2[k]: v
331+
for k, v in msg.items() if k in correction_mapping_v2
332+
})
301333
return result
302334

303335
async def _dispatch(self, msg):
@@ -313,6 +345,16 @@ async def _dispatch(self, msg):
313345
symbol, self._handlers['lulds'].get('*', None))
314346
if handler:
315347
await handler(self._cast(msg_type, msg))
348+
elif msg_type == 'x':
349+
handler = self._handlers['cancelErrors'].get(
350+
symbol, self._handlers['cancelErrors'].get('*', None))
351+
if handler:
352+
await handler(self._cast(msg_type, msg))
353+
elif msg_type == 'c':
354+
handler = self._handlers['corrections'].get(
355+
symbol, self._handlers['corrections'].get('*', None))
356+
if handler:
357+
await handler(self._cast(msg_type, msg))
316358
else:
317359
await super()._dispatch(msg)
318360

@@ -355,6 +397,16 @@ def unsubscribe_lulds(self, *symbols):
355397
for symbol in symbols:
356398
del self._handlers['lulds'][symbol]
357399

400+
def register_handler(self, msg_type, handler, *symbols):
401+
if handler is not None:
402+
_ensure_coroutine(handler)
403+
for symbol in symbols:
404+
self._handlers[msg_type][symbol] = handler
405+
406+
def unregister_handler(self, msg_type, *symbols):
407+
for symbol in symbols:
408+
del self._handlers[msg_type][symbol]
409+
358410

359411
class CryptoDataStream(_DataStream):
360412
def __init__(self,
@@ -531,8 +583,20 @@ def __init__(self,
531583
def subscribe_trade_updates(self, handler):
532584
self._trading_ws.subscribe_trade_updates(handler)
533585

534-
def subscribe_trades(self, handler, *symbols):
586+
def subscribe_trades(
587+
self,
588+
handler,
589+
*symbols,
590+
handler_cancel_errors=None,
591+
handler_corrections=None
592+
):
535593
self._data_ws.subscribe_trades(handler, *symbols)
594+
self._data_ws.register_handler("cancelErrors",
595+
handler_cancel_errors,
596+
*symbols)
597+
self._data_ws.register_handler("corrections",
598+
handler_corrections,
599+
*symbols)
536600

537601
def subscribe_quotes(self, handler, *symbols):
538602
self._data_ws.subscribe_quotes(handler, *symbols)
@@ -607,6 +671,20 @@ def decorator(func):
607671

608672
return decorator
609673

674+
def on_cancel_error(self, *symbols):
675+
def decorator(func):
676+
self._data_ws.register_handler("cancelErrors", func, *symbols)
677+
return func
678+
679+
return decorator
680+
681+
def on_corrections(self, *symbols):
682+
def decorator(func):
683+
self._data_ws.register_handler("corrections", func, *symbols)
684+
return func
685+
686+
return decorator
687+
610688
def on_crypto_trade(self, *symbols):
611689
def decorator(func):
612690
self.subscribe_crypto_trades(func, *symbols)
@@ -637,6 +715,8 @@ def decorator(func):
637715

638716
def unsubscribe_trades(self, *symbols):
639717
self._data_ws.unsubscribe_trades(*symbols)
718+
self._data_ws.unregister_handler("cancelErrors", *symbols)
719+
self._data_ws.unregister_handler("corrections", *symbols)
640720

641721
def unsubscribe_quotes(self, *symbols):
642722
self._data_ws.unsubscribe_quotes(*symbols)

0 commit comments

Comments
 (0)