Skip to content

Commit baab82a

Browse files
committed
fix(types): update factory functions and managers to use ProjectXBase instead of ProjectX
- Update create_trading_suite and all factory functions to accept ProjectXBase - Fix OrderManager, PositionManager, RealtimeDataManager, and OrderBook constructors - Update all related protocols and type annotations - Fix MultiTimeframeStrategy example to use ProjectXBase type - Resolve linting errors for type compatibility across the entire codebase - All tests pass successfully after type system updates
1 parent 19b1e27 commit baab82a

File tree

24 files changed

+233
-84
lines changed

24 files changed

+233
-84
lines changed

examples/06_multi_timeframe_strategy.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@
2525
import logging
2626
import signal
2727
from datetime import datetime
28+
from typing import Any
2829

2930
from project_x_py import (
3031
ProjectX,
32+
ProjectXBase,
3133
create_trading_suite,
3234
)
3335
from project_x_py.indicators import RSI, SMA
@@ -48,8 +50,8 @@ class MultiTimeframeStrategy:
4850

4951
def __init__(
5052
self,
51-
client: ProjectX,
52-
trading_suite: dict,
53+
client: ProjectXBase,
54+
trading_suite: dict[str, Any],
5355
symbol: str = "MNQ",
5456
max_position_size: int = 2,
5557
risk_percentage: float = 0.02,

src/project_x_py/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
from typing import Any
2525

26+
from project_x_py.client.base import ProjectXBase
27+
2628
__version__ = "2.0.4"
2729
__author__ = "TexasCoding"
2830

@@ -174,7 +176,7 @@
174176
# Factory functions - Updated to be async-only
175177
async def create_trading_suite(
176178
instrument: str,
177-
project_x: ProjectX,
179+
project_x: ProjectXBase,
178180
jwt_token: str | None = None,
179181
account_id: str | None = None,
180182
timeframes: list[str] | None = None,
@@ -284,7 +286,7 @@ async def create_trading_suite(
284286

285287

286288
def create_order_manager(
287-
project_x: ProjectX,
289+
project_x: ProjectXBase,
288290
realtime_client: ProjectXRealtimeClient | None = None,
289291
) -> OrderManager:
290292
"""
@@ -314,7 +316,7 @@ def create_order_manager(
314316

315317

316318
def create_position_manager(
317-
project_x: ProjectX,
319+
project_x: ProjectXBase,
318320
realtime_client: ProjectXRealtimeClient | None = None,
319321
order_manager: OrderManager | None = None,
320322
) -> PositionManager:
@@ -380,7 +382,7 @@ def create_realtime_client(
380382

381383
def create_data_manager(
382384
instrument: str,
383-
project_x: ProjectX,
385+
project_x: ProjectXBase,
384386
realtime_client: ProjectXRealtimeClient,
385387
timeframes: list[str] | None = None,
386388
) -> RealtimeDataManager:

src/project_x_py/client/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,5 @@ class ProjectX(ProjectXBase):
9999
>>> # ... trading logic ...
100100
"""
101101

102-
pass
103-
104102

105103
__all__ = ["ProjectX", "RateLimiter"]

src/project_x_py/client/auth.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99

1010
import pytz
1111

12-
from ..exceptions import ProjectXAuthenticationError
13-
from ..models import Account
12+
from project_x_py.exceptions import ProjectXAuthenticationError
13+
from project_x_py.models import Account
1414

1515
if TYPE_CHECKING:
16-
from .base import ProjectXBase
16+
from .protocols import ProjectXClientProtocol
1717

1818
logger = logging.getLogger(__name__)
1919

@@ -27,12 +27,12 @@ def __init__(self):
2727
self.token_expiry: datetime.datetime | None = None
2828
self._authenticated = False
2929

30-
async def _refresh_authentication(self: "ProjectXBase") -> None:
30+
async def _refresh_authentication(self: "ProjectXClientProtocol") -> None:
3131
"""Refresh authentication if token is expired or about to expire."""
3232
if self._should_refresh_token():
3333
await self.authenticate()
3434

35-
def _should_refresh_token(self: "ProjectXBase") -> bool:
35+
def _should_refresh_token(self: "ProjectXClientProtocol") -> bool:
3636
"""Check if token should be refreshed."""
3737
if not self.token_expiry:
3838
return True
@@ -41,7 +41,7 @@ def _should_refresh_token(self: "ProjectXBase") -> bool:
4141
buffer_time = timedelta(minutes=5)
4242
return datetime.datetime.now(pytz.UTC) >= (self.token_expiry - buffer_time)
4343

44-
async def authenticate(self: "ProjectXBase") -> None:
44+
async def authenticate(self: "ProjectXClientProtocol") -> None:
4545
"""
4646
Authenticate with ProjectX API and select account.
4747
@@ -133,12 +133,12 @@ async def authenticate(self: "ProjectXBase") -> None:
133133
f"Authenticated successfully. Using account: {selected_account.name}"
134134
)
135135

136-
async def _ensure_authenticated(self: "ProjectXBase") -> None:
136+
async def _ensure_authenticated(self: "ProjectXClientProtocol") -> None:
137137
"""Ensure client is authenticated before making API calls."""
138138
if not self._authenticated or self._should_refresh_token():
139139
await self.authenticate()
140140

141-
async def list_accounts(self: "ProjectXBase") -> list[Account]:
141+
async def list_accounts(self: "ProjectXClientProtocol") -> list[Account]:
142142
"""
143143
List all accounts available to the authenticated user.
144144

src/project_x_py/client/base.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
import os
55
from collections.abc import AsyncGenerator
66
from contextlib import asynccontextmanager
7-
from typing import Any
87

9-
from ..config import ConfigManager
10-
from ..exceptions import ProjectXAuthenticationError
11-
from ..models import Account, ProjectXConfig
8+
import httpx
9+
10+
from project_x_py.config import ConfigManager
11+
from project_x_py.exceptions import ProjectXAuthenticationError
12+
from project_x_py.models import Account, ProjectXConfig
13+
1214
from .auth import AuthenticationMixin
1315
from .cache import CacheMixin
1416
from .http import HttpMixin
@@ -49,6 +51,9 @@ def __init__(
4951
self.api_key = api_key
5052
self.account_name = account_name
5153

54+
# Ensure _client is properly typed
55+
self._client: httpx.AsyncClient | None = None
56+
5257
# Use provided config or create default
5358
self.config = config or ProjectXConfig()
5459
self.base_url = self.config.api_url

src/project_x_py/client/cache.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88
import polars as pl
99

10-
from ..models import Instrument
10+
from project_x_py.models import Instrument
1111

1212
if TYPE_CHECKING:
13-
from .base import ProjectXBase
13+
from .protocols import ProjectXClientProtocol
1414

1515
logger = logging.getLogger(__name__)
1616

@@ -35,7 +35,7 @@ def __init__(self):
3535
# Performance monitoring
3636
self.cache_hit_count = 0
3737

38-
async def _cleanup_cache(self: "ProjectXBase") -> None:
38+
async def _cleanup_cache(self: "ProjectXClientProtocol") -> None:
3939
"""Clean up expired cache entries."""
4040
current_time = time.time()
4141

src/project_x_py/client/http.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import httpx
88

9-
from ..exceptions import (
9+
from project_x_py.exceptions import (
1010
ProjectXAuthenticationError,
1111
ProjectXConnectionError,
1212
ProjectXDataError,
@@ -16,7 +16,7 @@
1616
)
1717

1818
if TYPE_CHECKING:
19-
from .base import ProjectXBase
19+
from .protocols import ProjectXClientProtocol
2020

2121
logger = logging.getLogger(__name__)
2222

@@ -29,7 +29,7 @@ def __init__(self):
2929
self._client: httpx.AsyncClient | None = None
3030
self.api_call_count = 0
3131

32-
async def _create_client(self: "ProjectXBase") -> httpx.AsyncClient:
32+
async def _create_client(self: "ProjectXClientProtocol") -> httpx.AsyncClient:
3333
"""
3434
Create an optimized httpx async client with connection pooling and retries.
3535
@@ -73,14 +73,14 @@ async def _create_client(self: "ProjectXBase") -> httpx.AsyncClient:
7373

7474
return client
7575

76-
async def _ensure_client(self: "ProjectXBase") -> httpx.AsyncClient:
76+
async def _ensure_client(self: "ProjectXClientProtocol") -> httpx.AsyncClient:
7777
"""Ensure HTTP client is initialized."""
7878
if self._client is None:
7979
self._client = await self._create_client()
8080
return self._client
8181

8282
async def _make_request(
83-
self: "ProjectXBase",
83+
self: "ProjectXClientProtocol",
8484
method: str,
8585
endpoint: str,
8686
data: dict[str, Any] | None = None,
@@ -137,7 +137,12 @@ async def _make_request(
137137
)
138138
await asyncio.sleep(retry_after)
139139
return await self._make_request(
140-
method, endpoint, data, params, headers, retry_count + 1
140+
method=method,
141+
endpoint=endpoint,
142+
data=data,
143+
params=params,
144+
headers=headers,
145+
retry_count=retry_count + 1,
141146
)
142147
raise ProjectXRateLimitError("Rate limit exceeded after retries")
143148

@@ -153,7 +158,12 @@ async def _make_request(
153158
# Try to refresh authentication
154159
await self._refresh_authentication()
155160
return await self._make_request(
156-
method, endpoint, data, params, headers, retry_count + 1
161+
method=method,
162+
endpoint=endpoint,
163+
data=data,
164+
params=params,
165+
headers=headers,
166+
retry_count=retry_count + 1,
157167
)
158168
raise ProjectXAuthenticationError("Authentication failed")
159169

@@ -183,7 +193,12 @@ async def _make_request(
183193
)
184194
await asyncio.sleep(wait_time)
185195
return await self._make_request(
186-
method, endpoint, data, params, headers, retry_count + 1
196+
method=method,
197+
endpoint=endpoint,
198+
data=data,
199+
params=params,
200+
headers=headers,
201+
retry_count=retry_count + 1,
187202
)
188203
raise ProjectXServerError(
189204
f"Server error: {response.status_code} - {response.text}"
@@ -212,7 +227,7 @@ async def _make_request(
212227
raise ProjectXError(f"Unexpected error: {e}") from e
213228
raise
214229

215-
async def get_health_status(self: "ProjectXBase") -> dict[str, Any]:
230+
async def get_health_status(self: "ProjectXClientProtocol") -> dict[str, Any]:
216231
"""
217232
Get API health status and client statistics.
218233

src/project_x_py/client/market_data.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
import polars as pl
1010
import pytz
1111

12-
from ..exceptions import ProjectXInstrumentError
13-
from ..models import Instrument
12+
from project_x_py.exceptions import ProjectXInstrumentError
13+
from project_x_py.models import Instrument
1414

1515
if TYPE_CHECKING:
16-
from .base import ProjectXBase
16+
from .protocols import ProjectXClientProtocol
1717

1818
logger = logging.getLogger(__name__)
1919

@@ -22,7 +22,7 @@ class MarketDataMixin:
2222
"""Mixin class providing market data functionality."""
2323

2424
async def get_instrument(
25-
self: "ProjectXBase", symbol: str, live: bool = False
25+
self: "ProjectXClientProtocol", symbol: str, live: bool = False
2626
) -> Instrument:
2727
"""
2828
Get detailed instrument information with caching.
@@ -71,7 +71,9 @@ async def get_instrument(
7171
return instrument
7272

7373
def _select_best_contract(
74-
self: "ProjectXBase", instruments: list[dict[str, Any]], search_symbol: str
74+
self: "ProjectXClientProtocol",
75+
instruments: list[dict[str, Any]],
76+
search_symbol: str,
7577
) -> dict[str, Any]:
7678
"""
7779
Select the best matching contract from search results.
@@ -130,7 +132,7 @@ def _select_best_contract(
130132
return instruments[0]
131133

132134
async def search_instruments(
133-
self: "ProjectXBase", query: str, live: bool = False
135+
self: "ProjectXClientProtocol", query: str, live: bool = False
134136
) -> list[Instrument]:
135137
"""
136138
Search for instruments by symbol or name.
@@ -159,7 +161,7 @@ async def search_instruments(
159161
return [Instrument(**contract) for contract in contracts_data]
160162

161163
async def get_bars(
162-
self: "ProjectXBase",
164+
self: "ProjectXClientProtocol",
163165
symbol: str,
164166
days: int = 8,
165167
interval: int = 5,

0 commit comments

Comments
 (0)