Skip to content

Commit 219ac76

Browse files
committed
fix: Address critical PR review issues
- Fixed JWT token exposure by moving tokens from URLs to headers - Updated SignalR connection to use Authorization header - Removed access_token from query parameters - Replaced broad exception catching with specific types - Updated memory_management.py to handle specific exceptions - Updated core.py to catch ProjectX-specific exceptions - Added proper handling for asyncio.CancelledError - Added comprehensive input validation for trading calculations - Added validation for all numeric parameters - Proper type checking for required types - Clear error messages for invalid inputs These changes address the security vulnerabilities and improve error handling as requested in the PR review.
1 parent 71430b6 commit 219ac76

File tree

5 files changed

+82
-23
lines changed

5 files changed

+82
-23
lines changed

src/project_x_py/realtime/connection_management.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,15 @@ async def setup_connections(self: "ProjectXRealtimeClientProtocol") -> None:
5757
raise ImportError("signalrcore is required for real-time functionality")
5858

5959
async with self._connection_lock:
60-
# Build user hub connection
60+
# Build user hub connection with JWT in headers
6161
self.user_connection = (
6262
HubConnectionBuilder()
63-
.with_url(self.user_hub_url)
63+
.with_url(
64+
self.user_hub_url,
65+
options={
66+
"headers": {"Authorization": f"Bearer {self.jwt_token}"}
67+
},
68+
)
6469
.configure_logging(
6570
logging.INFO,
6671
socket_trace=False,
@@ -76,10 +81,15 @@ async def setup_connections(self: "ProjectXRealtimeClientProtocol") -> None:
7681
.build()
7782
)
7883

79-
# Build market hub connection
84+
# Build market hub connection with JWT in headers
8085
self.market_connection = (
8186
HubConnectionBuilder()
82-
.with_url(self.market_hub_url)
87+
.with_url(
88+
self.market_hub_url,
89+
options={
90+
"headers": {"Authorization": f"Bearer {self.jwt_token}"}
91+
},
92+
)
8393
.configure_logging(
8494
logging.INFO,
8595
socket_trace=False,
@@ -425,13 +435,9 @@ async def update_jwt_token(
425435
# Disconnect existing connections
426436
await self.disconnect()
427437

428-
# Update token
438+
# Update JWT token for header authentication
429439
self.jwt_token = new_jwt_token
430440

431-
# Update URLs with new token
432-
self.user_hub_url = f"{self.base_user_url}?access_token={new_jwt_token}"
433-
self.market_hub_url = f"{self.base_market_url}?access_token={new_jwt_token}"
434-
435441
# Reset setup flag to force new connection setup
436442
self.setup_complete = False
437443

src/project_x_py/realtime/core.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def __init__(
133133
... )
134134
135135
Note:
136-
- JWT token is appended as access_token query parameter
136+
- JWT token is passed securely via Authorization header
137137
- Both hubs must connect successfully for full functionality
138138
- SignalR connections are established lazily on connect()
139139
"""
@@ -152,9 +152,9 @@ def __init__(
152152
final_user_url = user_hub_url or default_user_url
153153
final_market_url = market_hub_url or default_market_url
154154

155-
# Build complete URLs with authentication
156-
self.user_hub_url = f"{final_user_url}?access_token={jwt_token}"
157-
self.market_hub_url = f"{final_market_url}?access_token={jwt_token}"
155+
# Store URLs without tokens (tokens will be passed in headers)
156+
self.user_hub_url = final_user_url
157+
self.market_hub_url = final_market_url
158158

159159
# Set up base URLs for token refresh
160160
if config:

src/project_x_py/realtime_data_manager/core.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import polars as pl
1616
import pytz
1717

18+
from ..exceptions import ProjectXDataError, ProjectXError, ProjectXInstrumentError
1819
from .callbacks import CallbackMixin
1920
from .data_access import DataAccessMixin
2021
from .data_processing import DataProcessingMixin
@@ -350,8 +351,14 @@ async def initialize(self, initial_days: int = 1) -> bool:
350351
)
351352
return True
352353

353-
except Exception as e:
354-
self.logger.error(f"❌ Failed to initialize: {e}")
354+
except ProjectXInstrumentError as e:
355+
self.logger.error(f"❌ Failed to initialize - instrument error: {e}")
356+
return False
357+
except ProjectXDataError as e:
358+
self.logger.error(f"❌ Failed to initialize - data error: {e}")
359+
return False
360+
except ProjectXError as e:
361+
self.logger.error(f"❌ Failed to initialize - ProjectX error: {e}")
355362
return False
356363

357364
async def start_realtime_feed(self) -> bool:
@@ -447,8 +454,11 @@ async def on_new_bar(data):
447454
self.logger.info(f"✅ Real-time OHLCV feed started for {self.instrument}")
448455
return True
449456

450-
except Exception as e:
451-
self.logger.error(f"❌ Failed to start real-time feed: {e}")
457+
except RuntimeError as e:
458+
self.logger.error(f"❌ Failed to start real-time feed - runtime error: {e}")
459+
return False
460+
except asyncio.TimeoutError as e:
461+
self.logger.error(f"❌ Failed to start real-time feed - timeout: {e}")
452462
return False
453463

454464
async def stop_realtime_feed(self) -> None:
@@ -474,7 +484,7 @@ async def stop_realtime_feed(self) -> None:
474484

475485
self.logger.info(f"✅ Real-time feed stopped for {self.instrument}")
476486

477-
except Exception as e:
487+
except RuntimeError as e:
478488
self.logger.error(f"❌ Error stopping real-time feed: {e}")
479489

480490
async def cleanup(self) -> None:

src/project_x_py/realtime_data_manager/memory_management.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,19 @@ async def _periodic_cleanup(self: "RealtimeDataManagerProtocol") -> None:
7272
try:
7373
await asyncio.sleep(self.cleanup_interval)
7474
await self._cleanup_old_data()
75-
except Exception as e:
76-
self.logger.error(f"Error in periodic cleanup: {e}")
75+
except asyncio.CancelledError:
76+
# Task cancellation is expected during shutdown
77+
self.logger.debug("Periodic cleanup task cancelled")
78+
raise
79+
except MemoryError as e:
80+
self.logger.error(f"Memory error during cleanup: {e}")
81+
# Force immediate garbage collection
82+
import gc
83+
84+
gc.collect()
85+
except RuntimeError as e:
86+
self.logger.error(f"Runtime error in periodic cleanup: {e}")
87+
# Don't re-raise runtime errors to keep the cleanup task running
7788

7889
def get_memory_stats(self: "RealtimeDataManagerProtocol") -> dict:
7990
"""

src/project_x_py/utils/trading_calculations.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,13 @@ def calculate_tick_value(
2222
>>> calculate_tick_value(0.5, 0.1, 1.0)
2323
5.0
2424
"""
25+
# Validate inputs
2526
if tick_size <= 0:
26-
return 0.0
27+
raise ValueError(f"tick_size must be positive, got {tick_size}")
28+
if tick_value < 0:
29+
raise ValueError(f"tick_value cannot be negative, got {tick_value}")
30+
if not isinstance(price_change, (int, float)):
31+
raise TypeError(f"price_change must be numeric, got {type(price_change)}")
2732

2833
num_ticks = abs(price_change) / tick_size
2934
return num_ticks * tick_value
@@ -49,8 +54,15 @@ def calculate_position_value(
4954
>>> calculate_position_value(5, 2050.0, 1.0, 0.1)
5055
102500.0
5156
"""
57+
# Validate inputs
5258
if tick_size <= 0:
53-
return 0.0
59+
raise ValueError(f"tick_size must be positive, got {tick_size}")
60+
if tick_value < 0:
61+
raise ValueError(f"tick_value cannot be negative, got {tick_value}")
62+
if price < 0:
63+
raise ValueError(f"price cannot be negative, got {price}")
64+
if not isinstance(size, int):
65+
raise TypeError(f"size must be an integer, got {type(size)}")
5466

5567
ticks_per_point = 1.0 / tick_size
5668
value_per_point = ticks_per_point * tick_value
@@ -68,12 +80,18 @@ def round_to_tick_size(price: float, tick_size: float) -> float:
6880
Returns:
6981
float: Price rounded to nearest tick
7082
83+
Raises:
84+
ValueError: If tick_size is not positive or price is negative
85+
7186
Example:
7287
>>> round_to_tick_size(2050.37, 0.1)
7388
2050.4
7489
"""
90+
# Validate inputs
7591
if tick_size <= 0:
76-
return price
92+
raise ValueError(f"tick_size must be positive, got {tick_size}")
93+
if price < 0:
94+
raise ValueError(f"price cannot be negative, got {price}")
7795

7896
return round(price / tick_size) * tick_size
7997

@@ -142,6 +160,20 @@ def calculate_position_sizing(
142160
>>> sizing = calculate_position_sizing(50000, 0.02, 2050, 2040, 1.0)
143161
>>> print(f"Position size: {sizing['position_size']} contracts")
144162
"""
163+
# Validate inputs
164+
if account_balance <= 0:
165+
raise ValueError(f"account_balance must be positive, got {account_balance}")
166+
if not 0 < risk_per_trade <= 1:
167+
raise ValueError(
168+
f"risk_per_trade must be between 0 and 1, got {risk_per_trade}"
169+
)
170+
if entry_price <= 0:
171+
raise ValueError(f"entry_price must be positive, got {entry_price}")
172+
if stop_loss_price <= 0:
173+
raise ValueError(f"stop_loss_price must be positive, got {stop_loss_price}")
174+
if tick_value <= 0:
175+
raise ValueError(f"tick_value must be positive, got {tick_value}")
176+
145177
try:
146178
# Calculate risk per share/contract
147179
price_risk = abs(entry_price - stop_loss_price)

0 commit comments

Comments
 (0)