1414
1515from .common import get_base_url , get_data_stream_url , get_credentials , URL
1616from .entity import Entity
17- from .entity_v2 import quote_mapping_v2 , trade_mapping_v2 , bar_mapping_v2 , \
18- status_mapping_v2 , luld_mapping_v2 , Trade , Quote , Bar , StatusV2 , LULDV2
17+ from .entity_v2 import (
18+ quote_mapping_v2 ,
19+ trade_mapping_v2 ,
20+ bar_mapping_v2 ,
21+ status_mapping_v2 ,
22+ luld_mapping_v2 ,
23+ cancel_error_mapping_v2 ,
24+ correction_mapping_v2 ,
25+ Trade ,
26+ Quote ,
27+ Bar ,
28+ StatusV2 ,
29+ LULDV2 ,
30+ CancelErrorV2 ,
31+ CorrectionV2 ,
32+ )
1933
2034log = logging .getLogger (__name__ )
2135
@@ -167,7 +181,10 @@ def _subscribe(self, handler, symbols, handlers):
167181 asyncio .get_event_loop ().run_until_complete (self ._subscribe_all ())
168182
169183 async def _subscribe_all (self ):
170- if any (self ._handlers .values ()):
184+ if any (
185+ v for k , v in self ._handlers .items ()
186+ if k not in ("cancelErrors" , "corrections" )
187+ ):
171188 msg = {
172189 k : tuple (v .keys ())
173190 for k , v in self ._handlers .items ()
@@ -193,7 +210,10 @@ async def _unsubscribe(self,
193210
194211 async def _run_forever (self ):
195212 # do not start the websocket connection until we subscribe to something
196- while not any (self ._handlers .values ()):
213+ while not any (
214+ v for k , v in self ._handlers .items ()
215+ if k not in ("cancelErrors" , "corrections" )
216+ ):
197217 if not self ._stop_stream_queue .empty ():
198218 # the ws was signaled to stop before starting the loop so
199219 # we break
@@ -283,6 +303,8 @@ def __init__(self,
283303 )
284304 self ._handlers ['statuses' ] = {}
285305 self ._handlers ['lulds' ] = {}
306+ self ._handlers ['cancelErrors' ] = {}
307+ self ._handlers ['corrections' ] = {}
286308 self ._name = 'stock data'
287309
288310 def _cast (self , msg_type , msg ):
@@ -298,6 +320,16 @@ def _cast(self, msg_type, msg):
298320 luld_mapping_v2 [k ]: v
299321 for k , v in msg .items () if k in luld_mapping_v2
300322 })
323+ elif msg_type == 'x' :
324+ result = CancelErrorV2 ({
325+ cancel_error_mapping_v2 [k ]: v
326+ for k , v in msg .items () if k in cancel_error_mapping_v2
327+ })
328+ elif msg_type == 'c' :
329+ result = CorrectionV2 ({
330+ correction_mapping_v2 [k ]: v
331+ for k , v in msg .items () if k in correction_mapping_v2
332+ })
301333 return result
302334
303335 async def _dispatch (self , msg ):
@@ -313,6 +345,16 @@ async def _dispatch(self, msg):
313345 symbol , self ._handlers ['lulds' ].get ('*' , None ))
314346 if handler :
315347 await handler (self ._cast (msg_type , msg ))
348+ elif msg_type == 'x' :
349+ handler = self ._handlers ['cancelErrors' ].get (
350+ symbol , self ._handlers ['cancelErrors' ].get ('*' , None ))
351+ if handler :
352+ await handler (self ._cast (msg_type , msg ))
353+ elif msg_type == 'c' :
354+ handler = self ._handlers ['corrections' ].get (
355+ symbol , self ._handlers ['corrections' ].get ('*' , None ))
356+ if handler :
357+ await handler (self ._cast (msg_type , msg ))
316358 else :
317359 await super ()._dispatch (msg )
318360
@@ -355,6 +397,16 @@ def unsubscribe_lulds(self, *symbols):
355397 for symbol in symbols :
356398 del self ._handlers ['lulds' ][symbol ]
357399
400+ def register_handler (self , msg_type , handler , * symbols ):
401+ if handler is not None :
402+ _ensure_coroutine (handler )
403+ for symbol in symbols :
404+ self ._handlers [msg_type ][symbol ] = handler
405+
406+ def unregister_handler (self , msg_type , * symbols ):
407+ for symbol in symbols :
408+ del self ._handlers [msg_type ][symbol ]
409+
358410
359411class CryptoDataStream (_DataStream ):
360412 def __init__ (self ,
@@ -531,8 +583,20 @@ def __init__(self,
531583 def subscribe_trade_updates (self , handler ):
532584 self ._trading_ws .subscribe_trade_updates (handler )
533585
534- def subscribe_trades (self , handler , * symbols ):
586+ def subscribe_trades (
587+ self ,
588+ handler ,
589+ * symbols ,
590+ handler_cancel_errors = None ,
591+ handler_corrections = None
592+ ):
535593 self ._data_ws .subscribe_trades (handler , * symbols )
594+ self ._data_ws .register_handler ("cancelErrors" ,
595+ handler_cancel_errors ,
596+ * symbols )
597+ self ._data_ws .register_handler ("corrections" ,
598+ handler_corrections ,
599+ * symbols )
536600
537601 def subscribe_quotes (self , handler , * symbols ):
538602 self ._data_ws .subscribe_quotes (handler , * symbols )
@@ -607,6 +671,20 @@ def decorator(func):
607671
608672 return decorator
609673
674+ def on_cancel_error (self , * symbols ):
675+ def decorator (func ):
676+ self ._data_ws .register_handler ("cancelErrors" , func , * symbols )
677+ return func
678+
679+ return decorator
680+
681+ def on_corrections (self , * symbols ):
682+ def decorator (func ):
683+ self ._data_ws .register_handler ("corrections" , func , * symbols )
684+ return func
685+
686+ return decorator
687+
610688 def on_crypto_trade (self , * symbols ):
611689 def decorator (func ):
612690 self .subscribe_crypto_trades (func , * symbols )
@@ -637,6 +715,8 @@ def decorator(func):
637715
638716 def unsubscribe_trades (self , * symbols ):
639717 self ._data_ws .unsubscribe_trades (* symbols )
718+ self ._data_ws .unregister_handler ("cancelErrors" , * symbols )
719+ self ._data_ws .unregister_handler ("corrections" , * symbols )
640720
641721 def unsubscribe_quotes (self , * symbols ):
642722 self ._data_ws .unsubscribe_quotes (* symbols )
0 commit comments