1- from typing import Dict , Any , List
1+ from typing import Dict , Any , List , Type , Protocol , cast
22from .common import *
33from .models import *
44import logging
55
6+
7+ # Protocol to define classes with from_dict method
8+ class FromDictProtocol (Protocol ):
9+ @classmethod
10+ def from_dict (cls , data : Dict [str , Any ]) -> "FromDictProtocol" :
11+ pass
12+
13+
614# Define the mapping of market and event type to model class
7- MARKET_EVENT_MAP = {
15+ MARKET_EVENT_MAP : Dict [ Market , Dict [ str , Type [ FromDictProtocol ]]] = {
816 Market .Stocks : {
917 "A" : EquityAgg ,
1018 "AM" : EquityAgg ,
5058}
5159
5260
53- def parse_single (data : Dict [str , Any ], market : Market , logger : logging .Logger ) -> Any :
61+ def parse_single (
62+ data : Dict [str , Any ], logger : logging .Logger , market : Market
63+ ) -> Optional [WebSocketMessage ]:
5464 event_type = data ["ev" ]
5565 # Look up the model class based on market and event type
5666 model_class : Optional [Type [FromDictProtocol ]] = MARKET_EVENT_MAP .get (
5767 market , {}
5868 ).get (event_type )
5969 if model_class :
60- return model_class .from_dict (data )
70+ parsed = model_class .from_dict (data )
71+ return cast (
72+ WebSocketMessage , parsed
73+ ) # Ensure the return type is WebSocketMessage
6174 else :
6275 # Log a warning for unrecognized event types, unless it's a status message
6376 if event_type != "status" :
@@ -70,7 +83,7 @@ def parse(
7083) -> List [WebSocketMessage ]:
7184 res = []
7285 for m in msg :
73- parsed = parse_single (m , market )
86+ parsed = parse_single (m , logger , market )
7487 if parsed is None :
7588 if m ["ev" ] != "status" :
7689 logger .warning ("could not parse message %s" , m )
0 commit comments