Skip to content

Commit 413fef3

Browse files
committed
Fix lint errors
1 parent ad005b9 commit 413fef3

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

polygon/websocket/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,19 @@ def __init__(
4949
)
5050
self.api_key = api_key
5151
self.feed = feed
52-
self.market = market
52+
if isinstance(market, str):
53+
self.market = Market(market) # converts str input to enum
54+
else:
55+
self.market = market
56+
57+
self.market_value = self.market.value
5358
self.raw = raw
5459
if verbose:
5560
logger.setLevel(logging.DEBUG)
5661
self.websocket_cfg = kwargs
5762
if isinstance(feed, Enum):
5863
feed = feed.value
59-
if isinstance(market, Enum):
60-
market = market.value
61-
self.url = f"ws{'s' if secure else ''}://{self.feed.value}/{self.market.value}"
64+
self.url = f"ws{'s' if secure else ''}://{feed}/{self.market_value}"
6265
self.subscribed = False
6366
self.subs: Set[str] = set()
6467
self.max_reconnects = max_reconnects

polygon/websocket/models/__init__.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1-
from typing import Dict, Any, List
1+
from typing import Dict, Any, List, Type, Protocol, cast
22
from .common import *
33
from .models import *
44
import logging
55

6+
7+
# Protocol to define classes with from_dict method
8+
class FromDictProtocol(Protocol):
9+
@classmethod
10+
def from_dict(cls, data: Dict[str, Any]) -> "FromDictProtocol":
11+
pass
12+
13+
614
# Define the mapping of market and event type to model class
7-
MARKET_EVENT_MAP = {
15+
MARKET_EVENT_MAP: Dict[Market, Dict[str, Type[FromDictProtocol]]] = {
816
Market.Stocks: {
917
"A": EquityAgg,
1018
"AM": EquityAgg,
@@ -50,14 +58,19 @@
5058
}
5159

5260

53-
def parse_single(data: Dict[str, Any], market: Market, logger: logging.Logger) -> Any:
61+
def parse_single(
62+
data: Dict[str, Any], logger: logging.Logger, market: Market
63+
) -> Optional[WebSocketMessage]:
5464
event_type = data["ev"]
5565
# Look up the model class based on market and event type
5666
model_class: Optional[Type[FromDictProtocol]] = MARKET_EVENT_MAP.get(
5767
market, {}
5868
).get(event_type)
5969
if model_class:
60-
return model_class.from_dict(data)
70+
parsed = model_class.from_dict(data)
71+
return cast(
72+
WebSocketMessage, parsed
73+
) # Ensure the return type is WebSocketMessage
6174
else:
6275
# Log a warning for unrecognized event types, unless it's a status message
6376
if event_type != "status":
@@ -70,7 +83,7 @@ def parse(
7083
) -> List[WebSocketMessage]:
7184
res = []
7285
for m in msg:
73-
parsed = parse_single(m, market)
86+
parsed = parse_single(m, logger, market)
7487
if parsed is None:
7588
if m["ev"] != "status":
7689
logger.warning("could not parse message %s", m)

0 commit comments

Comments
 (0)