Skip to content

Commit 916eed5

Browse files
TexasCodingclaude
andcommitted
fix(typing): Resolve client mixin type hierarchy and response handling issues
- Fix mixin type annotations to work properly with ProjectXBase - Add proper attribute declarations in mixins for base class attributes - Add isinstance checks before calling .get() on API responses to handle union types - Remove unnecessary TYPE_CHECKING imports and clean up imports - Fix response handling in auth, market_data, and trading mixins - Update realtime_data_manager type hints for better type safety This resolves the majority of type checking issues in the client modules, reducing type errors from 100+ to just 13 remaining edge cases. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 45c0349 commit 916eed5

File tree

5 files changed

+135
-39
lines changed

5 files changed

+135
-39
lines changed

src/project_x_py/client/auth.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ async def main():
5454
import base64
5555
import datetime
5656
from datetime import timedelta
57-
from typing import TYPE_CHECKING
57+
from typing import Any
5858

5959
import orjson
6060
import pytz
@@ -70,15 +70,30 @@ async def main():
7070
validate_response,
7171
)
7272

73-
if TYPE_CHECKING:
74-
from project_x_py.types import ProjectXClientProtocol
75-
7673
logger = ProjectXLogger.get_logger(__name__)
7774

7875

7976
class AuthenticationMixin:
8077
"""Mixin class providing authentication functionality."""
8178

79+
# These attributes are provided by the base class
80+
username: str
81+
api_key: str
82+
account_name: str | None
83+
headers: dict[str, str]
84+
85+
async def _make_request(
86+
self,
87+
method: str,
88+
endpoint: str,
89+
data: dict[str, Any] | None = None,
90+
params: dict[str, Any] | None = None,
91+
headers: dict[str, str] | None = None,
92+
retry_count: int = 0,
93+
) -> Any:
94+
"""Provided by HttpMixin."""
95+
_ = (method, endpoint, data, params, headers, retry_count)
96+
8297
def __init__(self) -> None:
8398
"""Initialize authentication attributes."""
8499
super().__init__()
@@ -87,7 +102,7 @@ def __init__(self) -> None:
87102
self._authenticated = False
88103
self.account_info: Account | None = None
89104

90-
async def _refresh_authentication(self: "ProjectXClientProtocol") -> None:
105+
async def _refresh_authentication(self) -> None:
91106
"""
92107
Refresh authentication if token is expired or about to expire.
93108
@@ -103,7 +118,7 @@ async def _refresh_authentication(self: "ProjectXClientProtocol") -> None:
103118
if self._should_refresh_token():
104119
await self.authenticate()
105120

106-
def _should_refresh_token(self: "ProjectXClientProtocol") -> bool:
121+
def _should_refresh_token(self) -> bool:
107122
"""
108123
Check if the authentication token should be refreshed.
109124
@@ -125,7 +140,7 @@ def _should_refresh_token(self: "ProjectXClientProtocol") -> bool:
125140
return datetime.datetime.now(pytz.UTC) >= (self.token_expiry - buffer_time)
126141

127142
@handle_errors("authenticate")
128-
async def authenticate(self: "ProjectXClientProtocol") -> None:
143+
async def authenticate(self) -> None:
129144
"""
130145
Authenticate with ProjectX API and select account.
131146
@@ -191,7 +206,11 @@ async def authenticate(self: "ProjectXClientProtocol") -> None:
191206
accounts_response = await self._make_request(
192207
"POST", "/Account/search", data=payload
193208
)
194-
if not accounts_response or not accounts_response.get("success", False):
209+
if (
210+
not accounts_response
211+
or not isinstance(accounts_response, dict)
212+
or not accounts_response.get("success", False)
213+
):
195214
raise ProjectXAuthenticationError(ErrorMessages.API_REQUEST_FAILED)
196215

197216
accounts_data = accounts_response.get("accounts", [])
@@ -232,7 +251,7 @@ async def authenticate(self: "ProjectXClientProtocol") -> None:
232251
},
233252
)
234253

235-
async def _ensure_authenticated(self: "ProjectXClientProtocol") -> None:
254+
async def _ensure_authenticated(self) -> None:
236255
"""
237256
Ensure client is authenticated before making API calls.
238257
@@ -257,7 +276,7 @@ async def _ensure_authenticated(self: "ProjectXClientProtocol") -> None:
257276

258277
@handle_errors("list accounts")
259278
@validate_response(required_fields=["success", "accounts"])
260-
async def list_accounts(self: "ProjectXClientProtocol") -> list[Account]:
279+
async def list_accounts(self) -> list[Account]:
261280
"""
262281
List all accounts available to the authenticated user.
263282
@@ -282,7 +301,11 @@ async def list_accounts(self: "ProjectXClientProtocol") -> list[Account]:
282301
payload = {"onlyActiveAccounts": True}
283302
response = await self._make_request("POST", "/Account/search", data=payload)
284303

285-
if not response or not response.get("success", False):
304+
if (
305+
not response
306+
or not isinstance(response, dict)
307+
or not response.get("success", False)
308+
):
286309
return []
287310

288311
accounts_data = response.get("accounts", [])

src/project_x_py/client/http.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ async def main():
5252
"""
5353

5454
import time
55-
from typing import TYPE_CHECKING, Any, TypeVar
55+
from typing import Any, TypeVar
5656

5757
import httpx
5858

@@ -77,9 +77,6 @@ async def main():
7777
retry_on_network_error,
7878
)
7979

80-
if TYPE_CHECKING:
81-
from project_x_py.types import ProjectXClientProtocol
82-
8380
T = TypeVar("T")
8481

8582
logger = ProjectXLogger.get_logger(__name__)
@@ -88,13 +85,25 @@ async def main():
8885
class HttpMixin:
8986
"""Mixin class providing HTTP client functionality."""
9087

88+
# These attributes are provided by the base class or other mixins
89+
config: Any # ProjectXConfig
90+
base_url: str
91+
headers: dict[str, str]
92+
session_token: str
93+
rate_limiter: Any # RateLimiter
94+
cache_hit_count: int
95+
api_call_count: int
96+
97+
async def _refresh_authentication(self) -> None:
98+
"""Provided by AuthenticationMixin."""
99+
91100
def __init__(self) -> None:
92101
"""Initialize HTTP client attributes."""
93102
super().__init__()
94103
self._client: httpx.AsyncClient | None = None
95104
self.api_call_count = 0
96105

97-
async def _create_client(self: "ProjectXClientProtocol") -> httpx.AsyncClient:
106+
async def _create_client(self) -> httpx.AsyncClient:
98107
"""
99108
Create an optimized httpx async client with connection pooling and retries.
100109
@@ -138,7 +147,7 @@ async def _create_client(self: "ProjectXClientProtocol") -> httpx.AsyncClient:
138147

139148
return client
140149

141-
async def _ensure_client(self: "ProjectXClientProtocol") -> httpx.AsyncClient:
150+
async def _ensure_client(self) -> httpx.AsyncClient:
142151
"""
143152
Ensure HTTP client is initialized and ready for API requests.
144153
@@ -160,7 +169,7 @@ async def _ensure_client(self: "ProjectXClientProtocol") -> httpx.AsyncClient:
160169
@handle_rate_limit()
161170
@retry_on_network_error(max_attempts=3)
162171
async def _make_request(
163-
self: "ProjectXClientProtocol",
172+
self,
164173
method: str,
165174
endpoint: str,
166175
data: dict[str, Any] | None = None,
@@ -309,7 +318,7 @@ async def _make_request(
309318

310319
@handle_errors("get health status")
311320
async def get_health_status(
312-
self: "ProjectXClientProtocol",
321+
self,
313322
) -> PerformanceStatsResponse:
314323
"""
315324
Get client statistics and performance metrics.

src/project_x_py/client/market_data.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ async def main():
5656

5757
import datetime
5858
import re
59-
from typing import TYPE_CHECKING, Any
59+
from typing import Any
6060

6161
import polars as pl
6262
import pytz
@@ -73,20 +73,50 @@ async def main():
7373
validate_response,
7474
)
7575

76-
if TYPE_CHECKING:
77-
from project_x_py.types import ProjectXClientProtocol
78-
7976
logger = ProjectXLogger.get_logger(__name__)
8077

8178

8279
class MarketDataMixin:
8380
"""Mixin class providing market data functionality."""
8481

82+
# These attributes are provided by the base class
83+
logger: Any
84+
config: Any # ProjectXConfig
85+
86+
async def _ensure_authenticated(self) -> None:
87+
"""Provided by AuthenticationMixin."""
88+
89+
async def _make_request(
90+
self,
91+
method: str,
92+
endpoint: str,
93+
data: dict[str, Any] | None = None,
94+
params: dict[str, Any] | None = None,
95+
headers: dict[str, str] | None = None,
96+
retry_count: int = 0,
97+
) -> Any:
98+
"""Provided by HttpMixin."""
99+
_ = (method, endpoint, data, params, headers, retry_count)
100+
101+
def get_cached_instrument(self, symbol: str) -> Any:
102+
"""Provided by CacheMixin."""
103+
_ = symbol
104+
105+
def cache_instrument(self, symbol: str, instrument: Any) -> None:
106+
"""Provided by CacheMixin."""
107+
_ = (symbol, instrument)
108+
109+
def get_cached_market_data(self, cache_key: str) -> Any:
110+
"""Provided by CacheMixin."""
111+
_ = cache_key
112+
113+
def cache_market_data(self, cache_key: str, data: Any) -> None:
114+
"""Provided by CacheMixin."""
115+
_ = (cache_key, data)
116+
85117
@handle_errors("get instrument")
86118
@validate_response(required_fields=["success", "contracts"])
87-
async def get_instrument(
88-
self: "ProjectXClientProtocol", symbol: str, live: bool = False
89-
) -> Instrument:
119+
async def get_instrument(self, symbol: str, live: bool = False) -> Instrument:
90120
"""
91121
Get detailed instrument information with caching.
92122
@@ -200,7 +230,7 @@ async def get_instrument(
200230
return instrument
201231

202232
def _select_best_contract(
203-
self: "ProjectXClientProtocol",
233+
self,
204234
instruments: list[dict[str, Any]],
205235
search_symbol: str,
206236
) -> dict[str, Any]:
@@ -276,7 +306,7 @@ def _select_best_contract(
276306
@handle_errors("search instruments")
277307
@validate_response(required_fields=["success", "contracts"])
278308
async def search_instruments(
279-
self: "ProjectXClientProtocol", query: str, live: bool = False
309+
self, query: str, live: bool = False
280310
) -> list[Instrument]:
281311
"""
282312
Search for instruments by symbol or name.
@@ -312,10 +342,16 @@ async def search_instruments(
312342
"POST", "/Contract/search", data=payload
313343
)
314344

315-
if not response or not response.get("success", False):
345+
if (
346+
not response
347+
or not isinstance(response, dict)
348+
or not response.get("success", False)
349+
):
316350
return []
317351

318-
contracts_data = response.get("contracts", [])
352+
contracts_data = (
353+
response.get("contracts", []) if isinstance(response, dict) else []
354+
)
319355
instruments = [Instrument(**contract) for contract in contracts_data]
320356

321357
logger.debug(
@@ -327,7 +363,7 @@ async def search_instruments(
327363

328364
@handle_errors("get bars")
329365
async def get_bars(
330-
self: "ProjectXClientProtocol",
366+
self,
331367
symbol: str,
332368
days: int = 8,
333369
interval: int = 5,

src/project_x_py/client/trading.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,30 +67,45 @@ async def main():
6767
import datetime
6868
import logging
6969
from datetime import timedelta
70-
from typing import TYPE_CHECKING
70+
from typing import Any
7171

7272
import pytz
7373

7474
from project_x_py.exceptions import ProjectXError
7575
from project_x_py.models import Position, Trade
7676
from project_x_py.utils.deprecation import deprecated
7777

78-
if TYPE_CHECKING:
79-
from project_x_py.types import ProjectXClientProtocol
80-
8178
logger = logging.getLogger(__name__)
8279

8380

8481
class TradingMixin:
8582
"""Mixin class providing trading functionality."""
8683

84+
# These attributes are provided by the base class
85+
account_info: Any # Account object
86+
87+
async def _ensure_authenticated(self) -> None:
88+
"""Provided by AuthenticationMixin."""
89+
90+
async def _make_request(
91+
self,
92+
method: str,
93+
endpoint: str,
94+
data: dict[str, Any] | None = None,
95+
params: dict[str, Any] | None = None,
96+
headers: dict[str, str] | None = None,
97+
retry_count: int = 0,
98+
) -> Any:
99+
"""Provided by HttpMixin."""
100+
_ = (method, endpoint, data, params, headers, retry_count)
101+
87102
@deprecated(
88103
reason="Method renamed for API consistency",
89104
version="3.0.0",
90105
removal_version="4.0.0",
91106
replacement="search_open_positions()",
92107
)
93-
async def get_positions(self: "ProjectXClientProtocol") -> list[Position]:
108+
async def get_positions(self) -> list[Position]:
94109
"""
95110
DEPRECATED: Get all open positions for the authenticated account.
96111
@@ -108,7 +123,7 @@ async def get_positions(self: "ProjectXClientProtocol") -> list[Position]:
108123
return await self.search_open_positions()
109124

110125
async def search_open_positions(
111-
self: "ProjectXClientProtocol", account_id: int | None = None
126+
self, account_id: int | None = None
112127
) -> list[Position]:
113128
"""
114129
Search for open positions for the currently authenticated account.
@@ -170,7 +185,7 @@ async def search_open_positions(
170185
return [Position(**pos) for pos in positions_data]
171186

172187
async def search_trades(
173-
self: "ProjectXClientProtocol",
188+
self,
174189
start_date: datetime.datetime | None = None,
175190
end_date: datetime.datetime | None = None,
176191
contract_id: str | None = None,

src/project_x_py/realtime_data_manager/core.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,13 @@ async def initialize(self, initial_days: int = 1) -> bool:
527527
LogMessages.DATA_FETCH,
528528
extra={"phase": "initialization", "instrument": self.instrument},
529529
)
530-
530+
if self.project_x is None:
531+
raise ProjectXError(
532+
format_error_message(
533+
ErrorMessages.INTERNAL_ERROR,
534+
reason="ProjectX client not initialized",
535+
)
536+
)
531537
# Get the contract ID for the instrument
532538
instrument_info: Instrument | None = await self.project_x.get_instrument(
533539
self.instrument
@@ -558,6 +564,13 @@ async def initialize(self, initial_days: int = 1) -> bool:
558564
# Load initial data for all timeframes
559565
async with self.data_lock:
560566
for tf_key, tf_config in self.timeframes.items():
567+
if self.project_x is None:
568+
raise ProjectXError(
569+
format_error_message(
570+
ErrorMessages.INTERNAL_ERROR,
571+
reason="ProjectX client not initialized",
572+
)
573+
)
561574
bars = await self.project_x.get_bars(
562575
self.instrument, # Use base symbol, not contract ID
563576
interval=tf_config["interval"],

0 commit comments

Comments
 (0)