diff --git a/CHANGELOG.md b/CHANGELOG.md index f898029..b536cc8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,52 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Migration guides will be provided for all breaking changes - Semantic versioning (MAJOR.MINOR.PATCH) is strictly followed +## [3.1.13] - 2025-08-15 + +### Fixed +- **๐ŸŽฏ Event System Data Structure Mismatches**: Fixed critical order fill detection issues + - Bracket orders now properly detect fills without timing out + - Event handlers now correctly handle both `order_id` and nested `order` object structures + - Added backward compatibility for different event payload formats + - ManagedTrade now listens to correct events (ORDER_FILLED vs ORDER_MODIFIED) + +- **๐Ÿ“ Type Annotations for SignalR Connections**: Improved IDE support and type safety + - Created HubConnection type alias for BaseHubConnection + - Fixed market_connection and user_connection from Any to proper types + - IDEs now recognize connection methods (send, on, start, stop) + - Updated ProjectXRealtimeClientProtocol to match implementation + +- **๐Ÿ”ง Real-time Connection Improvements**: Enhanced WebSocket stability + - Added circuit breaker pattern to BatchedWebSocketHandler + - Improved subscription handling with proper event waiting + - Fixed asyncio deprecation warnings (get_event_loop โ†’ get_running_loop) + - Better error handling and recovery mechanisms + +### Improved +- **๐Ÿ“Š Data Storage Robustness**: Major improvements to mmap_storage module + - Fixed critical bug causing data overwrite on initialization + - Implemented binary search for read_window (significant performance boost) + - Added thread-safe operations with RLock + - Fixed file corruption bug in _resize_file + - Replaced print statements with proper logging + +- **๐Ÿงช Test Coverage**: Dramatically improved client module testing + - Client module coverage increased from 30% to 93% + - Added 70+ comprehensive test cases across all client components + - Fixed bug in _select_best_contract method + - Full test coverage for base.py (100%) and trading.py (98%) + +- **๐Ÿ—๏ธ Order and Position Management**: Enhanced tracking and stability + - Improved order tracking with better event handling + - More robust position manager logic + - Better error recovery in order chains + - Enhanced TradingSuite configuration options + +### Documentation +- Updated CHANGELOG.md with comprehensive v3.1.13 changes +- Updated CLAUDE.md Recent Changes section +- Added detailed commit messages for all fixes + ## [3.1.12] - 2025-08-15 ### Added diff --git a/CLAUDE.md b/CLAUDE.md index 0afcf98..95b078b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -2,7 +2,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. -## Project Status: v3.1.9 - Stable Production Release +## Project Status: v3.1.13 - Stable Production Release **IMPORTANT**: This project uses a fully asynchronous architecture. All APIs are async-only, optimized for high-performance futures trading. @@ -162,6 +162,18 @@ uv run python -m build # Alternative build command - Async event handlers with priority support - Built-in event types for all trading events +### Available TradingSuite Features + +The `Features` enum defines optional components that can be enabled: + +- `ORDERBOOK = "orderbook"` - Level 2 market depth and analysis +- `RISK_MANAGER = "risk_manager"` - Position sizing and risk management +- `TRADE_JOURNAL = "trade_journal"` - Trade logging (future) +- `PERFORMANCE_ANALYTICS = "performance_analytics"` - Advanced metrics (future) +- `AUTO_RECONNECT = "auto_reconnect"` - Automatic reconnection (future) + +**Note**: OrderManager and PositionManager are always included by default. + ### Architecture Patterns **Async Factory Functions**: Use async `create_*` functions for component initialization: @@ -288,7 +300,20 @@ async with ProjectX.from_env() as client: ## Recent Changes -### v3.1.12 - Latest Release +### v3.1.13 - Latest Release +- **Fixed**: Event system data structure mismatches causing order fill detection failures + - Bracket orders now properly detect fills without 60-second timeouts + - Event handlers handle both `order_id` and nested `order` object structures + - ManagedTrade correctly listens to ORDER_FILLED instead of ORDER_MODIFIED +- **Fixed**: Type annotations for SignalR hub connections + - Created HubConnection type alias for proper IDE support + - market_connection and user_connection now have proper types instead of Any +- **Improved**: Real-time connection stability with circuit breaker pattern +- **Improved**: Data storage robustness with thread-safety and performance optimizations +- **Enhanced**: Test coverage increased from 30% to 93% for client module +- **Fixed**: Multiple asyncio deprecation warnings + +### v3.1.12 - **Enhanced**: Significantly improved `01_events_with_on.py` real-time data example - Added CSV export functionality with interactive prompts - Plotly-based candlestick chart generation @@ -375,7 +400,7 @@ async def main(): suite = await TradingSuite.create( "MNQ", timeframes=["1min", "5min"], - features=["orderbook", "risk_manager"], + features=["orderbook", "risk_manager"], # Optional features initial_days=5 ) diff --git a/README.md b/README.md index 2c7fd7d..bd58636 100644 --- a/README.md +++ b/README.md @@ -21,9 +21,9 @@ A **high-performance async Python SDK** for the [ProjectX Trading Platform](http This Python SDK acts as a bridge between your trading strategies and the ProjectX platform, handling all the complex API interactions, data processing, and real-time connectivity. -## ๐Ÿš€ v3.1.11 - Stable Production Release +## ๐Ÿš€ v3.1.13 - Stable Production Release -**Latest Version**: v3.1.11 - Fixed ManagedTrade market price fetching for risk-managed trades. See [CHANGELOG.md](CHANGELOG.md) for full release history. +**Latest Version**: v3.1.13 - Fixed critical event system issues affecting bracket order fill detection and improved real-time connection stability. See [CHANGELOG.md](CHANGELOG.md) for full release history. ### ๐Ÿ“ฆ Production Stability Guarantee @@ -251,6 +251,38 @@ Or use a config file (`~/.config/projectx/config.json`): } ``` +### Available Features + +TradingSuite supports optional features that can be enabled during initialization: + +| Feature | String Value | Description | +|---------|-------------|-------------| +| **OrderBook** | `"orderbook"` | Level 2 market depth, bid/ask analysis, iceberg detection | +| **Risk Manager** | `"risk_manager"` | Position sizing, risk validation, managed trades | +| **Trade Journal** | `"trade_journal"` | Trade logging and performance tracking (future) | +| **Performance Analytics** | `"performance_analytics"` | Advanced metrics and analysis (future) | +| **Auto Reconnect** | `"auto_reconnect"` | Automatic WebSocket reconnection (future) | + +**Note:** PositionManager and OrderManager are always included and don't require feature flags. + +```python +# Enable specific features +suite = await TradingSuite.create( + "MNQ", + features=["orderbook", "risk_manager"] +) + +# Access feature-specific components +if suite.orderbook: # Only available when orderbook feature is enabled + spread = await suite.orderbook.get_bid_ask_spread() + +if suite.risk_manager: # Only available when risk_manager feature is enabled + sizing = await suite.risk_manager.calculate_position_size( + entry_price=100.0, + stop_loss=99.0 + ) +``` + ### Component Overview #### ProjectX Client @@ -296,11 +328,19 @@ icebergs = await suite.orderbook.detect_iceberg_orders() ``` #### RiskManager -Risk management and managed trades (when enabled): +Risk management and managed trades (requires feature flag): ```python # Enable risk manager in features suite = await TradingSuite.create("MNQ", features=["risk_manager"]) +# Risk manager integrates with PositionManager automatically +# Use for position sizing and risk validation +sizing = await suite.risk_manager.calculate_position_size( + entry_price=100.0, + stop_loss=99.0, + risk_percent=0.02 # Risk 2% of account +) + # Use managed trades for automatic risk management async with suite.managed_trade(max_risk_percent=0.01) as trade: # Market price fetched automatically (v3.1.11+) @@ -310,6 +350,8 @@ async with suite.managed_trade(max_risk_percent=0.01) as trade: ) ``` +**Note:** RiskManager requires the `"risk_manager"` feature flag and automatically integrates with PositionManager for comprehensive risk tracking. + ### Technical Indicators All 58+ indicators work with async data pipelines: diff --git a/RELEASE_NOTES_v3.1.11.md b/RELEASE_NOTES_v3.1.11.md deleted file mode 100644 index 9884a38..0000000 --- a/RELEASE_NOTES_v3.1.11.md +++ /dev/null @@ -1,69 +0,0 @@ -# Release Notes - v3.1.11 - -## ๐ŸŽฏ Risk Manager Market Price Fetching Fix - -### Overview -This release fixes a critical issue in the Risk Manager's ManagedTrade class where the `_get_market_price()` method was not implemented, preventing users from entering risk-managed trades without explicitly providing an entry price. - -### What's Fixed - -#### ManagedTrade Market Price Implementation -- **Problem**: When using `ManagedTrade.enter_long()` or `enter_short()` without an explicit `entry_price`, the system would fail with `NotImplementedError` -- **Solution**: Fully implemented `_get_market_price()` method that fetches current market prices from the data manager -- **Impact**: Risk-managed trades can now be entered using current market prices automatically - -### Technical Details - -#### Implementation Features -- **Smart Timeframe Fallback**: Tries multiple timeframes in order (1sec โ†’ 15sec โ†’ 1min โ†’ 5min) to get the most recent price -- **Direct Price Access**: Falls back to `get_current_price()` if bar data isn't available -- **Data Manager Integration**: ManagedTrade now receives data manager from TradingSuite automatically -- **Clear Error Messages**: Provides helpful error messages when market price cannot be fetched - -#### Code Changes -```python -# Before (would fail) -async with suite.managed_trade(max_risk_percent=0.01) as trade: - result = await trade.enter_long( - stop_loss=current_price - 50, # Would throw NotImplementedError - take_profit=current_price + 100 - ) - -# After (works perfectly) -async with suite.managed_trade(max_risk_percent=0.01) as trade: - result = await trade.enter_long( - stop_loss=current_price - 50, # Automatically fetches market price - take_profit=current_price + 100 - ) -``` - -### Migration Guide -No breaking changes. Existing code will continue to work. The enhancement is backward compatible: -- If you provide `entry_price` explicitly, it works as before -- If you omit `entry_price`, the system now fetches it automatically - -### Testing -The implementation has been tested with live market data: -- โœ… Market price fetching works correctly -- โœ… Fallback through multiple timeframes functions properly -- โœ… Integration with TradingSuite is seamless -- โœ… Risk orders (stop loss, take profit) are properly attached - -### Dependencies -No new dependencies required. Uses existing data manager infrastructure. - -### Known Issues -None at this time. - -### Future Improvements -- Consider adding configurable timeframe priority for price fetching -- Add option to use bid/ask prices instead of last trade price -- Implement price staleness checks with configurable thresholds - -### Support -For issues or questions about this release, please open an issue on GitHub or contact support. - ---- -*Released: 2025-08-13* -*Version: 3.1.11* -*Type: Bug Fix / Feature Enhancement* \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index e46611b..c2337d3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -23,8 +23,8 @@ project = "project-x-py" copyright = "2025, Jeff West" author = "Jeff West" -release = "3.1.12" -version = "3.1.12" +release = "3.1.13" +version = "3.1.13" # -- General configuration --------------------------------------------------- diff --git a/examples/02_order_management.py b/examples/02_order_management.py index dd0b6bf..9e83165 100644 --- a/examples/02_order_management.py +++ b/examples/02_order_management.py @@ -270,7 +270,22 @@ async def main() -> bool: print("๐Ÿ“ EXAMPLE 3: BRACKET ORDER") print("=" * 50) - entry_price = current_price - Decimal("5.0") # Entry $5 below market + try: + market_data = await suite.client.get_bars( + "MNQ", days=days, interval=interval + ) + if market_data is not None and not market_data.is_empty(): + current_price = Decimal( + str(market_data.select("close").tail(1).item()) + ) + latest_time = market_data.select("timestamp").tail(1).item() + print(f"โœ… Retrieved MNQ price: ${current_price:.2f}") + print(f" Data from: {latest_time} ({days}d {interval}min bars)") + except Exception: + print("โŒ Failed to get market data") + return False + + entry_price = current_price - Decimal("1.00") # Entry $5 below market stop_loss = entry_price - Decimal("10.0") # $10 risk take_profit = entry_price + Decimal("20.0") # $20 profit target (2:1 R/R) diff --git a/examples/03_position_management.py b/examples/03_position_management.py index 84f8eb2..d3a8ce3 100644 --- a/examples/03_position_management.py +++ b/examples/03_position_management.py @@ -247,6 +247,7 @@ async def main() -> bool: suite = await TradingSuite.create( "MNQ", timeframes=["1min", "5min"], + features=["risk_manager"], ) print( diff --git a/examples/14_phase4_comprehensive_test.py b/examples/14_phase4_comprehensive_test.py deleted file mode 100644 index 9ba7b95..0000000 --- a/examples/14_phase4_comprehensive_test.py +++ /dev/null @@ -1,439 +0,0 @@ -#!/usr/bin/env python3 -""" -Phase 4 Comprehensive Test - Data and Orders Improvements - -This example demonstrates all the improvements from Phase 4: -- Simplified data access methods -- Enhanced model properties -- Cleaner strategy implementation - -Author: SDK v3.0.2 Testing -""" - -import asyncio -from datetime import datetime - -from project_x_py import TradingSuite -from project_x_py.indicators import ATR, RSI, SMA - - -class CleanTradingStrategy: - """A trading strategy using all Phase 4 improvements.""" - - def __init__(self, suite: TradingSuite): - self.suite = suite - self.data = suite.data - self.orders = suite.orders - self.positions = suite.positions - - # Strategy parameters - self.max_position_size = 5 - self.profit_target_ticks = 20 - self.stop_loss_ticks = 10 - - async def analyze_market( - self, - ) -> dict[str, float | int | str | dict[str, float]] | None: - """Analyze market using simplified data access.""" - # Use new get_data_or_none for cleaner code - data = await self.data.get_data_or_none("5min", min_bars=50) - if data is None: - return None - - # Calculate indicators - data = data.pipe(SMA, period=20).pipe(RSI, period=14).pipe(ATR, period=14) - - # Get current market state using new methods - current_price = await self.data.get_latest_price() - ohlc = await self.data.get_ohlc("5min") - price_range = await self.data.get_price_range(bars=20) - volume_stats = await self.data.get_volume_stats(bars=20) - - if ( - current_price is None - or ohlc is None - or price_range is None - or volume_stats is None - ): - return None - - # Analyze trend - sma20 = float(data["sma_20"][-1]) - rsi = float(data["rsi_14"][-1]) - atr = float(data["atr_14"][-1]) - - # Price position within range - price_position = (float(current_price) - float(price_range["low"])) / float( - price_range["range"] - ) - - return { - "price": float(current_price), - "trend": "bullish" if float(current_price) > sma20 else "bearish", - "trend_strength": abs(float(current_price) - sma20) / sma20, - "rsi": rsi, - "atr": atr, - "price_position": price_position, - "volume_relative": float(volume_stats["relative"]), - "range": float(price_range["range"]), - "ohlc": { - "open": float(ohlc["open"]), - "high": float(ohlc["high"]), - "low": float(ohlc["low"]), - "close": float(ohlc["close"]), - }, - } - - async def check_positions( - self, - ) -> dict[str, float | int | list[dict[str, float | int | str]]]: - """Check positions using enhanced model properties.""" - positions = await self.positions.get_all_positions() - - position_summary = { - "total_positions": len(positions), - "long_positions": 0, - "short_positions": 0, - "total_exposure": 0.0, - "positions": [], - } - - current_price = await self.data.get_latest_price() - if current_price is None: - return position_summary - - for pos in positions: - # Use new Position properties - if pos.is_long: - position_summary["long_positions"] += 1 - elif pos.is_short: - position_summary["short_positions"] += 1 - - position_summary["total_exposure"] += pos.total_cost - - # Calculate P&L using the unrealized_pnl method - pnl = pos.unrealized_pnl(float(current_price), tick_value=5.0) - - position_summary["positions"].append( - { - "id": pos.id, - "symbol": pos.symbol, # New property - "direction": pos.direction, # New property - "size": pos.size, - "signed_size": pos.signed_size, # New property - "entry": pos.averagePrice, - "pnl": pnl, - "pnl_ticks": (pnl / 5.0) / pos.size if pos.size > 0 else 0, - } - ) - - return position_summary - - async def check_orders( - self, - ) -> dict[str, float | int | list[dict[str, float | int | str]]]: - """Check orders using enhanced model properties.""" - orders = await self.orders.search_open_orders() - - order_summary = { - "total_orders": len(orders), - "working_orders": 0, - "buy_orders": 0, - "sell_orders": 0, - "orders": [], - } - - for order in orders: - # Use new Order properties - if order.is_working: - order_summary["working_orders"] += 1 - - if order.is_buy: - order_summary["buy_orders"] += 1 - elif order.is_sell: - order_summary["sell_orders"] += 1 - - order_summary["orders"].append( - { - "id": order.id, - "symbol": order.symbol, # New property - "side": order.side_str, # New property - "type": order.type_str, # New property - "status": order.status_str, # New property - "size": order.size, - "remaining": order.remaining_size, # New property - "filled_pct": order.filled_percent, # New property - "price": float(order.limitPrice) - if order.limitPrice - else float(order.stopPrice) - if order.stopPrice - else 0.0, - } - ) - - return order_summary - - async def execute_strategy(self) -> None: - """Execute trading strategy using all Phase 4 improvements.""" - print("\n=== Strategy Execution ===") - - # 1. Check if data is ready - if not await self.data.is_data_ready(min_bars=50): - print("โณ Insufficient data for strategy") - return - - # 2. Analyze market - analysis = await self.analyze_market() - if not analysis: - print("โŒ Market analysis failed") - return - - print("\n๐Ÿ“Š Market Analysis:") - if not isinstance(analysis["trend"], str): - print("โŒ Trend is not a string") - return - - if not isinstance(analysis["trend_strength"], float): - print("โŒ Trend strength is not a float") - return - - print(f" Price: ${analysis['price']:,.2f}") - print( - f" Trend: {analysis['trend'].upper()} (strength: {analysis['trend_strength']:.1%})" - ) - print(f" RSI: {analysis['rsi']:.1f}") - print(f" Volume: {analysis['volume_relative']:.1%} of average") - print(f" Price Position: {analysis['price_position']:.1%} of range") - - # 3. Check current positions - position_summary = await self.check_positions() - print("\n๐Ÿ“ˆ Position Summary:") - print(f" Total: {position_summary['total_positions']}") - print(f" Long: {position_summary['long_positions']}") - print(f" Short: {position_summary['short_positions']}") - print(f" Exposure: ${position_summary['total_exposure']:,.2f}") - - if not isinstance(position_summary["positions"], list): - print("โŒ Position summary is not a list") - return - - for pos in position_summary["positions"]: - print( - f"\n {pos['direction']} {pos['size']} {pos['symbol']} @ ${pos['entry']:,.2f}" - ) - print(f" P&L: ${pos['pnl']:+,.2f} ({pos['pnl_ticks']:+.1f} ticks)") - - if not isinstance(pos["pnl_ticks"], float): - print("โŒ P&L ticks is not a float") - return - - # Exit logic using clean properties - if pos["pnl_ticks"] >= self.profit_target_ticks: - print(" โœ… PROFIT TARGET REACHED!") - elif pos["pnl_ticks"] <= -self.stop_loss_ticks: - print(" ๐Ÿ›‘ STOP LOSS TRIGGERED!") - - # 4. Check current orders - order_summary = await self.check_orders() - if not isinstance(order_summary["orders"], list): - print("โŒ Order summary is not a list") - return - - if not isinstance(order_summary["total_orders"], int): - print("โŒ Order summary is not an integer") - return - - if not isinstance(order_summary["working_orders"], int): - print("โŒ Order summary is not an integer") - return - - if not isinstance(order_summary["buy_orders"], int): - print("โŒ Order summary is not an integer") - return - - if order_summary["total_orders"] > 0: - print("\n๐Ÿ“‹ Order Summary:") - print(f" Working: {order_summary['working_orders']}") - print(f" Buy Orders: {order_summary['buy_orders']}") - print(f" Sell Orders: {order_summary['sell_orders']}") - - for order in order_summary["orders"]: - print( - f"\n {order['side']} {order['size']} {order['symbol']} - {order['type']}" - ) - print( - f" Status: {order['status']} ({order['filled_pct']:.0f}% filled)" - ) - if order["price"]: - print(f" Price: ${order['price']:,.2f}") - - # 5. Generate trading signals - signal = self._generate_signal(analysis, position_summary) - if signal: - print( - f"\n๐ŸŽฏ SIGNAL: {signal['action']} (confidence: {signal['confidence']:.1%})" - ) - print(f" Reason: {signal['reason']}") - - def _generate_signal( - self, analysis: dict, positions: dict - ) -> dict[str, float | str] | None: - """Generate trading signal based on analysis.""" - # No signal if we have max positions - if positions["total_positions"] >= self.max_position_size: - return None - - # Bullish signal - if ( - analysis["trend"] == "bullish" - and analysis["rsi"] < 70 - and analysis["volume_relative"] > 1.2 - and analysis["price_position"] < 0.7 - ): - return { - "action": "BUY", - "confidence": min( - analysis["trend_strength"], - (70 - analysis["rsi"]) / 50, - analysis["volume_relative"] - 1.0, - ), - "reason": "Bullish trend with momentum, not overbought", - } - - # Bearish signal - elif ( - analysis["trend"] == "bearish" - and analysis["rsi"] > 30 - and analysis["volume_relative"] > 1.2 - and analysis["price_position"] > 0.3 - ): - return { - "action": "SELL", - "confidence": min( - analysis["trend_strength"], - (analysis["rsi"] - 30) / 50, - analysis["volume_relative"] - 1.0, - ), - "reason": "Bearish trend with momentum, not oversold", - } - - return None - - -async def demonstrate_phase4_improvements() -> None: - """Demonstrate all Phase 4 improvements in action.""" - - async with await TradingSuite.create( - "MNQ", timeframes=["1min", "5min", "15min"], initial_days=2 - ) as suite: - print("ProjectX SDK v3.0.2 - Phase 4 Comprehensive Test") - print("=" * 60) - - strategy = CleanTradingStrategy(suite) - - # 1. Test simplified data access - print("\n1๏ธโƒฃ Testing Simplified Data Access") - print("-" * 40) - - # Old way vs new way comparison - print("OLD: data = await manager.get_data('5min')") - print(" if data is None or len(data) < 50:") - print(" return") - print("\nNEW: data = await manager.get_data_or_none('5min', min_bars=50)") - print(" if data is None:") - print(" return") - - # Test new methods - latest_bars = await suite.data.get_latest_bars(5) - if latest_bars is not None: - print(f"\nโœ… get_latest_bars(): Got {len(latest_bars)} bars") - - price = await suite.data.get_latest_price() - print(f"โœ… get_latest_price(): ${price:,.2f}") - - ohlc = await suite.data.get_ohlc() - if ohlc: - print( - f"โœ… get_ohlc(): O:{ohlc['open']:,.2f} H:{ohlc['high']:,.2f} " - f"L:{ohlc['low']:,.2f} C:{ohlc['close']:,.2f}" - ) - - # 2. Test enhanced models - print("\n\n2๏ธโƒฃ Testing Enhanced Model Properties") - print("-" * 40) - - # Create demo position - from project_x_py.models import Order, Position - - demo_pos = Position( - id=1, - accountId=1, - contractId="CON.F.US.MNQ.H25", - creationTimestamp=datetime.now().isoformat(), - type=1, - size=2, - averagePrice=16500.0, - ) - - print("Position Properties:") - print(f" direction: {demo_pos.direction}") - print(f" symbol: {demo_pos.symbol}") - print(f" is_long: {demo_pos.is_long}") - print(f" signed_size: {demo_pos.signed_size}") - print(f" total_cost: ${demo_pos.total_cost:,.2f}") - - # Create demo order - demo_order = Order( - id=1, - accountId=1, - contractId="CON.F.US.MNQ.H25", - creationTimestamp=datetime.now().isoformat(), - updateTimestamp=None, - status=1, - type=1, - side=0, - size=5, - fillVolume=2, - limitPrice=16450.0, - ) - - print("\nOrder Properties:") - print(f" side_str: {demo_order.side_str}") - print(f" type_str: {demo_order.type_str}") - print(f" status_str: {demo_order.status_str}") - print(f" is_working: {demo_order.is_working}") - print(f" filled_percent: {demo_order.filled_percent:.0f}%") - print(f" remaining_size: {demo_order.remaining_size}") - - # 3. Execute full strategy - print("\n\n3๏ธโƒฃ Testing Complete Strategy Implementation") - print("-" * 40) - - await strategy.execute_strategy() - - # 4. Performance comparison - print("\n\n4๏ธโƒฃ Code Complexity Comparison") - print("-" * 40) - print("Lines of code reduced:") - print(" Data access: ~10 lines โ†’ 2 lines (80% reduction)") - print(" Position checks: ~15 lines โ†’ 5 lines (67% reduction)") - print(" Order filtering: ~8 lines โ†’ 3 lines (63% reduction)") - print("\nโœ… Overall: Cleaner, more readable, less error-prone code!") - - -async def main() -> None: - """Run Phase 4 comprehensive test.""" - try: - await demonstrate_phase4_improvements() - - except KeyboardInterrupt: - print("\n\nTest interrupted by user") - except Exception as e: - print(f"\nโŒ Error: {e}") - import traceback - - traceback.print_exc() - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/examples/README_instrument_demo.md b/examples/README_instrument_demo.md deleted file mode 100644 index 5115be3..0000000 --- a/examples/README_instrument_demo.md +++ /dev/null @@ -1,59 +0,0 @@ -# Instrument Search Demo - -## Overview -The `09_get_check_available_instruments.py` script is an interactive demo that showcases the ProjectX instrument search functionality. - -## Features -- **Interactive search**: Enter any futures symbol to search for contracts -- **Smart contract selection**: See how `get_instrument()` selects the best matching contract -- **All contracts view**: Compare all available contracts with `search_instruments()` -- **Visual formatting**: Clear display of contract details with active contract indicators - -## Usage -```bash -# Make sure you have environment variables set -export PROJECT_X_API_KEY="your-api-key" -export PROJECT_X_USERNAME="your-username" - -# Run the demo -python examples/09_get_check_available_instruments.py -``` - -## Example Session -``` -Enter a symbol to search (or 'quit' to exit): NQ - -============================================================ -Searching for: 'NQ' -============================================================ - -1. All contracts matching 'NQ': --------------------------------------------------- - Found 2 contract(s): - - โ˜… [1] NQU5 - E-mini NASDAQ-100: September 2025 - [2] MNQU5 - Micro E-mini Nasdaq-100: September 2025 - -2. Best match using get_instrument('NQ'): --------------------------------------------------- - Selected: NQU5 - โ”Œโ”€ Contract Details โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - โ”‚ ID: CON.F.US.ENQ.U25 - โ”‚ Name: NQU5 - โ”‚ Symbol ID: F.US.ENQ - โ”‚ Description: E-mini NASDAQ-100: September 2025 - โ”‚ Active: โœ“ Yes - โ”‚ Tick Size: 0.25 - โ”‚ Tick Value: $5.0 - โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -``` - -## Contract Naming -- Contract names include futures month codes + year -- Examples: `NQU5` = NQ + U (September) + 5 (2025) -- Month codes: F(Jan), G(Feb), H(Mar), J(Apr), K(May), M(Jun), N(Jul), Q(Aug), U(Sep), V(Oct), X(Nov), Z(Dec) - -## Commands -- **Symbol name**: Search for that symbol (e.g., "ES", "NQ", "CL") -- **help**: Show common symbols table -- **quit/exit/q**: Exit the demo \ No newline at end of file diff --git a/examples/realtime_data_manager/00_events_with_wait_for.py b/examples/realtime_data_manager/00_events_with_wait_for.py index beac207..bb0868c 100644 --- a/examples/realtime_data_manager/00_events_with_wait_for.py +++ b/examples/realtime_data_manager/00_events_with_wait_for.py @@ -12,11 +12,11 @@ async def on_new_bar(suite: TradingSuite): print("=" * 80) if last_bars is not None and not last_bars.is_empty(): - print("Last 5 bars (oldest to newest):") + print("Last 6 bars (oldest to newest):") print("-" * 80) # Get the last 5 bars and iterate through them - for row in last_bars.tail(5).iter_rows(named=True): + for row in last_bars.tail(6).iter_rows(named=True): timestamp = row["timestamp"] open_price = row["open"] high = row["high"] diff --git a/examples/realtime_data_manager/01_events_with_on.py b/examples/realtime_data_manager/01_events_with_on.py index 594573b..89d6f00 100644 --- a/examples/realtime_data_manager/01_events_with_on.py +++ b/examples/realtime_data_manager/01_events_with_on.py @@ -239,7 +239,7 @@ async def on_new_bar(event): return try: - last_bars = await suite.data.get_data(timeframe=TIMEFRAME, bars=5) + last_bars = await suite.data.get_data(timeframe=TIMEFRAME, bars=6) except Exception as e: print(f"Error getting data: {e}") return diff --git a/examples/verify_orderblock_fvg_indicators.py b/examples/verify_orderblock_fvg_indicators.py deleted file mode 100755 index 28a6d68..0000000 --- a/examples/verify_orderblock_fvg_indicators.py +++ /dev/null @@ -1,262 +0,0 @@ -#!/usr/bin/env python -""" -Verification script for Order Block and Fair Value Gap indicators. - -This script loads historical market data and applies both the Order Block -and FVG indicators to verify they are detecting patterns correctly. -It displays the full dataframes to show all detected patterns. - -Author: ProjectX SDK -Date: 2025-01-12 -""" - -import asyncio -import os -import sys -from datetime import datetime - -import polars as pl - -# Add parent directory to path for imports -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from project_x_py import ProjectX -from project_x_py.indicators import FVG, ORDERBLOCK - - -def display_dataframe_info(df: pl.DataFrame, title: str) -> None: - """Display comprehensive information about a dataframe.""" - print(f"\n{'=' * 80}") - print(f"{title}") - print(f"{'=' * 80}") - print(f"Shape: {df.shape} (rows: {df.height}, columns: {df.width})") - print(f"Columns: {df.columns}") - print(f"Data types: {df.dtypes}") - print() - - -def display_order_blocks(df: pl.DataFrame) -> None: - """Display detected Order Blocks with details.""" - # Filter for rows where order blocks are detected - ob_columns = [col for col in df.columns if col.startswith("ob_")] - - if not ob_columns: - print("No Order Block columns found in dataframe!") - return - - # Check for bullish order blocks - if "ob_bullish" in df.columns: - bullish_obs = df.filter(pl.col("ob_bullish") == True) - if bullish_obs.height > 0: - print(f"\n๐Ÿ“ˆ BULLISH ORDER BLOCKS DETECTED: {bullish_obs.height}") - print("-" * 60) - # Limit display to first 10 for readability - for row in bullish_obs.head(10).iter_rows(named=True): - print(f" Time: {row['timestamp']}") - if "ob_top" in row and "ob_bottom" in row: - print( - f" Order Block Zone: ${row['ob_bottom']:.2f} - ${row['ob_top']:.2f}" - ) - print(f" Bar Range: ${row['low']:.2f} - ${row['high']:.2f}") - print(f" Volume: {row['volume']:,}") - if "ob_strength" in row: - print(f" Strength: {row['ob_strength']:.4f}") - print() - if bullish_obs.height > 10: - print(f" ... and {bullish_obs.height - 10} more") - - # Check for bearish order blocks - if "ob_bearish" in df.columns: - bearish_obs = df.filter(pl.col("ob_bearish") == True) - if bearish_obs.height > 0: - print(f"\n๐Ÿ“‰ BEARISH ORDER BLOCKS DETECTED: {bearish_obs.height}") - print("-" * 60) - # Limit display to first 10 for readability - for row in bearish_obs.head(10).iter_rows(named=True): - print(f" Time: {row['timestamp']}") - if "ob_top" in row and "ob_bottom" in row: - print( - f" Order Block Zone: ${row['ob_bottom']:.2f} - ${row['ob_top']:.2f}" - ) - print(f" Bar Range: ${row['low']:.2f} - ${row['high']:.2f}") - print(f" Volume: {row['volume']:,}") - if "ob_strength" in row: - print(f" Strength: {row['ob_strength']:.4f}") - print() - if bearish_obs.height > 10: - print(f" ... and {bearish_obs.height - 10} more") - - -def display_fair_value_gaps(df: pl.DataFrame) -> None: - """Display detected Fair Value Gaps with details.""" - # Filter for rows where FVGs are detected - fvg_columns = [col for col in df.columns if col.startswith("fvg_")] - - if not fvg_columns: - print("No FVG columns found in dataframe!") - return - - # Check for bullish FVGs - if "fvg_bullish_start" in df.columns: - bullish_fvgs = df.filter(pl.col("fvg_bullish_start").is_not_null()) - if bullish_fvgs.height > 0: - print(f"\nโฌ†๏ธ BULLISH FAIR VALUE GAPS DETECTED: {bullish_fvgs.height}") - print("-" * 60) - for row in bullish_fvgs.iter_rows(named=True): - print(f" Time: {row['timestamp']}") - print(f" Gap Start: ${row['fvg_bullish_start']:.2f}") - print(f" Gap End: ${row['fvg_bullish_end']:.2f}") - gap_size = row["fvg_bullish_end"] - row["fvg_bullish_start"] - print(f" Gap Size: ${gap_size:.2f}") - print(f" Current Price: ${row['close']:.2f}") - print() - - # Check for bearish FVGs - if "fvg_bearish_start" in df.columns: - bearish_fvgs = df.filter(pl.col("fvg_bearish_start").is_not_null()) - if bearish_fvgs.height > 0: - print(f"\nโฌ‡๏ธ BEARISH FAIR VALUE GAPS DETECTED: {bearish_fvgs.height}") - print("-" * 60) - for row in bearish_fvgs.iter_rows(named=True): - print(f" Time: {row['timestamp']}") - print(f" Gap Start: ${row['fvg_bearish_start']:.2f}") - print(f" Gap End: ${row['fvg_bearish_end']:.2f}") - gap_size = row["fvg_bearish_start"] - row["fvg_bearish_end"] - print(f" Gap Size: ${gap_size:.2f}") - print(f" Current Price: ${row['close']:.2f}") - print() - - -def display_sample_data(df: pl.DataFrame, n: int = 10) -> None: - """Display sample rows from the dataframe.""" - print(f"\n๐Ÿ“Š SAMPLE DATA (last {n} rows):") - print("-" * 60) - - # Select relevant columns for display - display_cols = ["timestamp", "open", "high", "low", "close", "volume"] - - # Add indicator columns if they exist - for col in df.columns: - if col.startswith("ob_") or col.startswith("fvg_"): - display_cols.append(col) - - # Filter to existing columns - display_cols = [col for col in display_cols if col in df.columns] - - # Show the data - sample = df.select(display_cols).tail(n) - print(sample) - - -async def main(): - """Main function to verify Order Block and FVG indicators.""" - print("๐Ÿ” Order Block and Fair Value Gap Indicator Verification") - print("=" * 80) - - try: - # Create ProjectX client - async with ProjectX.from_env() as client: - await client.authenticate() - print(f"โœ… Connected to account: {client.account_info.name}") - - # Get historical data with enough bars for pattern detection - print("\n๐Ÿ“ฅ Loading historical data...") - instrument = "MNQ" # Using E-mini NASDAQ futures - - # Get 10 days of 1-minute bars for more samples - bars_df = await client.get_bars( - instrument, - days=25, - interval=15, # 1-minute bars - unit=2, # 2 = minutes - ) - - if bars_df is None or bars_df.is_empty(): - print("โŒ No data retrieved!") - return - - print(f"โœ… Loaded {bars_df.height} bars of {instrument} data") - - # Display basic dataframe info - display_dataframe_info(bars_df, "Original Data") - - # Apply Order Block indicator - print("\n" + "=" * 80) - print("APPLYING ORDER BLOCK INDICATOR") - print("=" * 80) - - ob_df = bars_df.pipe(ORDERBLOCK, lookback_periods=20) - - # Display Order Block results - display_dataframe_info(ob_df, "Data with Order Blocks") - display_order_blocks(ob_df) - - # Apply Fair Value Gap indicator - print("\n" + "=" * 80) - print("APPLYING FAIR VALUE GAP INDICATOR") - print("=" * 80) - - fvg_df = bars_df.pipe(FVG) - - # Display FVG results - display_dataframe_info(fvg_df, "Data with Fair Value Gaps") - display_fair_value_gaps(fvg_df) - - # Apply both indicators together - print("\n" + "=" * 80) - print("APPLYING BOTH INDICATORS") - print("=" * 80) - - combined_df = bars_df.pipe(ORDERBLOCK, lookback_periods=20).pipe(FVG) - - # Display combined results - display_dataframe_info(combined_df, "Data with Both Indicators") - - # Show sample of the full dataframe - display_sample_data(combined_df, n=20) - - # Summary statistics - print("\n" + "=" * 80) - print("SUMMARY STATISTICS") - print("=" * 80) - - if "ob_bullish" in combined_df.columns: - bullish_ob_count = combined_df.filter( - pl.col("ob_bullish").is_not_null() - ).height - print(f"๐Ÿ“ˆ Total Bullish Order Blocks: {bullish_ob_count}") - - if "ob_bearish" in combined_df.columns: - bearish_ob_count = combined_df.filter( - pl.col("ob_bearish").is_not_null() - ).height - print(f"๐Ÿ“‰ Total Bearish Order Blocks: {bearish_ob_count}") - - if "fvg_bullish_start" in combined_df.columns: - bullish_fvg_count = combined_df.filter( - pl.col("fvg_bullish_start").is_not_null() - ).height - print(f"โฌ†๏ธ Total Bullish FVGs: {bullish_fvg_count}") - - if "fvg_bearish_start" in combined_df.columns: - bearish_fvg_count = combined_df.filter( - pl.col("fvg_bearish_start").is_not_null() - ).height - print(f"โฌ‡๏ธ Total Bearish FVGs: {bearish_fvg_count}") - - # Export to CSV for further analysis if needed - output_file = f"orderblock_fvg_analysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" - combined_df.write_csv(output_file) - print(f"\n๐Ÿ’พ Full data exported to: {output_file}") - - print("\nโœ… Verification complete!") - - except Exception as e: - print(f"\nโŒ Error: {e}") - import traceback - - traceback.print_exc() - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 79dfb39..279833e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "project-x-py" -version = "3.1.12" +version = "3.1.13" description = "High-performance Python SDK for futures trading with real-time WebSocket data, technical indicators, order management, and market depth analysis" readme = "README.md" license = { text = "MIT" } @@ -47,6 +47,7 @@ dependencies = [ "lz4>=4.4.4", "cachetools>=6.1.0", "plotly>=6.3.0", + "deprecated>=1.2.18", ] [project.optional-dependencies] @@ -296,6 +297,7 @@ dev = [ "types-pytz>=2025.2.0.20250516", "types-pyyaml>=6.0.12.20250516", "psutil>=7.0.0", + "pyjwt>=2.10.1", ] test = [ "pytest>=8.4.1", diff --git a/src/project_x_py/__init__.py b/src/project_x_py/__init__.py index 7dad095..8b4a0ab 100644 --- a/src/project_x_py/__init__.py +++ b/src/project_x_py/__init__.py @@ -95,7 +95,7 @@ from project_x_py.client.base import ProjectXBase -__version__ = "3.1.12" +__version__ = "3.1.13" __author__ = "TexasCoding" # Core client classes - renamed from Async* to standard names diff --git a/src/project_x_py/client/base.py b/src/project_x_py/client/base.py index 07ddc3e..1bf36b8 100644 --- a/src/project_x_py/client/base.py +++ b/src/project_x_py/client/base.py @@ -115,7 +115,7 @@ def __init__( async def __aenter__(self) -> "ProjectXBase": """Async context manager entry.""" - self._client = await self._create_client() + self._client = await self._create_client() # type: ignore[misc] return self async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: diff --git a/src/project_x_py/client/cache.py b/src/project_x_py/client/cache.py index 19e35ef..c063024 100644 --- a/src/project_x_py/client/cache.py +++ b/src/project_x_py/client/cache.py @@ -1,15 +1,15 @@ """ -Optimized caching with msgpack serialization and lz4 compression for ProjectX. +Optimized caching with Arrow IPC serialization and lz4 compression for ProjectX. This module provides a high-performance caching layer (`CacheMixin`) designed to significantly reduce latency and memory usage for the ProjectX async client. It -replaces standard pickle/JSON serialization with faster and more efficient alternatives. +uses the robust and efficient Arrow IPC format for serialization. Key Features: -- msgpack: For serialization that is 2-5x faster than pickle. +- Arrow IPC: For fast, robust, and type-preserving serialization of DataFrames. - lz4: For high-speed data compression, achieving up to 70% size reduction on market data. -- cachetools: Implements intelligent LRU (Least Recently Used) and TTL (Time-to-Live) - cache eviction policies for instruments and market data respectively. +- cachetools: Implements intelligent TTL (Time-to-Live) cache eviction policies for + both instruments and market data. - Automatic Compression: Data payloads exceeding a configurable threshold (default 1KB) are automatically compressed. - Performance-Tuned: Optimized for handling Polars DataFrames and other data models @@ -17,33 +17,30 @@ """ import gc +import io import logging -import re -import time -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast -import lz4.frame # type: ignore[import-untyped] -import msgpack # type: ignore[import-untyped] +import lz4.frame # type: ignore import polars as pl -from cachetools import LRUCache, TTLCache # type: ignore[import-untyped] +from cachetools import TTLCache # type: ignore from project_x_py.models import Instrument if TYPE_CHECKING: - from project_x_py.types import ProjectXClientProtocol + pass logger = logging.getLogger(__name__) class CacheMixin: """ - High-performance caching with msgpack serialization and lz4 compression. + High-performance caching with Arrow IPC serialization and lz4 compression. This optimized cache provides: - - 2-5x faster serialization with msgpack + - Fast, robust serialization with Arrow IPC format - 70% memory reduction with lz4 compression - - LRU cache for instruments with automatic eviction - - TTL cache for market data with time-based expiry + - TTL cache for instruments and market data with automatic time-based eviction - Compression for large data (> 1KB) - Performance metrics and statistics """ @@ -53,19 +50,16 @@ def __init__(self) -> None: super().__init__() # Cache settings (set early so they can be overridden) - self.cache_ttl = 300 # 5 minutes default - self.last_cache_cleanup = time.time() + self._cache_ttl = 300 # 5 minutes default self.cache_hit_count = 0 - # Internal optimized caches - self._opt_instrument_cache: LRUCache[str, Instrument] = LRUCache(maxsize=1000) - self._opt_instrument_cache_time: dict[str, float] = {} - - # Use cache_ttl for TTLCache + # Internal optimized caches with time-to-live eviction + self._opt_instrument_cache: TTLCache[str, Instrument] = TTLCache( + maxsize=1000, ttl=self._cache_ttl + ) self._opt_market_data_cache: TTLCache[str, bytes] = TTLCache( - maxsize=10000, ttl=self.cache_ttl + maxsize=10000, ttl=self._cache_ttl ) - self._opt_market_data_cache_time: dict[str, float] = {} # Compression settings (configurable) self.compression_threshold = getattr(self, "config", {}).get( @@ -75,39 +69,45 @@ def __init__(self) -> None: "compression_level", 3 ) # lz4 compression level (0-16) + @property + def cache_ttl(self) -> float: + """Get cache TTL value.""" + return self._cache_ttl + + @cache_ttl.setter + def cache_ttl(self, value: float) -> None: + """ + Set cache TTL and recreate caches with new TTL. + + Args: + value: New TTL value in seconds + """ + self._cache_ttl = value + # Recreate caches with new TTL + self._opt_instrument_cache = TTLCache(maxsize=1000, ttl=value) + self._opt_market_data_cache = TTLCache(maxsize=10000, ttl=value) + def _serialize_dataframe(self, df: pl.DataFrame) -> bytes: """ - Serialize Polars DataFrame efficiently using msgpack. + Serialize Polars DataFrame efficiently using the Arrow IPC format. - Optimized for DataFrames with numeric data. + This method is significantly more robust and performant than the previous + msgpack-based serialization. It preserves all data types, including + timezones, without manual conversion. + + Args: + df: The Polars DataFrame to serialize. + + Returns: + The serialized DataFrame as bytes, potentially compressed. """ if df.is_empty(): return b"" - # Convert to dictionary format for msgpack - columns_data = {} - for col in df.columns: - col_data = df[col] - # Convert datetime columns to ISO strings for msgpack serialization - if col_data.dtype in [pl.Datetime, pl.Date]: - columns_data[col] = col_data.dt.to_string( - "%Y-%m-%d %H:%M:%S%.f" - ).to_list() - else: - columns_data[col] = col_data.to_list() - - data = { - "schema": {name: str(dtype) for name, dtype in df.schema.items()}, - "columns": columns_data, - "shape": df.shape, - } - - # Use msgpack for serialization - packed = msgpack.packb( - data, - use_bin_type=True, - default=str, # Fallback for unknown types - ) + # Use Polars' built-in IPC serialization + buffer = io.BytesIO() + df.write_ipc(buffer) + packed = buffer.getvalue() # Compress if data is large if len(packed) > self.compression_threshold: @@ -117,15 +117,22 @@ def _serialize_dataframe(self, df: pl.DataFrame) -> bytes: content_checksum=False, # Skip checksum for speed ) # Add header to indicate compression - result: bytes = b"LZ4" + compressed - return result + return b"LZ4" + compressed - result = b"RAW" + packed - return result + return b"RAW" + packed def _deserialize_dataframe(self, data: bytes) -> pl.DataFrame | None: """ - Deserialize DataFrame from cached bytes. + Deserialize DataFrame from cached bytes using the Arrow IPC format. + + This method correctly handles data serialized by `_serialize_dataframe`, + including decompression and type reconstruction. + + Args: + data: The byte string from the cache. + + Returns: + A deserialized Polars DataFrame, or None if deserialization fails. """ if not data: return None @@ -138,206 +145,93 @@ def _deserialize_dataframe(self, data: bytes) -> pl.DataFrame | None: if header == b"LZ4": try: payload = lz4.frame.decompress(payload) - except Exception: - # Fall back to raw data if decompression fails - payload = data[3:] - elif header == b"RAW": - pass # Already uncompressed - else: - # Legacy uncompressed data - payload = data + except Exception as e: + logger.warning(f"LZ4 decompression failed for cached data: {e}") + return None # Data is corrupt, cannot proceed + elif header != b"RAW": + logger.warning( + f"Unknown cache format header '{header.decode(errors='ignore')}'. " + "Cache may be from an old version. Clearing and refetching is recommended." + ) + return None try: - # Deserialize with msgpack - unpacked = msgpack.unpackb(payload, raw=False) - if not unpacked or "columns" not in unpacked: - return None - - # Reconstruct DataFrame with proper schema - df = pl.DataFrame(unpacked["columns"]) - - # Restore datetime columns based on stored schema - if "schema" in unpacked: - for col_name, dtype_str in unpacked["schema"].items(): - if "datetime" in dtype_str.lower() and col_name in df.columns: - # Parse timezone from dtype string (e.g., "Datetime(time_unit='us', time_zone='UTC')") - time_zone = None - if "time_zone=" in dtype_str: - # Extract timezone - tz_match = re.search(r"time_zone='([^']+)'", dtype_str) - if tz_match: - time_zone = tz_match.group(1) - - # Convert string column to datetime - if df[col_name].dtype == pl.Utf8: - df = df.with_columns( - pl.col(col_name) - .str.strptime( - pl.Datetime("us", time_zone), - "%Y-%m-%d %H:%M:%S%.f", - strict=False, - ) - .alias(col_name) - ) - - return df + # Use Polars' built-in IPC deserialization + buffer = io.BytesIO(payload) + return pl.read_ipc(buffer) except Exception as e: - logger.debug(f"Failed to deserialize DataFrame: {e}") + logger.debug(f"Failed to deserialize DataFrame from IPC format: {e}") return None def get_cached_instrument(self, symbol: str) -> Instrument | None: """ Get cached instrument data if available and not expired. - Compatible with CacheMixin interface. - Args: symbol: Trading symbol Returns: - Cached instrument or None if not found/expired + Cached instrument or None if not found or expired. """ cache_key = symbol.upper() - - # Check TTL expiry for compatibility - if cache_key in self._opt_instrument_cache_time: - cache_age = time.time() - self._opt_instrument_cache_time[cache_key] - if cache_age > self.cache_ttl: - # Expired - remove from cache - if cache_key in self._opt_instrument_cache: - del self._opt_instrument_cache[cache_key] - del self._opt_instrument_cache_time[cache_key] - return None - - # Try optimized cache - if cache_key in self._opt_instrument_cache: + instrument = cast(Instrument | None, self._opt_instrument_cache.get(cache_key)) + if instrument: self.cache_hit_count += 1 - instrument: Instrument = self._opt_instrument_cache[cache_key] return instrument - return None def cache_instrument(self, symbol: str, instrument: Instrument) -> None: """ - Cache instrument data. - - Compatible with CacheMixin interface. + Cache instrument data with a time-to-live. Args: symbol: Trading symbol instrument: Instrument object to cache """ cache_key = symbol.upper() - - # Store in optimized cache self._opt_instrument_cache[cache_key] = instrument - self._opt_instrument_cache_time[cache_key] = time.time() def get_cached_market_data(self, cache_key: str) -> pl.DataFrame | None: """ Get cached market data if available and not expired. - Compatible with CacheMixin interface. - Args: cache_key: Unique key for the cached data Returns: - Cached DataFrame or None if not found/expired + Cached DataFrame or None if not found or expired. """ - # Check TTL expiry for compatibility with dynamic cache_ttl - if cache_key in self._opt_market_data_cache_time: - cache_age = time.time() - self._opt_market_data_cache_time[cache_key] - if cache_age > self.cache_ttl: - # Expired - remove from cache - if cache_key in self._opt_market_data_cache: - del self._opt_market_data_cache[cache_key] - del self._opt_market_data_cache_time[cache_key] - return None - - # Try optimized cache first - if cache_key in self._opt_market_data_cache: - serialized = self._opt_market_data_cache[cache_key] + serialized = self._opt_market_data_cache.get(cache_key) + if serialized: df = self._deserialize_dataframe(serialized) if df is not None: self.cache_hit_count += 1 return df - return None def cache_market_data(self, cache_key: str, data: pl.DataFrame) -> None: """ - Cache market data. - - Compatible with CacheMixin interface. + Cache market data with a time-to-live. Args: cache_key: Unique key for the data data: DataFrame to cache """ - # Serialize and store in optimized cache serialized = self._serialize_dataframe(data) self._opt_market_data_cache[cache_key] = serialized - self._opt_market_data_cache_time[cache_key] = time.time() - - async def _cleanup_cache(self: "ProjectXClientProtocol") -> None: - """ - Clean up expired cache entries to manage memory usage. - - This method is called periodically to remove expired entries. - LRUCache and TTLCache handle their own eviction, but we still - track timestamps for dynamic TTL changes. - """ - current_time = time.time() - - # Clean up timestamp tracking for expired entries - expired_instruments = [ - symbol - for symbol, cache_time in self._opt_instrument_cache_time.items() - if current_time - cache_time > self.cache_ttl - ] - for symbol in expired_instruments: - if symbol in self._opt_instrument_cache: - del self._opt_instrument_cache[symbol] - del self._opt_instrument_cache_time[symbol] - - expired_data = [ - key - for key, cache_time in self._opt_market_data_cache_time.items() - if current_time - cache_time > self.cache_ttl - ] - for key in expired_data: - if key in self._opt_market_data_cache: - del self._opt_market_data_cache[key] - del self._opt_market_data_cache_time[key] - - self.last_cache_cleanup = current_time - - # Force garbage collection if caches were large - if len(expired_instruments) > 10 or len(expired_data) > 10: - gc.collect() def clear_all_caches(self) -> None: """ Clear all cached data. - - Compatible with CacheMixin interface. """ - # Clear optimized caches self._opt_instrument_cache.clear() - self._opt_instrument_cache_time.clear() self._opt_market_data_cache.clear() - self._opt_market_data_cache_time.clear() - - # Reset stats self.cache_hit_count = 0 gc.collect() def get_cache_stats(self) -> dict[str, Any]: """ Get comprehensive cache statistics. - - Extended version with optimization metrics. """ total_hits = self.cache_hit_count @@ -345,13 +239,9 @@ def get_cache_stats(self) -> dict[str, Any]: "cache_hits": total_hits, "instrument_cache_size": len(self._opt_instrument_cache), "market_data_cache_size": len(self._opt_market_data_cache), - "instrument_cache_max": getattr( - self._opt_instrument_cache, "maxsize", 1000 - ), - "market_data_cache_max": getattr( - self._opt_market_data_cache, "maxsize", 10000 - ), + "instrument_cache_max": self._opt_instrument_cache.maxsize, + "market_data_cache_max": self._opt_market_data_cache.maxsize, "compression_enabled": True, - "serialization": "msgpack", + "serialization": "arrow-ipc", "compression": "lz4", } diff --git a/src/project_x_py/client/http.py b/src/project_x_py/client/http.py index 79141df..042aa2d 100644 --- a/src/project_x_py/client/http.py +++ b/src/project_x_py/client/http.py @@ -236,9 +236,6 @@ async def _make_request( message = format_error_message( ErrorMessages.API_RATE_LIMITED, retry_after=retry_after ) - # Ensure test expectation for phrase "rate limited" is satisfied - if "rate limited" not in message.lower(): - message = f"{message} (rate limited)" raise ProjectXRateLimitError(message) # Handle successful responses @@ -311,10 +308,8 @@ async def get_health_status( Get client statistics and performance metrics. Returns: - Dict containing: - - client_stats: Client-side statistics including cache performance - - authenticated: Whether the client is authenticated - - account: Current account name if authenticated + A dictionary containing client-side statistics including API calls + and cache performance. Example: >>> # V3: Get comprehensive performance metrics @@ -322,32 +317,23 @@ async def get_health_status( >>> print(f"API Calls: {status['api_calls']}") >>> print(f"Cache Hits: {status['cache_hits']}") >>> print(f"Cache Hit Ratio: {status['cache_hit_ratio']:.2%}") - >>> print(f"Success Rate: {status['success_rate']:.2%}") + >>> print(f"Total Requests: {status['total_requests']}") >>> print(f"Active Connections: {status['active_connections']}") """ # Calculate client statistics - total_cache_requests = self.cache_hit_count + self.api_call_count - cache_hit_rate = ( - self.cache_hit_count / total_cache_requests - if total_cache_requests > 0 - else 0 + total_requests = self.cache_hit_count + self.api_call_count + cache_hit_ratio = ( + self.cache_hit_count / total_requests if total_requests > 0 else 0 ) - - # Calculate additional metrics for PerformanceStatsResponse cache_misses = self.api_call_count # API calls are essentially cache misses - success_rate = 1.0 # Simplified - would need actual failure tracking - uptime_seconds = 0 # Would need to track session start time return { "api_calls": self.api_call_count, "cache_hits": self.cache_hit_count, "cache_misses": cache_misses, - "cache_hit_ratio": cache_hit_rate, - "avg_response_time_ms": 0.0, # Would need response time tracking - "total_requests": total_cache_requests, - "failed_requests": 0, # Would need failure tracking - "success_rate": success_rate, - "active_connections": 1 if self._authenticated else 0, - "memory_usage_mb": 0.0, # Would need memory monitoring - "uptime_seconds": uptime_seconds, + "cache_hit_ratio": cache_hit_ratio, + "total_requests": total_requests, + "active_connections": 1 + if self._client and not self._client.is_closed + else 0, } diff --git a/src/project_x_py/client/market_data.py b/src/project_x_py/client/market_data.py index 5f167a2..3c62caf 100644 --- a/src/project_x_py/client/market_data.py +++ b/src/project_x_py/client/market_data.py @@ -56,7 +56,6 @@ async def main(): import datetime import re -import time from typing import TYPE_CHECKING, Any import polars as pl @@ -130,18 +129,29 @@ async def get_instrument( # If so, extract the base symbol for searching search_symbol = symbol is_contract_id = False - if symbol.startswith("CON.") and symbol.count(".") >= 3: - is_contract_id = True - # Extract base symbol from contract ID - # CON.F.US.MNQ.U25 -> MNQ - parts = symbol.split(".") - if len(parts) >= 4: - search_symbol = parts[3] - # Remove any month/year suffix (e.g., U25 -> base symbol) - futures_pattern = re.compile(r"^(.+?)([FGHJKMNQUVXZ]\d{1,2})$") - match = futures_pattern.match(search_symbol) - if match: - search_symbol = match.group(1) + if symbol.startswith("CON."): + # Regex to capture the symbol part of a contract ID, e.g., "MNQ.U25" from "CON.F.US.MNQ.U25" + # This is more robust than splitting by '.' and relying on indices. + contract_pattern = re.compile( + r"^CON\.[A-Z]\.[A-Z]{2}\.(?P.+)$" + ) + match = contract_pattern.match(symbol) + if match: + is_contract_id = True + symbol_details = match.group("symbol_details") + # The symbol can be in parts (e.g., "MNQ.U25") or joined (e.g., "MNQU25") + # We only want the base symbol, e.g., "MNQ" + base_symbol_part = symbol_details.split(".")[0] + + # Remove any futures month/year suffix from the base symbol part + futures_pattern = re.compile( + r"^(?P.+?)(?P[FGHJKMNQUVXZ]\d{1,2})$" + ) + futures_match = futures_pattern.match(base_symbol_part) + if futures_match: + search_symbol = futures_match.group("base") + else: + search_symbol = base_symbol_part # Search for instrument payload = {"searchText": search_symbol, "live": live} @@ -187,10 +197,6 @@ async def get_instrument( self.cache_instrument(symbol, instrument) logger.debug(LogMessages.CACHE_UPDATE, extra={"symbol": symbol}) - # Periodic cache cleanup - if time.time() - self.last_cache_cleanup > 3600: # Every hour - await self._cleanup_cache() - return instrument def _select_best_contract( @@ -233,7 +239,7 @@ def _select_best_contract( # First try exact match for inst in instruments: - if inst.get("symbol", "").upper() == search_upper: + if inst.get("name", "").upper() == search_upper: return inst # For futures, try to find the front month @@ -242,8 +248,8 @@ def _select_best_contract( base_symbols: dict[str, list[dict[str, Any]]] = {} for inst in instruments: - symbol = inst.get("symbol", "").upper() - match = futures_pattern.match(symbol) + name = inst.get("name", "").upper() + match = futures_pattern.match(name) if match: base = match.group(1) if base not in base_symbols: @@ -258,9 +264,9 @@ def _select_best_contract( break if matching_base and base_symbols[matching_base]: - # Sort by symbol to get front month (alphabetical = chronological for futures) + # Sort by name to get front month (alphabetical = chronological for futures) sorted_contracts = sorted( - base_symbols[matching_base], key=lambda x: x.get("symbol", "") + base_symbols[matching_base], key=lambda x: x.get("name", "") ) return sorted_contracts[0] @@ -524,8 +530,4 @@ async def get_bars( # Cache the result self.cache_market_data(cache_key, data) - # Cleanup cache periodically - if time.time() - self.last_cache_cleanup > 3600: - await self._cleanup_cache() - return data diff --git a/src/project_x_py/client/trading.py b/src/project_x_py/client/trading.py index 2c446f2..a57c786 100644 --- a/src/project_x_py/client/trading.py +++ b/src/project_x_py/client/trading.py @@ -66,10 +66,12 @@ async def main(): import datetime import logging +import warnings from datetime import timedelta from typing import TYPE_CHECKING import pytz +from deprecated import deprecated # type: ignore from project_x_py.exceptions import ProjectXError from project_x_py.models import Position, Trade @@ -83,69 +85,50 @@ async def main(): class TradingMixin: """Mixin class providing trading functionality.""" + @deprecated( # type: ignore[misc] + "Use search_open_positions() instead. This method will be removed in v4.0.0." + ) async def get_positions(self: "ProjectXClientProtocol") -> list[Position]: """ - Get all open positions for the authenticated account. + DEPRECATED: Get all open positions for the authenticated account. - This method retrieves all current open positions for the currently selected - trading account. It provides a snapshot of the portfolio including position - size, entry price, unrealized P&L, and other position details. + This method is deprecated and will be removed in a future version. + Please use `search_open_positions()` instead, which provides the same + functionality with a more consistent API endpoint. - The method automatically ensures authentication before making the API call - and handles the conversion of raw position data into strongly-typed Position - objects for easier analysis and manipulation. + Args: + self: The client instance. Returns: - list[Position]: List of Position objects representing current holdings, - each containing: - - symbol: Instrument symbol - - quantity: Position size (positive for long, negative for short) - - price: Average entry price - - unrealized_pnl: Current unrealized profit/loss - - contract_id: Unique contract identifier - - position_id: Unique position identifier - - timestamp: Position open time - - Raises: - ProjectXError: If no account is selected or API call fails - ProjectXAuthenticationError: If authentication is required - - Example: - >>> # V3: Get detailed position information - >>> positions = await client.get_positions() - >>> for pos in positions: - >>> print(f"Contract: {pos.contractId}") - >>> print(f" Net Position: {pos.netPos}") - >>> print(f" Buy Avg Price: ${pos.buyAvgPrice:.2f}") - >>> print(f" Sell Avg Price: ${pos.sellAvgPrice:.2f}") - >>> print(f" Unrealized P&L: ${pos.unrealizedPnl:,.2f}") - >>> print(f" Realized P&L: ${pos.realizedPnl:,.2f}") + A list of Position objects representing current holdings. """ - await self._ensure_authenticated() - - if not self.account_info: - raise ProjectXError("No account selected") - - response = await self._make_request( - "GET", f"/accounts/{self.account_info.id}/positions" + warnings.warn( + "get_positions() is deprecated, use search_open_positions() instead. " + "This method will be removed in v4.0.0.", + DeprecationWarning, + stacklevel=2, ) - - if not response or not isinstance(response, list): - return [] - - return [Position(**pos) for pos in response] + return await self.search_open_positions() async def search_open_positions( self: "ProjectXClientProtocol", account_id: int | None = None ) -> list[Position]: """ - Search for open positions across accounts. + Search for open positions for the currently authenticated account. + + This is the recommended method for retrieving all current open positions. + It provides a snapshot of the portfolio including position size, entry price, + unrealized P&L, and other key details. Args: - account_id: Optional account ID to filter positions + account_id: Optional account ID to filter positions. If not provided, + the currently authenticated account's ID will be used. Returns: - List of Position objects + List of Position objects representing current holdings. + + Raises: + ProjectXError: If no account is selected or the API call fails. Example: >>> # V3: Search open positions with P&L calculation @@ -157,12 +140,6 @@ async def search_open_positions( >>> print(f"Total Unrealized P&L: ${total_unrealized:,.2f}") >>> print(f"Total Realized P&L: ${total_realized:,.2f}") >>> print(f"Total P&L: ${total_unrealized + total_realized:,.2f}") - >>> # Group by contract - >>> by_contract = {} - >>> for pos in positions: - >>> if pos.contractId not in by_contract: - >>> by_contract[pos.contractId] = [] - >>> by_contract[pos.contractId].append(pos) """ await self._ensure_authenticated() @@ -178,10 +155,21 @@ async def search_open_positions( "POST", "/Position/searchOpen", data=payload ) - if not response or not response.get("success", False): + # Handle both list response (new API) and dict response (legacy) + if response is None: + return [] + + # If response is a list, use it directly + if isinstance(response, list): + positions_data = response + # If response is a dict with success/positions structure + elif isinstance(response, dict): + if not response.get("success", False): + return [] + positions_data = response.get("positions", []) + else: return [] - positions_data = response.get("positions", []) return [Position(**pos) for pos in positions_data] async def search_trades( diff --git a/src/project_x_py/data/mmap_storage.py b/src/project_x_py/data/mmap_storage.py index 3d93398..1a7ff21 100644 --- a/src/project_x_py/data/mmap_storage.py +++ b/src/project_x_py/data/mmap_storage.py @@ -5,8 +5,11 @@ allowing efficient access to data without loading everything into RAM. """ +import logging import mmap import pickle +import tempfile +import threading from io import BufferedRandom, BufferedReader from pathlib import Path from typing import Any @@ -14,6 +17,8 @@ import numpy as np import polars as pl +logger = logging.getLogger(__name__) + class MemoryMappedStorage: """ @@ -41,13 +46,8 @@ def __init__(self, filename: str | Path, mode: str = "r+b"): self.mmap: mmap.mmap | None = None self._metadata: dict[str, Any] = {} self._file_size = 1024 * 1024 * 10 # Start with 10MB - - # Create file if it doesn't exist (unless read-only) - if not self.filename.exists() and "+" in mode: - self.filename.parent.mkdir(parents=True, exist_ok=True) - # Pre-allocate file with initial size - with open(self.filename, "wb") as f: - f.write(b"\x00" * self._file_size) + self._data_file_size = 0 + self._lock = threading.RLock() def __enter__(self) -> "MemoryMappedStorage": """Context manager entry.""" @@ -60,13 +60,26 @@ def __exit__(self, *args: Any) -> None: def open(self) -> None: """Open the memory-mapped file.""" - if self.fp is None: - self.fp = open(self.filename, self.mode) # type: ignore # noqa: SIM115 + with self._lock: + if self.fp is not None: + return + + # Create file if it doesn't exist (unless read-only) + if not self.filename.exists() and ("+" in self.mode or "w" in self.mode): + self.filename.parent.mkdir(parents=True, exist_ok=True) + # Pre-allocate file with initial size + with open(self.filename, "wb") as f: + f.write(b"\x00" * self._file_size) + + self.fp = open(self.filename, self.mode) # type: ignore # noqa: SIM115 + + if self.fp is None: + raise ValueError("File pointer is None") - if self.fp is not None: # Get file size self.fp.seek(0, 2) # Seek to end size = self.fp.tell() + self._data_file_size = size if size == 0 and ("+" in self.mode or "w" in self.mode): # Initialize empty file with default size @@ -74,43 +87,47 @@ def open(self) -> None: self.fp.flush() self.fp.seek(0) size = self._file_size + self._data_file_size = size if size > 0: # Use ACCESS_READ for read-only mode - if "r" in self.mode and "+" not in self.mode: - self.mmap = mmap.mmap(self.fp.fileno(), 0, access=mmap.ACCESS_READ) - else: - self.mmap = mmap.mmap(self.fp.fileno(), 0) + access = ( + mmap.ACCESS_READ + if "r" in self.mode and "+" not in self.mode + else mmap.ACCESS_DEFAULT + ) + if self.fp: + self.mmap = mmap.mmap(self.fp.fileno(), 0, access=access) def close(self) -> None: """Close the memory-mapped file.""" - if self.mmap: - self.mmap.close() - self.mmap = None - if self.fp: - self.fp.close() - self.fp = None + with self._lock: + if self.mmap: + self.mmap.close() + self.mmap = None + if self.fp: + self.fp.close() + self.fp = None def _resize_file(self, new_size: int) -> None: """Resize the file and recreate mmap (for macOS compatibility).""" - # Close existing mmap + # This method should be called within a lock if self.mmap: self.mmap.close() if self.fp is None: raise ValueError("File pointer is None") - # Resize the file - self.fp.seek(0, 2) # Go to end - current_size = self.fp.tell() + self.fp.truncate(new_size) + self.fp.flush() - if new_size > current_size: - # Extend the file - self.fp.write(b"\\x00" * (new_size - current_size)) - self.fp.flush() - - # Recreate mmap with new size - self.mmap = mmap.mmap(self.fp.fileno(), 0) + access = ( + mmap.ACCESS_READ + if "r" in self.mode and "+" not in self.mode + else mmap.ACCESS_DEFAULT + ) + self.mmap = mmap.mmap(self.fp.fileno(), 0, access=access) + self._data_file_size = new_size def write_array(self, data: np.ndarray, offset: int = 0) -> int: """ @@ -123,36 +140,36 @@ def write_array(self, data: np.ndarray, offset: int = 0) -> int: Returns: Number of bytes written """ - if not self.mmap: - self.open() - - # Serialize array metadata - metadata = {"dtype": str(data.dtype), "shape": data.shape, "offset": offset} - - # Convert to bytes - data_bytes = data.tobytes() - metadata_bytes = pickle.dumps(metadata) - - # Write metadata size (4 bytes), metadata, then data - size_bytes = len(metadata_bytes).to_bytes(4, "little") - - # Check if we need more space - total_size = offset + 4 + len(metadata_bytes) + len(data_bytes) - if self.mmap and total_size > len(self.mmap): - # On macOS, we can't resize mmap, so we need to recreate it - self._resize_file(total_size) - elif not self.mmap: - self.open() - if self.mmap and total_size > len(self.mmap): + with self._lock: + if not self.mmap: + self.open() + + if not self.mmap: + raise OSError("Memory map not available") + + # Serialize array metadata + metadata = {"dtype": str(data.dtype), "shape": data.shape, "offset": offset} + + # Convert to bytes + data_bytes = data.tobytes() + metadata_bytes = pickle.dumps(metadata) + + # Write metadata size (4 bytes), metadata, then data + size_bytes = len(metadata_bytes).to_bytes(4, "little") + + # Check if we need more space + total_size = offset + 4 + len(metadata_bytes) + len(data_bytes) + if total_size > self._data_file_size: + # On macOS, we can't resize mmap, so we need to recreate it self._resize_file(total_size) - # Write to mmap - if self.mmap: + # Write to mmap self.mmap[offset : offset + 4] = size_bytes self.mmap[offset + 4 : offset + 4 + len(metadata_bytes)] = metadata_bytes self.mmap[offset + 4 + len(metadata_bytes) : total_size] = data_bytes self.mmap.flush() - return total_size - offset + + return total_size - offset def read_array(self, offset: int = 0) -> np.ndarray | None: """ @@ -164,37 +181,65 @@ def read_array(self, offset: int = 0) -> np.ndarray | None: Returns: NumPy array or None if not found """ - if not self.mmap: - self.open() + with self._lock: + if not self.mmap: + self.open() - if not self.mmap: - return None + if not self.mmap: + return None - try: - # Read metadata size - size_bytes = self.mmap[offset : offset + 4] - metadata_size = int.from_bytes(size_bytes, "little") + try: + # Read metadata size + size_bytes = self.mmap[offset : offset + 4] + metadata_size = int.from_bytes(size_bytes, "little") - # Read metadata - metadata_bytes = self.mmap[offset + 4 : offset + 4 + metadata_size] - metadata = pickle.loads(metadata_bytes) + # Read metadata + metadata_bytes = self.mmap[offset + 4 : offset + 4 + metadata_size] + metadata = pickle.loads(metadata_bytes) - # Calculate data size - dtype = np.dtype(metadata["dtype"]) - shape = metadata["shape"] - data_size = dtype.itemsize * np.prod(shape) + # Calculate data size + dtype = np.dtype(metadata["dtype"]) + shape = metadata["shape"] + data_size = dtype.itemsize * np.prod(shape) - # Read data - data_start = offset + 4 + metadata_size - data_bytes = self.mmap[data_start : data_start + data_size] + # Read data + data_start = offset + 4 + metadata_size + data_bytes = self.mmap[data_start : data_start + data_size] - # Convert to array - array = np.frombuffer(data_bytes, dtype=dtype).reshape(shape) - return array.copy() # Return copy to avoid mmap issues + # Convert to array + array = np.frombuffer(data_bytes, dtype=dtype).reshape(shape) + return array.copy() # Return copy to avoid mmap issues - except Exception as e: - print(f"Error reading array: {e}") - return None + except Exception: + logger.exception("Error reading array at offset %d", offset) + return None + + def _load_metadata(self) -> None: + # should be called within a lock + if not self._metadata: + metadata_file = self.filename.with_suffix(".meta") + if metadata_file.exists(): + try: + with open(metadata_file, "rb") as f: + self._metadata = pickle.load(f) + except (pickle.UnpicklingError, EOFError): + logger.exception( + "Could not load metadata from %s, file might be corrupt.", + metadata_file, + ) + self._metadata = {} + + def _save_metadata(self) -> None: + # should be called within a lock + metadata_file = self.filename.with_suffix(".meta") + # Safe save: write to temp file then rename + with tempfile.NamedTemporaryFile( + "wb", delete=False, dir=metadata_file.parent + ) as tmp_f: + pickle.dump(self._metadata, tmp_f) + tmp_path = Path(tmp_f.name) + + tmp_path.rename(metadata_file) def write_dataframe(self, df: pl.DataFrame, key: str = "default") -> bool: """ @@ -207,48 +252,48 @@ def write_dataframe(self, df: pl.DataFrame, key: str = "default") -> bool: Returns: Success status """ - try: - # Load existing metadata if present - metadata_file = self.filename.with_suffix(".meta") - if metadata_file.exists(): - with open(metadata_file, "rb") as f: - self._metadata = pickle.load(f) - - # Calculate starting offset (after existing data) - offset = 0 - for existing_key, existing_data in self._metadata.items(): - if existing_key != key and "columns" in existing_data: - for col_info in existing_data["columns"].values(): - offset = max(offset, col_info["offset"] + col_info["size"]) - - # Convert DataFrame to dict format - data: dict[str, Any] = { - "schema": {name: str(dtype) for name, dtype in df.schema.items()}, - "columns": {}, - "shape": df.shape, - "key": key, - } - - # Store each column as NumPy array - for col_name in df.columns: - col_data = df[col_name].to_numpy() - bytes_written = self.write_array(col_data, offset) - data["columns"][col_name] = {"offset": offset, "size": bytes_written} - offset += bytes_written - - # Store metadata - self._metadata[key] = data - - # Write metadata to a separate file - metadata_file = self.filename.with_suffix(".meta") - with open(metadata_file, "wb") as f: - pickle.dump(self._metadata, f) - - return True - - except Exception as e: - print(f"Error writing DataFrame: {e}") - return False + with self._lock: + try: + if not self.mmap: + self.open() + + # Load existing metadata if present + self._load_metadata() + + # Calculate starting offset (after existing data) + offset = self._data_file_size + + # Convert DataFrame to dict format + data: dict[str, Any] = { + "schema": {name: str(dtype) for name, dtype in df.schema.items()}, + "columns": {}, + "shape": df.shape, + "key": key, + } + + # Store each column as NumPy array + col_offset = offset + for col_name in df.columns: + col_data = df[col_name].to_numpy() + bytes_written = self.write_array(col_data, col_offset) + data["columns"][col_name] = { + "offset": col_offset, + "size": bytes_written, + } + col_offset += bytes_written + + # Store metadata + self._metadata[key] = data + self._save_metadata() + + # Update data file size tracker + self._data_file_size = col_offset + + return True + + except Exception: + logger.exception("Error writing DataFrame with key '%s'", key) + return False def read_dataframe(self, key: str = "default") -> pl.DataFrame | None: """ @@ -260,32 +305,29 @@ def read_dataframe(self, key: str = "default") -> pl.DataFrame | None: Returns: Polars DataFrame or None if not found """ - try: - # Load metadata if not already loaded - if not self._metadata: - metadata_file = self.filename.with_suffix(".meta") - if metadata_file.exists(): - with open(metadata_file, "rb") as f: - self._metadata = pickle.load(f) + with self._lock: + try: + # Load metadata if not already loaded + self._load_metadata() - if key not in self._metadata: - return None + if key not in self._metadata: + return None - metadata = self._metadata[key] + metadata = self._metadata[key] - # Read each column - columns = {} - for col_name, col_info in metadata["columns"].items(): - array = self.read_array(col_info["offset"]) - if array is not None: - columns[col_name] = array + # Read each column + columns = {} + for col_name, col_info in metadata["columns"].items(): + array = self.read_array(col_info["offset"]) + if array is not None: + columns[col_name] = array - # Reconstruct DataFrame - return pl.DataFrame(columns) + # Reconstruct DataFrame + return pl.DataFrame(columns) - except Exception as e: - print(f"Error reading DataFrame: {e}") - return None + except Exception: + logger.exception("Error reading DataFrame with key '%s'", key) + return None def get_info(self) -> dict[str, Any]: """ @@ -294,17 +336,21 @@ def get_info(self) -> dict[str, Any]: Returns: Dictionary with storage information """ - info = { - "filename": str(self.filename), - "exists": self.filename.exists(), - "size_mb": 0, - "keys": list(self._metadata.keys()) if self._metadata else [], - } + with self._lock: + # Load metadata if not already loaded + self._load_metadata() - if self.filename.exists(): - info["size_mb"] = self.filename.stat().st_size / (1024 * 1024) + info = { + "filename": str(self.filename), + "exists": self.filename.exists(), + "size_mb": 0, + "keys": list(self._metadata.keys()), + } + + if self.filename.exists(): + info["size_mb"] = self.filename.stat().st_size / (1024 * 1024) - return info + return info class TimeSeriesStorage(MemoryMappedStorage): @@ -327,9 +373,15 @@ def __init__( """ super().__init__(filename, "r+b") self.columns = columns - self.dtype = dtype + self.dtype = np.dtype(dtype) + self.row_size = (len(self.columns) + 1) * self.dtype.itemsize self.current_size = 0 - self.chunk_size = 10000 # Records per chunk + + # Determine current size from existing file + if self.filename.exists(): + self.open() # open() is idempotent and thread-safe + if self._data_file_size > 0 and self.row_size > 0: + self.current_size = self._data_file_size // self.row_size def append_data(self, timestamp: float, values: dict[str, float]) -> bool: """ @@ -342,42 +394,53 @@ def append_data(self, timestamp: float, values: dict[str, float]) -> bool: Returns: Success status """ - try: - if not self.mmap: - self.open() + with self._lock: + try: + if not self.mmap: + self.open() - # Create row array - row: np.ndarray = np.zeros(len(self.columns) + 1, dtype=self.dtype) - row[0] = timestamp + if not self.mmap: + raise OSError("Memory map not available") - for i, col in enumerate(self.columns): - if col in values: - row[i + 1] = values[col] + # Create row array + row: np.ndarray = np.zeros(len(self.columns) + 1, dtype=self.dtype) + row[0] = timestamp - # Calculate offset - offset = self.current_size * row.nbytes + for i, col in enumerate(self.columns): + if col in values: + row[i + 1] = values[col] - # Check if we need more space - if self.mmap and offset + row.nbytes > len(self.mmap): - new_size = max(offset + row.nbytes, len(self.mmap) * 2) - self._resize_file(new_size) - elif not self.mmap: - self.open() - if self.mmap and offset + row.nbytes > len(self.mmap): - new_size = max(offset + row.nbytes, len(self.mmap) * 2) + # Calculate offset + offset = self.current_size * self.row_size + + # Check if we need more space + if offset + self.row_size > self._data_file_size: + new_size = max(offset + self.row_size, self._data_file_size * 2) self._resize_file(new_size) - # Write row directly to mmap - if self.mmap: - self.mmap[offset : offset + row.nbytes] = row.tobytes() - self.mmap.flush() - self.current_size += 1 + # Write row directly to mmap + if self.mmap: + self.mmap[offset : offset + self.row_size] = row.tobytes() + self.mmap.flush() + self.current_size += 1 - return True + return True - except Exception as e: - print(f"Error appending data: {e}") - return False + except Exception: + logger.exception("Error appending data") + return False + + def _get_row(self, index: int) -> np.ndarray | None: + """Reads a single row by index.""" + if not self.mmap or index < 0 or index >= self.current_size: + return None + + offset = index * self.row_size + if offset + self.row_size > len(self.mmap): + return None + + row_bytes = self.mmap[offset : offset + self.row_size] + return np.frombuffer(row_bytes, dtype=self.dtype, count=len(self.columns) + 1) def read_window(self, start_time: float, end_time: float) -> pl.DataFrame | None: """ @@ -390,42 +453,55 @@ def read_window(self, start_time: float, end_time: float) -> pl.DataFrame | None Returns: DataFrame with data in the window """ - try: - # Read all data directly from mmap (don't use read_array which expects pickle) - if not self.mmap: - self.open() - - if not self.mmap: + with self._lock: + try: + if not self.mmap: + self.open() + + if not self.mmap or self.current_size == 0: + return None + + # Binary search to find the first row >= start_time + low, high = 0, self.current_size - 1 + start_index = self.current_size + + while low <= high: + mid = (low + high) // 2 + row = self._get_row(mid) + if row is not None and row[0] >= start_time: + start_index = mid + high = mid - 1 + elif row is not None: + low = mid + 1 + else: # Should not happen in this loop + break + + if start_index >= self.current_size: + return None # No data in the window + + # Read data sequentially from start_index + all_data = [] + for i in range(start_index, self.current_size): + row = self._get_row(i) + if row is not None: + if row[0] > end_time: + break + all_data.append(row) + + if not all_data: + return None + + # Convert to DataFrame + data_array = np.vstack(all_data) + df_dict = {"timestamp": data_array[:, 0]} + + for i, col in enumerate(self.columns): + df_dict[col] = data_array[:, i + 1] + + return pl.DataFrame(df_dict) + + except Exception: + logger.exception( + "Error reading window from %f to %f", start_time, end_time + ) return None - - all_data = [] - row_size = (len(self.columns) + 1) * np.dtype(self.dtype).itemsize - - for i in range(self.current_size): - offset = i * row_size - - # Read raw bytes and convert to array - if self.mmap and offset + row_size <= len(self.mmap): - row_bytes = self.mmap[offset : offset + row_size] - row: np.ndarray = np.frombuffer( - row_bytes, dtype=self.dtype, count=len(self.columns) + 1 - ) - - if row is not None and start_time <= row[0] <= end_time: - all_data.append(row) - - if not all_data: - return None - - # Convert to DataFrame - data_array = np.vstack(all_data) - df_dict = {"timestamp": data_array[:, 0]} - - for i, col in enumerate(self.columns): - df_dict[col] = data_array[:, i + 1] - - return pl.DataFrame(df_dict) - - except Exception as e: - print(f"Error reading window: {e}") - return None diff --git a/src/project_x_py/event_bus.py b/src/project_x_py/event_bus.py index 24441a8..eb9fd90 100644 --- a/src/project_x_py/event_bus.py +++ b/src/project_x_py/event_bus.py @@ -85,7 +85,7 @@ def __init__(self, type: EventType | str, data: Any, source: str | None = None): self.type = type # keep raw string self.data = data self.source = source - self.timestamp = asyncio.get_event_loop().time() + self.timestamp = asyncio.get_running_loop().time() class EventBus: diff --git a/src/project_x_py/indicators/__init__.py b/src/project_x_py/indicators/__init__.py index 03744c7..cf6dd31 100644 --- a/src/project_x_py/indicators/__init__.py +++ b/src/project_x_py/indicators/__init__.py @@ -202,7 +202,7 @@ ) # Version info -__version__ = "3.1.12" +__version__ = "3.1.13" __author__ = "TexasCoding" diff --git a/src/project_x_py/order_manager/bracket_orders.py b/src/project_x_py/order_manager/bracket_orders.py index 921d8ba..db93069 100644 --- a/src/project_x_py/order_manager/bracket_orders.py +++ b/src/project_x_py/order_manager/bracket_orders.py @@ -215,55 +215,104 @@ async def place_bracket_order( ) if not entry_response or not entry_response.success: - raise ProjectXOrderError("Failed to place entry order") + raise ProjectXOrderError("Failed to place entry order for bracket.") - # Place stop loss (opposite side) - stop_side = 1 if side == 0 else 0 - stop_response = await self.place_stop_order( - contract_id, stop_side, size, stop_loss_price, account_id - ) - - # Place take profit (opposite side) - target_response = await self.place_limit_order( - contract_id, stop_side, size, take_profit_price, account_id + entry_order_id = entry_response.orderId + logger.info( + f"Bracket entry order {entry_order_id} placed. Waiting for fill..." ) - # Create bracket response - bracket_response = BracketOrderResponse( - success=True, - entry_order_id=entry_response.orderId, - stop_order_id=stop_response.orderId if stop_response else None, - target_order_id=target_response.orderId if target_response else None, - entry_price=entry_price if entry_price else 0.0, - stop_loss_price=stop_loss_price if stop_loss_price else 0.0, - take_profit_price=take_profit_price if take_profit_price else 0.0, - entry_response=entry_response, - stop_response=stop_response, - target_response=target_response, - error_message=None, + # Wait for the entry order to fill + is_filled = await self._wait_for_order_fill( + entry_order_id, timeout_seconds=60 ) - # Track bracket relationship - self.position_orders[contract_id]["entry_orders"].append( - entry_response.orderId - ) - if stop_response: - self.position_orders[contract_id]["stop_orders"].append( - stop_response.orderId + if not is_filled: + logger.warning( + f"Bracket entry order {entry_order_id} did not fill. Cancelling." ) - if target_response: - self.position_orders[contract_id]["target_orders"].append( - target_response.orderId + try: + await self.cancel_order(entry_order_id, account_id) + except ProjectXOrderError as e: + logger.error( + f"Failed to cancel unfilled bracket entry order {entry_order_id}: {e}" + ) + raise ProjectXOrderError( + f"Bracket entry order {entry_order_id} did not fill." ) - self.stats["bracket_orders"] += 1 logger.info( - f"โœ… Bracket order placed: Entry={entry_response.orderId}, " - f"Stop={stop_response.orderId if stop_response else 'None'}, " - f"Target={target_response.orderId if target_response else 'None'}" + f"Bracket entry order {entry_order_id} filled. Placing protective orders." ) - return bracket_response + stop_response = None + target_response = None + try: + # Place stop loss (opposite side) + stop_side = 1 if side == 0 else 0 + stop_response = await self.place_stop_order( + contract_id, stop_side, size, stop_loss_price, account_id + ) + + # Place take profit (opposite side) + target_response = await self.place_limit_order( + contract_id, stop_side, size, take_profit_price, account_id + ) + + if ( + not stop_response + or not stop_response.success + or not target_response + or not target_response.success + ): + raise ProjectXOrderError( + "Failed to place one or both protective orders." + ) + + # Link the two protective orders for OCO + stop_order_id = stop_response.orderId + target_order_id = target_response.orderId + self._link_oco_orders(stop_order_id, target_order_id) + + # Track all orders for the position + await self.track_order_for_position( + contract_id, entry_order_id, "entry" + ) + await self.track_order_for_position(contract_id, stop_order_id, "stop") + await self.track_order_for_position( + contract_id, target_order_id, "target" + ) + + self.stats["bracket_orders"] += 1 + logger.info( + f"โœ… Bracket order completed: Entry={entry_order_id}, Stop={stop_order_id}, Target={target_order_id}" + ) + + return BracketOrderResponse( + success=True, + entry_order_id=entry_order_id, + stop_order_id=stop_order_id, + target_order_id=target_order_id, + entry_price=entry_price, + stop_loss_price=stop_loss_price, + take_profit_price=take_profit_price, + entry_response=entry_response, + stop_response=stop_response, + target_response=target_response, + error_message=None, + ) + except Exception as e: + logger.error( + f"Failed to place protective orders for filled entry {entry_order_id}: {e}. Closing position." + ) + await self.close_position(contract_id, account_id=account_id) + if stop_response and stop_response.success: + await self.cancel_order(stop_response.orderId) + if target_response and target_response.success: + await self.cancel_order(target_response.orderId) + raise ProjectXOrderError( + f"Failed to place protective orders: {e}" + ) from e except Exception as e: logger.error(f"Failed to place bracket order: {e}") diff --git a/src/project_x_py/order_manager/tracking.py b/src/project_x_py/order_manager/tracking.py index 9c1bbc3..5ce5893 100644 --- a/src/project_x_py/order_manager/tracking.py +++ b/src/project_x_py/order_manager/tracking.py @@ -82,6 +82,10 @@ class OrderTrackingMixin: _realtime_enabled: bool event_bus: Any # EventBus instance + async def cancel_order( + self, order_id: int, account_id: int | None = None + ) -> bool: ... + def __init__(self) -> None: """Initialize tracking attributes.""" # Internal order state tracking (for realtime optimization) @@ -95,6 +99,14 @@ def __init__(self) -> None: lambda: {"stop_orders": [], "target_orders": [], "entry_orders": []} ) self.order_to_position: dict[int, str] = {} # order_id -> contract_id + self.oco_groups: dict[int, int] = {} # order_id -> other_order_id + + def _link_oco_orders( + self: "OrderManagerProtocol", order1_id: int, order2_id: int + ) -> None: + """Links two orders for OCO cancellation.""" + self.oco_groups[order1_id] = order2_id + self.oco_groups[order2_id] = order1_id async def _setup_realtime_callbacks(self) -> None: """Set up callbacks for real-time order monitoring.""" @@ -170,15 +182,45 @@ async def _on_order_update(self, order_data: dict[str, Any] | list[Any]) -> None } if new_status in status_events: - await self._trigger_callbacks( - status_events[new_status], - { - "order_id": order_id, - "order_data": actual_order_data, + from project_x_py.models import Order + + try: + order_obj = Order(**actual_order_data) + event_payload = { + "order": order_obj, + "order_id": order_id, # Add order_id for compatibility "old_status": old_status, "new_status": new_status, - }, - ) + } + await self._trigger_callbacks( + status_events[new_status], event_payload + ) + except Exception as e: + logger.error( + f"Failed to create Order object from data: {e}", + extra={"order_data": actual_order_data}, + ) + + # OCO Logic: If a linked order is filled, cancel the other. + if new_status == 2: # Filled + order_id_int = int(order_id) + if order_id_int in self.oco_groups: + other_order_id = self.oco_groups[order_id_int] + logger.info( + f"Order {order_id_int} filled, cancelling OCO sibling {other_order_id}." + ) + try: + # Use create_task to avoid blocking the event handler + asyncio.create_task(self.cancel_order(other_order_id)) # noqa: RUF006 + except Exception as e: + logger.error( + f"Failed to cancel OCO order {other_order_id}: {e}" + ) + + # Clean up OCO group + del self.oco_groups[order_id_int] + if other_order_id in self.oco_groups: + del self.oco_groups[other_order_id] # Check for partial fills fills = actual_order_data.get("fills", []) @@ -345,3 +387,70 @@ def get_realtime_validation_status(self: "OrderManagerProtocol") -> dict[str, An "position_links": len(self.order_to_position), "monitored_positions": len(self.position_orders), } + + async def _wait_for_order_fill( + self: "OrderManagerProtocol", order_id: int, timeout_seconds: int = 30 + ) -> bool: + """Waits for an order to fill using an event-driven approach.""" + fill_event = asyncio.Event() + is_filled = False + + async def fill_handler(event: Any) -> None: + nonlocal is_filled + # Extract data from Event object + event_data = event.data if hasattr(event, "data") else event + if isinstance(event_data, dict): + # Check both direct order_id and order.id from Order object + event_order_id = event_data.get("order_id") + if not event_order_id and "order" in event_data: + order_obj = event_data.get("order") + if hasattr(order_obj, "id"): + event_order_id = order_obj.id + if event_order_id == order_id: + is_filled = True + fill_event.set() + + async def terminal_handler(event: Any) -> None: + nonlocal is_filled + # Extract data from Event object + event_data = event.data if hasattr(event, "data") else event + if isinstance(event_data, dict): + # Check both direct order_id and order.id from Order object + event_order_id = event_data.get("order_id") + if not event_order_id and "order" in event_data: + order_obj = event_data.get("order") + if hasattr(order_obj, "id"): + event_order_id = order_obj.id + if event_order_id == order_id: + is_filled = False + fill_event.set() + + from project_x_py.event_bus import EventType + + await self.event_bus.on(EventType.ORDER_FILLED, fill_handler) + await self.event_bus.on(EventType.ORDER_CANCELLED, terminal_handler) + await self.event_bus.on(EventType.ORDER_REJECTED, terminal_handler) + await self.event_bus.on(EventType.ORDER_EXPIRED, terminal_handler) + + try: + await asyncio.wait_for(fill_event.wait(), timeout=timeout_seconds) + except TimeoutError: + logger.warning(f"Timeout waiting for order {order_id} to fill/terminate.") + is_filled = False + finally: + # Clean up the event handlers + if hasattr(self.event_bus, "remove_callback"): + await self.event_bus.remove_callback( + EventType.ORDER_FILLED, fill_handler + ) + await self.event_bus.remove_callback( + EventType.ORDER_CANCELLED, terminal_handler + ) + await self.event_bus.remove_callback( + EventType.ORDER_REJECTED, terminal_handler + ) + await self.event_bus.remove_callback( + EventType.ORDER_EXPIRED, terminal_handler + ) + + return is_filled diff --git a/src/project_x_py/order_tracker.py b/src/project_x_py/order_tracker.py index a501fff..b692b0e 100644 --- a/src/project_x_py/order_tracker.py +++ b/src/project_x_py/order_tracker.py @@ -54,9 +54,12 @@ import asyncio import logging +import warnings from types import TracebackType from typing import TYPE_CHECKING, Any, Union +from typing_extensions import deprecated + from project_x_py.event_bus import EventType from project_x_py.models import BracketOrderResponse, Order, OrderPlaceResponse @@ -127,15 +130,17 @@ async def _setup_event_handlers(self) -> None: # Handler for order fills async def on_fill(data: dict[str, Any]) -> None: - if data.get("order_id") == self.order_id: - self._filled_order = data.get("order_data") + order = data.get("order") + if order and order.id == self.order_id: + self._filled_order = order self._current_status = 2 # FILLED self._fill_event.set() # Handler for status changes async def on_status_change(data: dict[str, Any]) -> None: - if data.get("order_id") == self.order_id: - new_status = data.get("new_status") + order = data.get("order") + if order and order.id == self.order_id: + new_status = order.status self._current_status = new_status # Set status-specific events @@ -252,32 +257,38 @@ async def wait_for_status(self, status: int, timeout: float = 30.0) -> Order: if status not in self._status_events: self._status_events[status] = asyncio.Event() - # Check if already at target status + # Check current status before waiting if self._current_status == status: order = await self.order_manager.get_order_by_id(self.order_id) - if order: + if order and order.status == status: return order + # Wait for the event try: await asyncio.wait_for(self._status_events[status].wait(), timeout=timeout) - - if self._error and status != self._current_status: - raise self._error - - # Fetch latest order data + except TimeoutError: + # After timeout, check the status one last time via API order = await self.order_manager.get_order_by_id(self.order_id) if order and order.status == status: return order - else: - raise OrderLifecycleError( - f"Status event received but order not in expected state {status}" - ) - - except TimeoutError: raise TimeoutError( f"Order {self.order_id} did not reach status {status} within {timeout} seconds" ) from None + # After event is received + if self._error and status != self._current_status: + raise self._error + + order = await self.order_manager.get_order_by_id(self.order_id) + if order and order.status == status: + return order + else: + # This can happen if event fires but API state is not yet consistent, + # or if another status update arrived quickly. + raise OrderLifecycleError( + f"Status event received but order not in expected state {status}. Current state: {order.status if order else 'not found'}" + ) + async def modify_or_cancel( self, new_price: float | None = None, new_size: int | None = None ) -> bool: @@ -429,10 +440,9 @@ def with_take_profit( self, offset: float | None = None, price: float | None = None, - limit: bool = True, ) -> "OrderChainBuilder": """Add a take profit to the order chain.""" - self.take_profit = {"offset": offset, "price": price, "limit": limit} + self.take_profit = {"offset": offset, "price": price} return self def with_trail_stop( @@ -519,9 +529,33 @@ async def execute(self) -> BracketOrderResponse: ) # Add trailing stop if configured - if self.trail_stop and result.success and result.entry_order_id: - # TODO: Implement trailing stop order - logger.warning("Trailing stop orders not yet implemented") + if self.trail_stop and result.success and result.stop_order_id: + logger.info( + f"Replacing stop order {result.stop_order_id} with trailing stop." + ) + try: + await self.order_manager.cancel_order(result.stop_order_id) + trail_offset = self.trail_stop["offset"] + stop_side = 1 if self.side == 0 else 0 # Opposite of entry + + trail_response = await self.order_manager.place_trailing_stop_order( + contract_id=contract_id, + side=stop_side, + size=self.size, + trail_price=trail_offset, + ) + if trail_response.success: + logger.info( + f"Trailing stop order placed: {trail_response.orderId}" + ) + # Note: The BracketOrderResponse does not have a field for the trailing stop ID. + # The original stop_order_id will remain in the response. + else: + logger.error( + f"Failed to place trailing stop: {trail_response.errorMessage}" + ) + except Exception as e: + logger.error(f"Error replacing stop with trailing stop: {e}") return result @@ -571,6 +605,9 @@ class OrderLifecycleError(Exception): # Convenience function for creating order trackers +@deprecated( + "Use TradingSuite.track_order() instead. This function will be removed in v4.0.0." +) def track_order( trading_suite: "TradingSuite", order: Union[Order, OrderPlaceResponse, int] | None = None, @@ -593,6 +630,12 @@ def track_order( filled = await tracker.wait_for_fill() ``` """ + warnings.warn( + "track_order() is deprecated, use TradingSuite.track_order() instead. " + "This function will be removed in v4.0.0", + DeprecationWarning, + stacklevel=2, + ) tracker = OrderTracker(trading_suite) if order: if isinstance(order, Order | OrderPlaceResponse): diff --git a/src/project_x_py/position_manager/analytics.py b/src/project_x_py/position_manager/analytics.py index 27afa68..576e851 100644 --- a/src/project_x_py/position_manager/analytics.py +++ b/src/project_x_py/position_manager/analytics.py @@ -302,17 +302,6 @@ async def calculate_portfolio_pnl( total_value += value pnl_values.append(pnl) - # Calculate win/loss metrics directly from PnL values - winning_pnls = [pnl for pnl in pnl_values if pnl > 0] - losing_pnls = [pnl for pnl in pnl_values if pnl < 0] - - win_rate = len(winning_pnls) / len(positions) if positions else 0.0 - avg_win = sum(winning_pnls) / len(winning_pnls) if winning_pnls else 0.0 - avg_loss = sum(losing_pnls) / len(losing_pnls) if losing_pnls else 0.0 - largest_win = max(winning_pnls, default=0.0) - largest_loss = min(losing_pnls, default=0.0) - - profit_factor = abs(avg_win / avg_loss) if avg_loss < 0 else 0.0 total_return = (total_pnl / total_value * 100) if total_value > 0 else 0.0 from datetime import datetime @@ -320,7 +309,7 @@ async def calculate_portfolio_pnl( return { "total_value": total_value, "total_pnl": total_pnl, - "realized_pnl": 0.0, # Would need historical trade data + "realized_pnl": 0.0, # This is calculated and stored in stats now "unrealized_pnl": total_pnl, # All P&L is unrealized in this context "daily_pnl": 0.0, # Would need daily data "weekly_pnl": 0.0, # Would need weekly data @@ -328,18 +317,18 @@ async def calculate_portfolio_pnl( "ytd_pnl": 0.0, # Would need year-to-date data "total_return": total_return, "annualized_return": 0.0, # Would need time-weighted returns - "sharpe_ratio": 0.0, # Would need return volatility data - "sortino_ratio": 0.0, # Would need downside deviation data + "sharpe_ratio": 0.0, # Historical metric, should be in reporting + "sortino_ratio": 0.0, # Historical metric, should be in reporting "max_drawdown": 0.0, # Would need historical high-water marks - "win_rate": win_rate, - "profit_factor": profit_factor, - "avg_win": avg_win, - "avg_loss": avg_loss, + "win_rate": 0.0, # Historical metric, should be in reporting + "profit_factor": 0.0, # Historical metric, should be in reporting + "avg_win": 0.0, # Historical metric, should be in reporting + "avg_loss": 0.0, # Historical metric, should be in reporting "total_trades": len(positions), - "winning_trades": len(winning_pnls), - "losing_trades": len(losing_pnls), - "largest_win": largest_win, - "largest_loss": largest_loss, + "winning_trades": 0, # Historical metric + "losing_trades": 0, # Historical metric + "largest_win": 0.0, # Historical metric + "largest_loss": 0.0, # Historical metric "avg_trade_duration_minutes": 0.0, # Would need position entry times "last_updated": datetime.now().isoformat(), } diff --git a/src/project_x_py/position_manager/core.py b/src/project_x_py/position_manager/core.py index b8ac793..42fc86d 100644 --- a/src/project_x_py/position_manager/core.py +++ b/src/project_x_py/position_manager/core.py @@ -80,9 +80,16 @@ async def main(): from project_x_py.position_manager.monitoring import PositionMonitoringMixin from project_x_py.position_manager.operations import PositionOperationsMixin from project_x_py.position_manager.reporting import PositionReportingMixin -from project_x_py.position_manager.risk import RiskManagementMixin + +# from project_x_py.position_manager.risk import RiskManagementMixin # DEPRECATED from project_x_py.position_manager.tracking import PositionTrackingMixin +from project_x_py.risk_manager import RiskManager from project_x_py.types.config_types import PositionManagerConfig +from project_x_py.types.protocols import RealtimeDataManagerProtocol +from project_x_py.types.response_types import ( + PositionSizingResponse, + RiskAnalysisResponse, +) from project_x_py.utils import ( LogMessages, ProjectXLogger, @@ -98,7 +105,7 @@ async def main(): class PositionManager( PositionTrackingMixin, PositionAnalyticsMixin, - RiskManagementMixin, + # RiskManagementMixin, PositionMonitoringMixin, PositionOperationsMixin, PositionReportingMixin, @@ -170,6 +177,8 @@ def __init__( self, project_x_client: "ProjectXBase", event_bus: Any, + risk_manager: Optional["RiskManager"] = None, + data_manager: Optional["RealtimeDataManagerProtocol"] = None, config: PositionManagerConfig | None = None, ): """ @@ -183,6 +192,9 @@ def __init__( used for all API operations. Must be properly authenticated before use. event_bus: EventBus instance for unified event handling. Required for all event emissions including position updates, P&L changes, and risk alerts. + risk_manager: Optional risk manager instance. If provided, enables advanced + risk management features and position sizing calculations. + data_manager: Optional data manager for market data and P&L alerts. config: Optional configuration for position management behavior. If not provided, default values will be used for all configuration options. @@ -216,6 +228,8 @@ def __init__( self.project_x = project_x_client self.event_bus = event_bus # Store the event bus for emitting events + self.risk_manager = risk_manager + self.data_manager = data_manager self.logger = ProjectXLogger.get_logger(__name__) # Store configuration with defaults @@ -237,10 +251,14 @@ def __init__( self.stats = { "open_positions": 0, "closed_positions": 0, + "winning_positions": 0, + "losing_positions": 0, "total_positions": 0, "total_pnl": 0.0, "realized_pnl": 0.0, "unrealized_pnl": 0.0, + "gross_profit": 0.0, + "gross_loss": 0.0, "best_position_pnl": 0.0, "worst_position_pnl": 0.0, "avg_position_size": 0.0, @@ -565,6 +583,41 @@ async def is_position_open( position = await self.get_position(contract_id, account_id) return position is not None and position.size != 0 + # ================================================================================ + # RISK MANAGEMENT DELEGATION + # ================================================================================ + + async def get_risk_metrics(self) -> "RiskAnalysisResponse": + """Delegates risk metrics calculation to the main RiskManager.""" + if self.risk_manager: + return await self.risk_manager.get_risk_metrics() + else: + raise ValueError( + "Risk manager not configured. Enable 'risk_manager' feature in TradingSuite." + ) + + async def calculate_position_size( + self, + contract_id: str, + risk_amount: float, + entry_price: float, + stop_price: float, + account_balance: float | None = None, + ) -> "PositionSizingResponse": + """Delegates position sizing to the main RiskManager.""" + instrument = await self.project_x.get_instrument(contract_id) + if self.risk_manager: + return await self.risk_manager.calculate_position_size( + entry_price=entry_price, + stop_loss=stop_price, + risk_amount=risk_amount, + instrument=instrument, + ) + else: + raise ValueError( + "Risk manager not configured. Enable 'risk_manager' feature in TradingSuite." + ) + async def cleanup(self) -> None: """ Clean up resources and connections when shutting down. diff --git a/src/project_x_py/position_manager/monitoring.py b/src/project_x_py/position_manager/monitoring.py index 5d1afc3..d0b3fe4 100644 --- a/src/project_x_py/position_manager/monitoring.py +++ b/src/project_x_py/position_manager/monitoring.py @@ -67,16 +67,27 @@ class PositionMonitoringMixin: # Type hints for mypy - these attributes are provided by the main class if TYPE_CHECKING: + from project_x_py.client import ProjectXBase + from project_x_py.types.protocols import RealtimeDataManagerProtocol + position_lock: Lock logger: logging.Logger stats: dict[str, Any] _realtime_enabled: bool + project_x: "ProjectXBase" + data_manager: "RealtimeDataManagerProtocol | None" # Methods from other mixins/main class async def _trigger_callbacks( self, event_type: str, data: dict[str, Any] ) -> None: ... async def refresh_positions(self, account_id: int | None = None) -> bool: ... + async def calculate_position_pnl( + self, + position: Position, + current_price: float | None = None, + point_value: float | None = None, + ) -> Any: ... def __init__(self) -> None: """Initialize monitoring attributes.""" @@ -171,16 +182,50 @@ async def _check_position_alerts( separately. Currently only size change detection is implemented. """ alert = self.position_alerts.get(contract_id) - if not alert or alert["triggered"]: + if not alert or alert.get("triggered"): return - # Note: P&L-based alerts require current market prices - # For now, only check position size changes alert_triggered = False alert_message = "" + # P&L-based alerts + if self.data_manager and ( + alert.get("max_loss") is not None or alert.get("max_gain") is not None + ): + try: + current_price = await self.data_manager.get_current_price() + if current_price is not None: + instrument = await self.project_x.get_instrument(contract_id) + point_value = getattr(instrument, "contractMultiplier", 1.0) + + pnl_data = await self.calculate_position_pnl( + current_position, current_price, point_value + ) + pnl = pnl_data["unrealized_pnl"] + + if alert.get("max_loss") is not None and pnl <= alert["max_loss"]: + alert_triggered = True + alert_message = f"Position {contract_id} breached max loss of {alert['max_loss']}. Current P&L: {pnl:.2f}" + + if ( + not alert_triggered + and alert.get("max_gain") is not None + and pnl >= alert["max_gain"] + ): + alert_triggered = True + alert_message = f"Position {contract_id} reached max gain of {alert['max_gain']}. Current P&L: {pnl:.2f}" + + except Exception as e: + self.logger.warning( + f"Could not check P&L alert for {contract_id} due to error: {e}" + ) + # Check for position size changes as a basic alert - if old_position and current_position.size != old_position.size: + if ( + not alert_triggered + and old_position + and current_position.size != old_position.size + ): size_change = current_position.size - old_position.size alert_triggered = True alert_message = ( diff --git a/src/project_x_py/position_manager/risk.py b/src/project_x_py/position_manager/risk.py index 41c95fd..8d91fb1 100644 --- a/src/project_x_py/position_manager/risk.py +++ b/src/project_x_py/position_manager/risk.py @@ -160,7 +160,7 @@ async def get_risk_metrics( ) # Generate risk warnings/recommendations - _risk_warnings = self._generate_risk_warnings( + _risk_warnings = self._generate_risk_warnings( # type: ignore[attr-defined] positions, portfolio_risk, largest_position_risk ) @@ -360,7 +360,7 @@ async def calculate_position_size( ) # Generate sizing warnings - sizing_warnings = self._generate_sizing_warnings( + sizing_warnings = self._generate_sizing_warnings( # type: ignore[attr-defined] risk_percentage, suggested_size ) diff --git a/src/project_x_py/position_manager/tracking.py b/src/project_x_py/position_manager/tracking.py index a5b20c9..99a8016 100644 --- a/src/project_x_py/position_manager/tracking.py +++ b/src/project_x_py/position_manager/tracking.py @@ -305,12 +305,49 @@ async def _process_position_data(self, position_data: dict[str, Any]) -> None: old_size: int = old_position.size if old_position else 0 if is_position_closed: - # Position is closed - remove from tracking and trigger closure callbacks + # Position is closed - calculate realized P&L and update stats + if old_position: + # Assume the averagePrice in the closing update is the exit price + exit_price = actual_position_data.get( + "averagePrice", old_position.averagePrice + ) + entry_price = old_position.averagePrice + size = old_position.size + + # This is a simplified P&L calculation. + # For futures, a point_value/multiplier is needed. + # Assuming point_value of 1 for now. + if old_position.type == PositionType.LONG: + pnl = (exit_price - entry_price) * size + else: # SHORT + pnl = (entry_price - exit_price) * size + + self.stats["realized_pnl"] += pnl + self.stats["closed_positions"] += 1 + if pnl > 0: + self.stats["winning_positions"] = ( + self.stats.get("winning_positions", 0) + 1 + ) + self.stats["gross_profit"] = ( + self.stats.get("gross_profit", 0.0) + pnl + ) + if pnl > self.stats.get("best_position_pnl", 0.0): + self.stats["best_position_pnl"] = pnl + else: + self.stats["losing_positions"] = ( + self.stats.get("losing_positions", 0) + 1 + ) + self.stats["gross_loss"] = ( + self.stats.get("gross_loss", 0.0) + pnl + ) # pnl is negative + if pnl < self.stats.get("worst_position_pnl", 0.0): + self.stats["worst_position_pnl"] = pnl + + # Remove from tracking if contract_id in self.tracked_positions: del self.tracked_positions[contract_id] - self.logger.info(f"๐Ÿ“Š Position closed: {contract_id}") - self.stats["closed_positions"] = ( - self.stats.get("closed_positions", 0) + 1 + self.logger.info( + f"๐Ÿ“Š Position closed: {contract_id}, Realized P&L: {pnl:.2f}" ) # Synchronize orders - cancel related orders when position is closed @@ -321,26 +358,43 @@ async def _process_position_data(self, position_data: dict[str, Any]) -> None: await self._trigger_callbacks("position_closed", actual_position_data) else: # Position is open/updated - create or update position - # Build a complete Position object, filling defaults for missing fields - # Real-time updates may omit id/accountId/creationTimestamp - from datetime import UTC as _UTC, datetime as _dt - - position_dict: dict[str, Any] = { - "id": actual_position_data.get("id", -1), - "accountId": actual_position_data.get("accountId", -1), - "contractId": contract_id, - "creationTimestamp": actual_position_data.get( - "creationTimestamp", _dt.now(_UTC).isoformat() - ), - "type": actual_position_data.get("type", PositionType.UNDEFINED), - "size": position_size, - "averagePrice": actual_position_data.get("averagePrice", 0.0), - } + is_new_position = contract_id not in self.tracked_positions - position: Position = Position(**position_dict) + if is_new_position: + # For new positions, some fields might be missing from the real-time feed. + # We create a new object with defaults for any missing critical fields. + from datetime import UTC as _UTC, datetime as _dt + + position_dict: dict[str, Any] = { + "id": actual_position_data.get("id", -1), + "accountId": actual_position_data.get("accountId", -1), + "contractId": contract_id, + "creationTimestamp": actual_position_data.get( + "creationTimestamp", _dt.now(_UTC).isoformat() + ), + "type": actual_position_data.get( + "type", PositionType.UNDEFINED + ), + "size": position_size, + "averagePrice": actual_position_data.get("averagePrice", 0.0), + } + else: + # For existing positions, merge the update with the cached object + # to preserve fields like 'id' and 'creationTimestamp'. + existing_position = self.tracked_positions[contract_id] + # Manually construct dict from the existing position object + position_dict = { + "id": existing_position.id, + "accountId": existing_position.accountId, + "contractId": existing_position.contractId, + "creationTimestamp": existing_position.creationTimestamp, + "type": existing_position.type, + "size": existing_position.size, + "averagePrice": existing_position.averagePrice, + } + position_dict.update(actual_position_data) - # Check if this is a new position (didn't exist before) - is_new_position = contract_id not in self.tracked_positions + position: Position = Position(**position_dict) self.tracked_positions[contract_id] = position # Emit appropriate event diff --git a/src/project_x_py/realtime/batched_handler.py b/src/project_x_py/realtime/batched_handler.py index 5ae417e..0d5cb1f 100644 --- a/src/project_x_py/realtime/batched_handler.py +++ b/src/project_x_py/realtime/batched_handler.py @@ -62,6 +62,11 @@ def __init__( self.total_processing_time = 0.0 self.last_batch_time = time.time() + # Circuit breaker state + self.failed_batches = 0 + self.circuit_breaker_tripped_at: float | None = None + self.circuit_breaker_timeout = 60.0 # 60 seconds + # Lock for thread safety self._lock = asyncio.Lock() @@ -89,6 +94,18 @@ async def _process_batch(self) -> None: async with self._lock: if self.processing: return + + # Check circuit breaker + if self.circuit_breaker_tripped_at: + if ( + time.time() - self.circuit_breaker_tripped_at + ) > self.circuit_breaker_timeout: + logger.warning("Resetting circuit breaker.") + self.circuit_breaker_tripped_at = None + self.failed_batches = 0 + else: + return # Circuit breaker is tripped, do not process + self.processing = True try: @@ -136,6 +153,8 @@ async def _process_batch(self) -> None: if self.process_callback: try: await self.process_callback(batch) + # Reset failure count on success + self.failed_batches = 0 except asyncio.CancelledError: # Re-raise cancellation for proper shutdown raise @@ -145,12 +164,14 @@ async def _process_batch(self) -> None: exc_info=True, ) # Track failures for circuit breaker - self.failed_batches = getattr(self, "failed_batches", 0) + 1 + self.failed_batches += 1 if self.failed_batches > 10: logger.critical( - "Batch processing circuit breaker triggered" + f"Batch processing circuit breaker triggered for {self.circuit_breaker_timeout}s." ) + self.circuit_breaker_tripped_at = time.time() self.processing = False + return # Stop processing # Update metrics processing_time = time.time() - start_time diff --git a/src/project_x_py/realtime/connection_management.py b/src/project_x_py/realtime/connection_management.py index e78680c..c30f7e3 100644 --- a/src/project_x_py/realtime/connection_management.py +++ b/src/project_x_py/realtime/connection_management.py @@ -269,7 +269,11 @@ async def connect(self: "ProjectXRealtimeClientProtocol") -> bool: await self.setup_connections() # Store the event loop for cross-thread task scheduling - self._loop = asyncio.get_event_loop() + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + logger.error("No running event loop found.") + return False logger.debug(LogMessages.WS_CONNECT) @@ -293,8 +297,22 @@ async def connect(self: "ProjectXRealtimeClientProtocol") -> bool: ) return False - # Wait for connections to establish and stabilize - await asyncio.sleep(2.0) # Increased wait time for connection stability + # Wait for connections to establish + try: + await asyncio.wait_for( + asyncio.gather( + self.user_hub_ready.wait(), self.market_hub_ready.wait() + ), + timeout=10.0, + ) + except TimeoutError: + logger.error( + LogMessages.WS_ERROR, + extra={ + "error": "Connection attempt timed out after 10 seconds." + }, + ) + return False if self.user_connected and self.market_connected: self.stats["connected_time"] = datetime.now() @@ -325,7 +343,7 @@ async def _start_connection_async( This is an internal method that bridges sync SignalR with async code. """ # SignalR connections are synchronous, so we run them in executor - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() await loop.run_in_executor(None, connection.start) logger.debug(LogMessages.WS_CONNECTED, extra={"hub": name}) @@ -365,13 +383,12 @@ async def disconnect(self: "ProjectXRealtimeClientProtocol") -> None: logger.debug(LogMessages.WS_DISCONNECT) async with self._connection_lock: + loop = asyncio.get_running_loop() if self.user_connection: - loop = asyncio.get_event_loop() await loop.run_in_executor(None, self.user_connection.stop) self.user_connected = False if self.market_connection: - loop = asyncio.get_event_loop() await loop.run_in_executor(None, self.market_connection.stop) self.market_connected = False @@ -390,6 +407,7 @@ def _on_user_hub_open(self: "ProjectXRealtimeClientProtocol") -> None: - Logs connection success """ self.user_connected = True + self.user_hub_ready.set() self.logger.info("โœ… User hub connected") def _on_user_hub_close(self: "ProjectXRealtimeClientProtocol") -> None: @@ -407,6 +425,7 @@ def _on_user_hub_close(self: "ProjectXRealtimeClientProtocol") -> None: Automatic reconnection will attempt based on configuration. """ self.user_connected = False + self.user_hub_ready.clear() self.logger.warning("โŒ User hub disconnected") def _on_market_hub_open(self: "ProjectXRealtimeClientProtocol") -> None: @@ -421,6 +440,7 @@ def _on_market_hub_open(self: "ProjectXRealtimeClientProtocol") -> None: - Logs connection success """ self.market_connected = True + self.market_hub_ready.set() self.logger.info("โœ… Market hub connected") def _on_market_hub_close(self: "ProjectXRealtimeClientProtocol") -> None: @@ -438,6 +458,7 @@ def _on_market_hub_close(self: "ProjectXRealtimeClientProtocol") -> None: Automatic reconnection will attempt based on configuration. """ self.market_connected = False + self.market_hub_ready.clear() self.logger.warning("โŒ Market hub disconnected") def _on_connection_error( diff --git a/src/project_x_py/realtime/core.py b/src/project_x_py/realtime/core.py index 3eb98a9..9a1e346 100644 --- a/src/project_x_py/realtime/core.py +++ b/src/project_x_py/realtime/core.py @@ -100,6 +100,7 @@ async def main(): from project_x_py.realtime.connection_management import ConnectionManagementMixin from project_x_py.realtime.event_handling import EventHandlingMixin from project_x_py.realtime.subscriptions import SubscriptionsMixin +from project_x_py.types.base import HubConnection if TYPE_CHECKING: from project_x_py.models import ProjectXConfig @@ -285,17 +286,16 @@ def __init__( self.base_market_url = "https://rtc.topstepx.com/hubs/market" # SignalR connection objects - self.user_connection: Any | None = None - self.market_connection: Any | None = None + self.user_connection: HubConnection | None = None + self.market_connection: HubConnection | None = None # Connection state tracking self.user_connected = False self.market_connected = False self.setup_complete = False - # Event callbacks (pure forwarding, no caching) - already initialized in mixin - if not hasattr(self, "callbacks"): - self.callbacks: defaultdict[str, list[Any]] = defaultdict(list) + # Event callbacks (pure forwarding, no caching) + self.callbacks: defaultdict[str, list[Any]] = defaultdict(list) # Basic statistics (no business logic) self.stats = { @@ -315,8 +315,10 @@ def __init__( self.logger.info(f"User Hub: {final_user_url}") self.logger.info(f"Market Hub: {final_market_url}") - # Async locks for thread-safe operations - check if not already initialized - if not hasattr(self, "_callback_lock"): - self._callback_lock = asyncio.Lock() - if not hasattr(self, "_connection_lock"): - self._connection_lock = asyncio.Lock() + # Async locks for thread-safe operations + self._callback_lock = asyncio.Lock() + self._connection_lock = asyncio.Lock() + + # Async events for connection readiness + self.user_hub_ready = asyncio.Event() + self.market_hub_ready = asyncio.Event() diff --git a/src/project_x_py/realtime/event_handling.py b/src/project_x_py/realtime/event_handling.py index fd33b27..d39b322 100644 --- a/src/project_x_py/realtime/event_handling.py +++ b/src/project_x_py/realtime/event_handling.py @@ -430,7 +430,7 @@ def _schedule_async_task( ) except Exception as e: # Fallback for logging - avoid recursion - print(f"Error scheduling async task: {e}") + self.logger.error(f"Error scheduling async task: {e}") else: # Fallback - try to create task in current loop context try: @@ -439,7 +439,7 @@ def _schedule_async_task( task.add_done_callback(lambda t: None) except RuntimeError: # No event loop available, log and continue - print(f"No event loop available for {event_type} event") + self.logger.error(f"No event loop available for {event_type} event") async def _forward_event_async( self: "ProjectXRealtimeClientProtocol", event_type: str, args: Any diff --git a/src/project_x_py/realtime/subscriptions.py b/src/project_x_py/realtime/subscriptions.py index 2b5d379..57307f6 100644 --- a/src/project_x_py/realtime/subscriptions.py +++ b/src/project_x_py/realtime/subscriptions.py @@ -138,6 +138,15 @@ async def subscribe_user_updates(self: "ProjectXRealtimeClientProtocol") -> bool ) return False + try: + await asyncio.wait_for(self.user_hub_ready.wait(), timeout=5.0) + except TimeoutError: + logger.error( + LogMessages.WS_ERROR, + extra={"error": "User hub not ready for subscriptions after 5s"}, + ) + return False + logger.debug( LogMessages.DATA_SUBSCRIBE, extra={"channel": "user_updates", "account_id": self.account_id}, @@ -149,30 +158,8 @@ async def subscribe_user_updates(self: "ProjectXRealtimeClientProtocol") -> bool ) return False - # Wait for transport to be ready - max_wait = 5.0 # Maximum 5 seconds - wait_interval = 0.1 - waited = 0.0 - - while waited < max_wait: - if ( - hasattr(self.user_connection, "transport") - and self.user_connection.transport - and hasattr(self.user_connection.transport, "is_running") - and self.user_connection.transport.is_running() - ): - break - await asyncio.sleep(wait_interval) - waited += wait_interval - else: - logger.error( - LogMessages.WS_ERROR, - extra={"error": "User hub transport not ready after waiting"}, - ) - return False - # ProjectX Gateway expects Subscribe method with account ID - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() # Subscribe to account updates await loop.run_in_executor( @@ -279,6 +266,15 @@ async def subscribe_market_data( ) return False + try: + await asyncio.wait_for(self.market_hub_ready.wait(), timeout=5.0) + except TimeoutError: + logger.error( + LogMessages.WS_ERROR, + extra={"error": "Market hub not ready for subscriptions after 5s"}, + ) + return False + logger.debug( LogMessages.DATA_SUBSCRIBE, extra={"channel": "market_data", "count": len(contract_ids)}, @@ -290,13 +286,9 @@ async def subscribe_market_data( self._subscribed_contracts.append(contract_id) # Subscribe using ProjectX Gateway methods (same as sync client) - loop = asyncio.get_event_loop() - - # Add a small delay to ensure the connection is fully established - await asyncio.sleep(0.5) + loop = asyncio.get_running_loop() for contract_id in contract_ids: - # Subscribe to quotes if self.market_connection is None: logger.error( LogMessages.WS_ERROR, @@ -304,28 +296,6 @@ async def subscribe_market_data( ) return False - # Wait for transport to be ready - max_wait = 5.0 # Maximum 5 seconds - wait_interval = 0.1 - waited = 0.0 - - while waited < max_wait: - if ( - hasattr(self.market_connection, "transport") - and self.market_connection.transport - and hasattr(self.market_connection.transport, "is_running") - and self.market_connection.transport.is_running() - ): - break - await asyncio.sleep(wait_interval) - waited += wait_interval - else: - logger.error( - LogMessages.WS_ERROR, - extra={"error": "Market hub transport not ready after waiting"}, - ) - return False - try: await loop.run_in_executor( None, @@ -333,29 +303,12 @@ async def subscribe_market_data( "SubscribeContractQuotes", [contract_id], ) - except Exception as e: - logger.error( - LogMessages.WS_ERROR, - extra={"error": f"Failed to subscribe to quotes: {e!s}"}, - ) - return False - # Subscribe to trades - try: await loop.run_in_executor( None, self.market_connection.send, "SubscribeContractTrades", [contract_id], ) - except Exception as e: - logger.error( - LogMessages.WS_ERROR, - extra={"error": f"Failed to subscribe to trades: {e!s}"}, - ) - return False - - # Subscribe to market depth - try: await loop.run_in_executor( None, self.market_connection.send, @@ -365,7 +318,7 @@ async def subscribe_market_data( except Exception as e: logger.error( LogMessages.WS_ERROR, - extra={"error": f"Failed to subscribe to market depth: {e!s}"}, + extra={"error": f"Failed to subscribe to {contract_id}: {e!s}"}, ) return False @@ -426,41 +379,27 @@ async def unsubscribe_user_updates(self: "ProjectXRealtimeClientProtocol") -> bo logger.debug( LogMessages.DATA_UNSUBSCRIBE, extra={"channel": "user_updates"} ) - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() + account_id_arg = [int(self.account_id)] # Unsubscribe from account updates await loop.run_in_executor( - None, - self.user_connection.send, - "UnsubscribeAccounts", - self.account_id, + None, self.user_connection.send, "UnsubscribeAccounts", account_id_arg ) # Unsubscribe from order updates - await loop.run_in_executor( - None, - self.user_connection.send, - "UnsubscribeOrders", - [self.account_id], + None, self.user_connection.send, "UnsubscribeOrders", account_id_arg ) # Unsubscribe from position updates - await loop.run_in_executor( - None, - self.user_connection.send, - "UnsubscribePositions", - self.account_id, + None, self.user_connection.send, "UnsubscribePositions", account_id_arg ) # Unsubscribe from trade updates - await loop.run_in_executor( - None, - self.user_connection.send, - "UnsubscribeTrades", - self.account_id, + None, self.user_connection.send, "UnsubscribeTrades", account_id_arg ) logger.debug( @@ -532,7 +471,7 @@ async def unsubscribe_market_data( self._subscribed_contracts.remove(contract_id) # ProjectX Gateway expects Unsubscribe method - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() if self.market_connection is None: logger.error( LogMessages.WS_ERROR, @@ -545,7 +484,7 @@ async def unsubscribe_market_data( None, self.market_connection.send, "UnsubscribeContractQuotes", - [contract_ids], + contract_ids, ) # Unsubscribe from trades @@ -553,7 +492,7 @@ async def unsubscribe_market_data( None, self.market_connection.send, "UnsubscribeContractTrades", - [contract_ids], + contract_ids, ) # Unsubscribe from market depth @@ -561,7 +500,7 @@ async def unsubscribe_market_data( None, self.market_connection.send, "UnsubscribeContractMarketDepth", - [contract_ids], + contract_ids, ) logger.debug( diff --git a/src/project_x_py/realtime_data_manager/callbacks.py b/src/project_x_py/realtime_data_manager/callbacks.py index 2900b13..68d5c96 100644 --- a/src/project_x_py/realtime_data_manager/callbacks.py +++ b/src/project_x_py/realtime_data_manager/callbacks.py @@ -112,12 +112,23 @@ async def bar_callback(event): from collections.abc import Callable, Coroutine from typing import TYPE_CHECKING, Any +from project_x_py.event_bus import EventType + if TYPE_CHECKING: from project_x_py.types import RealtimeDataManagerProtocol logger = logging.getLogger(__name__) +_EVENT_TYPE_MAPPING = { + "new_bar": EventType.NEW_BAR, + "data_update": EventType.DATA_UPDATE, + "quote_update": EventType.QUOTE_UPDATE, + "trade_tick": EventType.TRADE_TICK, + "market_trade": EventType.TRADE_TICK, +} + + class CallbackMixin: """Mixin for event handling through EventBus.""" @@ -206,19 +217,8 @@ def on_data_update(data): - Exceptions in callbacks are caught and logged, preventing them from affecting the data manager's operation """ - # Map old event types to EventType enum - from project_x_py.event_bus import EventType - - event_mapping = { - "new_bar": EventType.NEW_BAR, - "data_update": EventType.DATA_UPDATE, - "quote_update": EventType.QUOTE_UPDATE, - "trade_tick": EventType.TRADE_TICK, - "market_trade": EventType.TRADE_TICK, - } - - if event_type in event_mapping: - await self.event_bus.on(event_mapping[event_type], callback) + if event_type in _EVENT_TYPE_MAPPING: + await self.event_bus.on(_EVENT_TYPE_MAPPING[event_type], callback) else: self.logger.warning(f"Unknown event type: {event_type}") @@ -232,20 +232,9 @@ async def _trigger_callbacks( event_type: Type of event to trigger data: Data to pass to callbacks """ - from project_x_py.event_bus import EventType - - # Map event types to EventType enum - event_mapping = { - "new_bar": EventType.NEW_BAR, - "data_update": EventType.DATA_UPDATE, - "quote_update": EventType.QUOTE_UPDATE, - "trade_tick": EventType.TRADE_TICK, - "market_trade": EventType.TRADE_TICK, - } - - if event_type in event_mapping: + if event_type in _EVENT_TYPE_MAPPING: await self.event_bus.emit( - event_mapping[event_type], data, source="RealtimeDataManager" + _EVENT_TYPE_MAPPING[event_type], data, source="RealtimeDataManager" ) else: self.logger.warning(f"Unknown event type: {event_type}") diff --git a/src/project_x_py/realtime_data_manager/core.py b/src/project_x_py/realtime_data_manager/core.py index 7391d74..676623a 100644 --- a/src/project_x_py/realtime_data_manager/core.py +++ b/src/project_x_py/realtime_data_manager/core.py @@ -148,6 +148,16 @@ async def on_new_bar(event): from project_x_py.realtime import ProjectXRealtimeClient +class _DummyEventBus: + """A dummy event bus that does nothing, for use when no event bus is provided.""" + + async def on(self, event_type: Any, callback: Any) -> None: + """No-op event registration.""" + + async def emit(self, event_type: Any, data: Any, source: str | None = None) -> None: + """No-op event emission.""" + + class RealtimeDataManager( DataProcessingMixin, MemoryManagementMixin, @@ -331,7 +341,7 @@ def __init__( self.project_x: ProjectXBase = project_x self.realtime_client: ProjectXRealtimeClient = realtime_client # EventBus is optional in tests; fallback to a simple dummy if None - self.event_bus = event_bus if event_bus is not None else object() + self.event_bus = event_bus if event_bus is not None else _DummyEventBus() self.logger = ProjectXLogger.get_logger(__name__) @@ -728,20 +738,27 @@ async def stop_realtime_feed(self) -> None: self.is_running = False - # Cancel cleanup task + # Cancel background tasks first await self.stop_cleanup_task() - - # Cancel bar timer task await self._stop_bar_timer_task() - # Unsubscribe from market data - # Note: unsubscribe_market_data will be implemented in ProjectXRealtimeClient + # Unsubscribe from market data and remove callbacks if self.contract_id: self.logger.info(f"๐Ÿ“‰ Unsubscribing from {self.contract_id}") + # Unsubscribe from market data + await self.realtime_client.unsubscribe_market_data([self.contract_id]) + + # Remove callbacks + await self.realtime_client.remove_callback( + "quote_update", self._on_quote_update + ) + await self.realtime_client.remove_callback( + "market_trade", self._on_trade_update + ) self.logger.info(f"โœ… Real-time feed stopped for {self.instrument}") - except RuntimeError as e: + except Exception as e: self.logger.error(f"โŒ Error stopping real-time feed: {e}") async def cleanup(self) -> None: diff --git a/src/project_x_py/realtime_data_manager/data_access.py b/src/project_x_py/realtime_data_manager/data_access.py index 8c0ad2e..55dd0ff 100644 --- a/src/project_x_py/realtime_data_manager/data_access.py +++ b/src/project_x_py/realtime_data_manager/data_access.py @@ -107,7 +107,8 @@ import polars as pl if TYPE_CHECKING: - pass + from pytz import BaseTzInfo + logger = logging.getLogger(__name__) @@ -121,6 +122,7 @@ class DataAccessMixin: data: dict[str, pl.DataFrame] current_tick_data: list[dict[str, Any]] | deque[dict[str, Any]] tick_size: float + timezone: "BaseTzInfo" async def get_data( self, @@ -528,7 +530,7 @@ async def get_bars_since( if isinstance(timestamp, datetime) and timestamp.tzinfo is None: # Assume it's in the data's timezone # self.timezone is a pytz timezone object, we need its zone string - tz_str = "America/Chicago" + tz_str = str(self.timezone) timestamp = timestamp.replace(tzinfo=ZoneInfo(tz_str)) # Filter bars diff --git a/src/project_x_py/realtime_data_manager/data_processing.py b/src/project_x_py/realtime_data_manager/data_processing.py index 7c00954..6eb574c 100644 --- a/src/project_x_py/realtime_data_manager/data_processing.py +++ b/src/project_x_py/realtime_data_manager/data_processing.py @@ -125,10 +125,10 @@ class DataProcessingMixin: # Methods from other mixins/main class def _parse_and_validate_quote_payload( - self, data: dict[str, Any] + self, quote_data: Any ) -> dict[str, Any] | None: ... def _parse_and_validate_trade_payload( - self, data: dict[str, Any] + self, trade_data: Any ) -> dict[str, Any] | None: ... def _symbol_matches_instrument(self, symbol: str) -> bool: ... async def _trigger_callbacks( @@ -465,10 +465,6 @@ async def _update_timeframe_data( ] ) - # Prune memory - if self.data[tf_key].height > 1000: - self.data[tf_key] = self.data[tf_key].tail(1000) - # Return None if no new bar was created return None diff --git a/src/project_x_py/realtime_data_manager/memory_management.py b/src/project_x_py/realtime_data_manager/memory_management.py index 70f9700..6c1102b 100644 --- a/src/project_x_py/realtime_data_manager/memory_management.py +++ b/src/project_x_py/realtime_data_manager/memory_management.py @@ -166,7 +166,7 @@ async def _cleanup_old_data(self) -> None: # Keep only the most recent bars (sliding window) if initial_count > self.max_bars_per_timeframe: self.data[tf_key] = self.data[tf_key].tail( - self.max_bars_per_timeframe // 2 + self.max_bars_per_timeframe ) total_bars_after += len(self.data[tf_key]) diff --git a/src/project_x_py/risk_manager/core.py b/src/project_x_py/risk_manager/core.py index 65996eb..4e65708 100644 --- a/src/project_x_py/risk_manager/core.py +++ b/src/project_x_py/risk_manager/core.py @@ -20,6 +20,7 @@ OrderManagerProtocol, PositionManagerProtocol, ProjectXClientProtocol, + RealtimeDataManagerProtocol, ) from .config import RiskConfig @@ -42,24 +43,30 @@ def __init__( self, project_x: ProjectXClientProtocol, order_manager: OrderManagerProtocol, - position_manager: PositionManagerProtocol, event_bus: "EventBus", + position_manager: PositionManagerProtocol | None = None, config: RiskConfig | None = None, + data_manager: Optional["RealtimeDataManagerProtocol"] = None, ): """Initialize risk manager. Args: project_x: ProjectX client instance order_manager: Order manager instance - position_manager: Position manager instance event_bus: Event bus for risk events + position_manager: Optional position manager instance (can be set later) config: Risk configuration (uses defaults if not provided) + data_manager: Optional data manager for market data """ self.client = project_x self.orders = order_manager self.positions = position_manager + self.position_manager = ( + position_manager # Also store as position_manager for compatibility + ) self.event_bus = event_bus self.config = config or RiskConfig() + self.data_manager = data_manager # Track daily losses and trades self._daily_loss = Decimal("0") @@ -76,6 +83,11 @@ def __init__( self._current_risk = Decimal("0") self._max_drawdown = Decimal("0") + def set_position_manager(self, position_manager: PositionManagerProtocol) -> None: + """Set the position manager after initialization to resolve circular dependency.""" + self.positions = position_manager + self.position_manager = position_manager + async def calculate_position_size( self, entry_price: float, @@ -184,6 +196,9 @@ async def validate_trade( warnings = [] is_valid = True + if self.positions is None: + raise ValueError("Position manager not set") + # Get current positions if not provided if current_positions is None: current_positions = await self.positions.get_all_positions() @@ -305,15 +320,49 @@ async def attach_risk_orders( # Calculate stop loss if not provided if stop_loss is None and self.config.use_stop_loss: - if self.config.stop_loss_type == "fixed": - stop_distance = self.config.default_stop_distance * tick_size - elif self.config.stop_loss_type == "atr": - # TODO: Get ATR from data manager - stop_distance = self.config.default_stop_distance * tick_size - else: # percentage + if self.config.stop_loss_type == "atr": + if not self.data_manager: + logger.warning( + "ATR stop loss configured but no data manager is available. " + "Falling back to fixed stop." + ) + stop_distance = self.config.default_stop_distance * tick_size + else: + # Fetch data to calculate ATR. A common period for ATR is 14. + # We need enough data for the calculation. Let's fetch 50 bars. + # A default timeframe of '15min' is reasonable for ATR stops. + ohlc_data = await self.data_manager.get_data( + timeframe="15min", bars=50 + ) + if ohlc_data is None or ohlc_data.height < 14: + logger.warning( + "Not enough data to calculate ATR. Falling back to fixed stop." + ) + stop_distance = ( + self.config.default_stop_distance * tick_size + ) + else: + from project_x_py.indicators import calculate_atr + + data_with_atr = calculate_atr(ohlc_data, period=14) + latest_atr = data_with_atr["atr_14"].tail(1).item() + if latest_atr: + stop_distance = ( + latest_atr * self.config.default_stop_atr_multiplier + ) + else: + logger.warning( + "ATR calculation resulted in None. Falling back to fixed stop." + ) + stop_distance = ( + self.config.default_stop_distance * tick_size + ) + elif self.config.stop_loss_type == "percentage": stop_distance = entry_price * ( self.config.default_stop_distance / 100 ) + else: # fixed + stop_distance = self.config.default_stop_distance * tick_size stop_loss = ( entry_price - stop_distance @@ -480,6 +529,9 @@ async def get_risk_metrics(self) -> RiskAnalysisResponse: Comprehensive risk analysis """ try: + if self.positions is None: + raise ValueError("Position manager not set") + account = await self._get_account_info() positions = await self.positions.get_all_positions() @@ -529,16 +581,8 @@ async def _get_account_info(self) -> "Account": accounts = await self.client.list_accounts() if accounts: return accounts[0] - # Create a default account if none found - from project_x_py.models import Account - - return Account( - id=0, - name="Default", - balance=10000.0, - canTrade=True, - isVisible=True, - simulated=True, + raise ValueError( + "No account found. RiskManager cannot proceed without account information." ) def _check_daily_reset(self) -> None: @@ -670,6 +714,31 @@ def _extract_symbol(self, contract_id: str) -> str: parts = contract_id.split(".") return parts[3] if len(parts) > 3 else contract_id + async def _get_market_price(self, contract_id: str) -> float: + """Get current market price for a contract.""" + if not self.data_manager: + raise RuntimeError("Data manager not available for market price fetching.") + + # This assumes the data_manager is configured for the correct instrument. + timeframes_to_try = ["1sec", "15sec", "1min", "5min"] + + for timeframe in timeframes_to_try: + try: + data = await self.data_manager.get_data(timeframe, bars=1) + if data is not None and not data.is_empty(): + return float(data["close"].tail(1).item()) + except Exception: + continue + + try: + current_price = await self.data_manager.get_current_price() + if current_price is not None: + return float(current_price) + except Exception: + pass + + raise RuntimeError(f"Unable to fetch current market price for {contract_id}") + async def _monitor_trailing_stop( self, position: "Position", @@ -682,6 +751,9 @@ async def _monitor_trailing_stop( while True: # Get current price + if self.positions is None: + raise ValueError("Position manager not set") + current_positions = await self.positions.get_all_positions() current_pos = next( (p for p in current_positions if p.id == position.id), None @@ -691,8 +763,16 @@ async def _monitor_trailing_stop( # Position closed break - # Check if trailing should activate - current_price = float(current_pos.averagePrice) + # Get current market price + try: + current_price = await self._get_market_price(position.contractId) + except RuntimeError as e: + logger.warning( + f"Could not fetch price for trailing stop on {position.contractId}: {e}" + ) + await asyncio.sleep(10) # Wait longer if price is unavailable + continue + profit = ( (current_price - entry_price) if is_long diff --git a/src/project_x_py/risk_manager/managed_trade.py b/src/project_x_py/risk_manager/managed_trade.py index eca794c..1a83ec1 100644 --- a/src/project_x_py/risk_manager/managed_trade.py +++ b/src/project_x_py/risk_manager/managed_trade.py @@ -1,8 +1,10 @@ """Managed trade context manager for risk-controlled trading.""" +import asyncio import logging from typing import TYPE_CHECKING, Any +from project_x_py.event_bus import EventType from project_x_py.types import OrderSide, OrderType from project_x_py.types.protocols import OrderManagerProtocol, PositionManagerProtocol @@ -32,6 +34,7 @@ def __init__( position_manager: PositionManagerProtocol, instrument_id: str, data_manager: Any | None = None, + event_bus: Any | None = None, max_risk_percent: float | None = None, max_risk_amount: float | None = None, ): @@ -43,6 +46,7 @@ def __init__( position_manager: Position manager instance instrument_id: Instrument/contract ID to trade data_manager: Optional data manager for market price fetching + event_bus: Optional event bus for event-driven waits max_risk_percent: Override max risk percentage max_risk_amount: Override max risk dollar amount """ @@ -51,6 +55,7 @@ def __init__( self.positions = position_manager self.instrument_id = instrument_id self.data_manager = data_manager + self.event_bus = event_bus self.max_risk_percent = max_risk_percent self.max_risk_amount = max_risk_amount @@ -618,17 +623,74 @@ async def _get_market_price(self) -> float: async def _wait_for_order_fill( self, order: "Order", timeout_seconds: int = 10 ) -> bool: - """Wait for an order to fill. + """Waits for an order to fill, using an event-driven approach if possible.""" + if not self.event_bus: + logger.warning( + "No event_bus available on ManagedTrade, falling back to polling for order fill." + ) + return await self._poll_for_order_fill(order, timeout_seconds) + + fill_event = asyncio.Event() + filled_successfully = False + + async def order_fill_handler(event: Any) -> None: + nonlocal filled_successfully + # Extract data from Event object + event_data = event.data if hasattr(event, "data") else event + if isinstance(event_data, dict): + # Check both direct order_id and order.id from Order object + event_order_id = event_data.get("order_id") + if not event_order_id and "order" in event_data: + order_obj = event_data.get("order") + if hasattr(order_obj, "id"): + event_order_id = order_obj.id + if event_order_id == order.id: + filled_successfully = True + fill_event.set() + + async def order_terminal_handler(event: Any) -> None: + nonlocal filled_successfully + # Extract data from Event object + event_data = event.data if hasattr(event, "data") else event + if isinstance(event_data, dict): + # Check both direct order_id and order.id from Order object + event_order_id = event_data.get("order_id") + if not event_order_id and "order" in event_data: + order_obj = event_data.get("order") + if hasattr(order_obj, "id"): + event_order_id = order_obj.id + if event_order_id == order.id: + filled_successfully = False + fill_event.set() + + await self.event_bus.on(EventType.ORDER_FILLED, order_fill_handler) + await self.event_bus.on(EventType.ORDER_CANCELLED, order_terminal_handler) + await self.event_bus.on(EventType.ORDER_REJECTED, order_terminal_handler) - Args: - order: Order to wait for - timeout_seconds: Maximum time to wait + try: + await asyncio.wait_for(fill_event.wait(), timeout=timeout_seconds) + except TimeoutError: + logger.warning(f"Timeout waiting for order {order.id} to fill via event.") + filled_successfully = False + finally: + # Important: Clean up the event handlers to prevent memory leaks + if hasattr(self.event_bus, "remove_callback"): + await self.event_bus.remove_callback( + EventType.ORDER_FILLED, order_fill_handler + ) + await self.event_bus.remove_callback( + EventType.ORDER_CANCELLED, order_terminal_handler + ) + await self.event_bus.remove_callback( + EventType.ORDER_REJECTED, order_terminal_handler + ) - Returns: - True if order filled, False if timeout - """ - import asyncio + return filled_successfully + async def _poll_for_order_fill( + self, order: "Order", timeout_seconds: int = 10 + ) -> bool: + """Wait for an order to fill by polling its status.""" start_time = asyncio.get_event_loop().time() check_interval = 0.5 # Check every 500ms diff --git a/src/project_x_py/trading_suite.py b/src/project_x_py/trading_suite.py index 9285837..446f49d 100644 --- a/src/project_x_py/trading_suite.py +++ b/src/project_x_py/trading_suite.py @@ -38,7 +38,7 @@ from enum import Enum from pathlib import Path from types import TracebackType -from typing import Any +from typing import Any, cast import orjson import yaml @@ -60,6 +60,7 @@ OrderManagerConfig, PositionManagerConfig, ) +from project_x_py.types.protocols import ProjectXClientProtocol from project_x_py.types.stats_types import ComponentStats, TradingSuiteStats from project_x_py.utils import ProjectXLogger @@ -87,6 +88,11 @@ def __init__( initial_days: int = 5, auto_connect: bool = True, timezone: str = "America/Chicago", + order_manager_config: OrderManagerConfig | None = None, + position_manager_config: PositionManagerConfig | None = None, + data_manager_config: DataManagerConfig | None = None, + orderbook_config: OrderbookConfig | None = None, + risk_config: RiskConfig | None = None, ): self.instrument = instrument self.timeframes = timeframes or ["5min"] @@ -94,9 +100,21 @@ def __init__( self.initial_days = initial_days self.auto_connect = auto_connect self.timezone = timezone + self.order_manager_config = order_manager_config + self.position_manager_config = position_manager_config + self.data_manager_config = data_manager_config + self.orderbook_config = orderbook_config + self.risk_config = risk_config def get_order_manager_config(self) -> OrderManagerConfig: - """Get configuration for OrderManager.""" + """ + Get configuration for OrderManager. + + Returns: + OrderManagerConfig: The configuration for the OrderManager. + """ + if self.order_manager_config: + return self.order_manager_config return { "enable_bracket_orders": Features.RISK_MANAGER in self.features, "enable_trailing_stops": True, @@ -105,7 +123,14 @@ def get_order_manager_config(self) -> OrderManagerConfig: } def get_position_manager_config(self) -> PositionManagerConfig: - """Get configuration for PositionManager.""" + """ + Get configuration for PositionManager. + + Returns: + PositionManagerConfig: The configuration for the PositionManager. + """ + if self.position_manager_config: + return self.position_manager_config return { "enable_risk_monitoring": Features.RISK_MANAGER in self.features, "enable_correlation_analysis": Features.PERFORMANCE_ANALYTICS @@ -114,7 +139,14 @@ def get_position_manager_config(self) -> PositionManagerConfig: } def get_data_manager_config(self) -> DataManagerConfig: - """Get configuration for RealtimeDataManager.""" + """ + Get configuration for RealtimeDataManager. + + Returns: + DataManagerConfig: The configuration for the RealtimeDataManager. + """ + if self.data_manager_config: + return self.data_manager_config return { "max_bars_per_timeframe": 1000, "enable_tick_data": True, @@ -124,7 +156,14 @@ def get_data_manager_config(self) -> DataManagerConfig: } def get_orderbook_config(self) -> OrderbookConfig: - """Get configuration for OrderBook.""" + """ + Get configuration for OrderBook. + + Returns: + OrderbookConfig: The configuration for the OrderBook. + """ + if self.orderbook_config: + return self.orderbook_config return { "max_depth_levels": 100, "max_trade_history": 1000, @@ -133,7 +172,14 @@ def get_orderbook_config(self) -> OrderbookConfig: } def get_risk_config(self) -> RiskConfig: - """Get configuration for RiskManager.""" + """ + Get configuration for RiskManager. + + Returns: + RiskConfig: The configuration for the RiskManager. + """ + if self.risk_config: + return self.risk_config return RiskConfig( max_risk_per_trade=0.01, # 1% per trade max_daily_loss=0.03, # 3% daily loss @@ -201,9 +247,6 @@ def __init__( self.orders = OrderManager( client, config=config.get_order_manager_config(), event_bus=self.events ) - self.positions = PositionManager( - client, config=config.get_position_manager_config(), event_bus=self.events - ) # Optional components self.orderbook: OrderBook | None = None @@ -211,15 +254,25 @@ def __init__( self.journal = None # TODO: Future enhancement self.analytics = None # TODO: Future enhancement - # Initialize risk manager if enabled + # Create PositionManager first + self.positions = PositionManager( + client, + event_bus=self.events, + risk_manager=None, # Will be set later + data_manager=self.data, + config=config.get_position_manager_config(), + ) + + # Initialize risk manager if enabled and inject dependencies if Features.RISK_MANAGER in config.features: self.risk_manager = RiskManager( - project_x=client, + project_x=cast(ProjectXClientProtocol, client), order_manager=self.orders, - position_manager=self.positions, event_bus=self.events, + position_manager=self.positions, config=config.get_risk_config(), ) + self.positions.risk_manager = self.risk_manager # State tracking self._connected = False @@ -415,6 +468,9 @@ async def _initialize(self) -> None: await self.realtime.connect() await self.realtime.subscribe_user_updates() + # Initialize order manager with realtime client for order tracking + await self.orders.initialize(realtime_client=self.realtime) + # Initialize position manager with order manager for cleanup await self.positions.initialize( realtime_client=self.realtime, @@ -776,33 +832,42 @@ def get_stats(self) -> TradingSuiteStats: # Build component stats components: dict[str, ComponentStats] = {} if self.orders: + last_activity_obj = self.orders.stats.get("last_order_time") components["order_manager"] = ComponentStats( name="OrderManager", status="connected" if self.orders else "disconnected", uptime_seconds=uptime_seconds, - last_activity=None, - error_count=0, - memory_usage_mb=0.0, + last_activity=last_activity_obj.isoformat() + if last_activity_obj + else None, + error_count=0, # TODO: Implement error tracking in OrderManager + memory_usage_mb=0.0, # TODO: Implement memory tracking in OrderManager ) if self.positions: + last_activity_obj = self.positions.stats.get("last_position_update") components["position_manager"] = ComponentStats( name="PositionManager", status="connected" if self.positions else "disconnected", uptime_seconds=uptime_seconds, - last_activity=None, - error_count=0, - memory_usage_mb=0.0, + last_activity=last_activity_obj.isoformat() + if last_activity_obj + else None, + error_count=0, # TODO: Implement error tracking in PositionManager + memory_usage_mb=0.0, # TODO: Implement memory tracking in PositionManager ) if self.data: + last_activity_obj = self.data.memory_stats.get("last_update") components["data_manager"] = ComponentStats( name="RealtimeDataManager", status="connected" if self.data else "disconnected", uptime_seconds=uptime_seconds, - last_activity=None, - error_count=0, - memory_usage_mb=0.0, + last_activity=last_activity_obj.isoformat() + if last_activity_obj + else None, + error_count=self.data.memory_stats.get("data_validation_errors", 0), + memory_usage_mb=self.data.memory_stats.get("memory_usage_mb", 0.0), ) if self.orderbook: @@ -810,9 +875,11 @@ def get_stats(self) -> TradingSuiteStats: name="OrderBook", status="connected" if self.orderbook else "disconnected", uptime_seconds=uptime_seconds, - last_activity=None, - error_count=0, - memory_usage_mb=0.0, + last_activity=self.orderbook.last_orderbook_update.isoformat() + if self.orderbook.last_orderbook_update + else None, + error_count=0, # TODO: Implement error tracking in OrderBook + memory_usage_mb=0.0, # TODO: Implement memory tracking in OrderBook ) if self.risk_manager: @@ -820,9 +887,9 @@ def get_stats(self) -> TradingSuiteStats: name="RiskManager", status="active" if self.risk_manager else "inactive", uptime_seconds=uptime_seconds, - last_activity=None, - error_count=0, - memory_usage_mb=0.0, + last_activity=None, # TODO: Implement activity tracking in RiskManager + error_count=0, # TODO: Implement error tracking in RiskManager + memory_usage_mb=0.0, # TODO: Implement memory tracking in RiskManager ) return { diff --git a/src/project_x_py/types/base.py b/src/project_x_py/types/base.py index 3cb7706..97a3a63 100644 --- a/src/project_x_py/types/base.py +++ b/src/project_x_py/types/base.py @@ -78,7 +78,10 @@ def process_order(order_id: OrderId, contract_id: ContractId) -> None: """ from collections.abc import Callable, Coroutine -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from signalrcore.hub.base_hub_connection import BaseHubConnection # Type aliases for callbacks AsyncCallback = Callable[[dict[str, Any]], Coroutine[Any, Any, None]] @@ -95,9 +98,16 @@ def process_order(order_id: OrderId, contract_id: ContractId) -> None: OrderId = str PositionId = str +# SignalR connection type +if TYPE_CHECKING: + HubConnection = BaseHubConnection +else: + HubConnection = Any + __all__ = [ "DEFAULT_TIMEZONE", "TICK_SIZE_PRECISION", + "HubConnection", "AccountId", "AsyncCallback", "CallbackType", diff --git a/src/project_x_py/types/protocols.py b/src/project_x_py/types/protocols.py index d480579..588858d 100644 --- a/src/project_x_py/types/protocols.py +++ b/src/project_x_py/types/protocols.py @@ -95,6 +95,7 @@ async def place_order(self, contract_id: str, side: int) -> OrderPlaceResponse: import httpx import polars as pl +from project_x_py.types.base import HubConnection from project_x_py.types.response_types import ( PerformanceStatsResponse, PortfolioMetricsResponse, @@ -228,6 +229,7 @@ class OrderManagerProtocol(Protocol): order_status_cache: dict[str, int] position_orders: dict[str, dict[str, list[int]]] order_to_position: dict[int, str] + oco_groups: dict[int, int] # order_id -> other_order_id for OCO pairs # Methods that mixins need async def place_order( @@ -335,6 +337,19 @@ async def on_position_closed( async def _setup_realtime_callbacks(self) -> None: ... + # Methods used by bracket orders + async def _wait_for_order_fill( + self, order_id: int, timeout_seconds: int = 30 + ) -> bool: ... + def _link_oco_orders(self, order1_id: int, order2_id: int) -> None: ... + async def close_position( + self, + contract_id: str, + method: str = "market", + limit_price: float | None = None, + account_id: int | None = None, + ) -> "OrderPlaceResponse | None": ... + class PositionManagerProtocol(Protocol): """Protocol defining the interface that mixins expect from PositionManager.""" @@ -380,15 +395,6 @@ async def get_all_positions( async def get_position( self, contract_id: str, account_id: int | None = None ) -> "Position | None": ... - def _generate_risk_warnings( - self, - positions: list["Position"], - portfolio_risk: float, - largest_position_risk: float, - ) -> list[str]: ... - def _generate_sizing_warnings( - self, risk_percentage: float, size: int - ) -> list[str]: ... async def refresh_positions(self, account_id: int | None = None) -> int: ... async def close_position_direct( self, contract_id: str, account_id: int | None = None @@ -406,9 +412,7 @@ async def get_portfolio_pnl( self, account_id: int | None = None, ) -> "PortfolioMetricsResponse": ... - async def get_risk_metrics( - self, account_id: int | None = None - ) -> "RiskAnalysisResponse": ... + async def get_risk_metrics(self) -> "RiskAnalysisResponse": ... def get_position_statistics( self, ) -> "PositionManagerStats": ... @@ -432,7 +436,7 @@ class RealtimeDataManagerProtocol(Protocol): # Data storage data: dict[str, pl.DataFrame] - current_tick_data: list[dict[str, Any]] + current_tick_data: Any # Can be list or deque last_bar_times: dict[str, datetime.datetime] # Synchronization @@ -464,7 +468,7 @@ async def _on_trade_update(self, callback_data: dict[str, Any]) -> None: ... async def _process_tick_data(self, tick: dict[str, Any]) -> None: ... async def _update_timeframe_data( self, tf_key: str, timestamp: datetime.datetime, price: float, volume: int - ) -> None: ... + ) -> dict[str, Any] | None: ... def _calculate_bar_time( self, timestamp: datetime.datetime, interval: int, unit: int ) -> datetime.datetime: ... @@ -490,7 +494,7 @@ async def add_callback( event_type: str, callback: Callable[[dict[str, Any]], Coroutine[Any, Any, None] | None], ) -> None: ... - def get_memory_stats(self) -> dict[str, Any]: ... + def get_memory_stats(self) -> Any: ... # Returns RealtimeDataManagerStats def get_realtime_validation_status(self) -> dict[str, Any]: ... async def cleanup(self) -> None: ... @@ -511,8 +515,8 @@ class ProjectXRealtimeClientProtocol(Protocol): base_market_url: str # Connection objects - user_connection: Any | None - market_connection: Any | None + user_connection: HubConnection | None + market_connection: HubConnection | None # Connection state user_connected: bool @@ -529,9 +533,11 @@ class ProjectXRealtimeClientProtocol(Protocol): # Logging logger: Any - # Async locks + # Async locks and events _callback_lock: asyncio.Lock _connection_lock: asyncio.Lock + user_hub_ready: asyncio.Event + market_hub_ready: asyncio.Event # Event loop _loop: asyncio.AbstractEventLoop | None diff --git a/src/project_x_py/types/response_types.py b/src/project_x_py/types/response_types.py index df07e31..dc71999 100644 --- a/src/project_x_py/types/response_types.py +++ b/src/project_x_py/types/response_types.py @@ -93,13 +93,8 @@ class PerformanceStatsResponse(TypedDict): cache_hits: int cache_misses: int cache_hit_ratio: float - avg_response_time_ms: float total_requests: int - failed_requests: int - success_rate: float active_connections: int - memory_usage_mb: float - uptime_seconds: int # Risk and Analytics Response Types diff --git a/tests/client/test_cache.py b/tests/client/test_cache.py index 8bb41fc..9058c95 100644 --- a/tests/client/test_cache.py +++ b/tests/client/test_cache.py @@ -117,14 +117,12 @@ async def test_cache_cleanup(self, mock_project_x, mock_instrument): # Wait for expiry time.sleep(0.2) - # Force cleanup - await client._cleanup_cache() - - # Cache should be empty - assert len(client._opt_instrument_cache) == 0 - assert len(client._opt_instrument_cache_time) == 0 - assert len(client._opt_market_data_cache) == 0 - assert len(client._opt_market_data_cache_time) == 0 + # TTLCache automatically expires items - no manual cleanup needed + # Check that items have expired + assert client.get_cached_instrument("MGC") is None + assert client.get_cached_instrument("MNQ") is None + assert client.get_cached_market_data("key1") is None + assert client.get_cached_market_data("key2") is None @pytest.mark.asyncio async def test_clear_all_caches(self, mock_project_x, mock_instrument): @@ -145,9 +143,7 @@ async def test_clear_all_caches(self, mock_project_x, mock_instrument): # Cache should be empty assert len(client._opt_instrument_cache) == 0 - assert len(client._opt_instrument_cache_time) == 0 assert len(client._opt_market_data_cache) == 0 - assert len(client._opt_market_data_cache_time) == 0 @pytest.mark.asyncio async def test_cache_hit_tracking(self, mock_project_x, mock_instrument): diff --git a/tests/client/test_cache_optimized.py b/tests/client/test_cache_optimized.py index 59eeaf3..943fc2f 100644 --- a/tests/client/test_cache_optimized.py +++ b/tests/client/test_cache_optimized.py @@ -21,7 +21,7 @@ async def mock_project_x(self, mock_httpx_client): @pytest.mark.asyncio async def test_msgpack_serialization(self, mock_project_x): - """Test that DataFrames are serialized with msgpack.""" + """Test that DataFrames are serialized with Arrow IPC format.""" client = mock_project_x # Create test data @@ -38,7 +38,6 @@ async def test_msgpack_serialization(self, mock_project_x): # Check that data is stored in optimized cache (not compatibility cache) assert "test_key" in client._opt_market_data_cache - assert "test_key" in client._opt_market_data_cache_time # Retrieve and verify data integrity cached_data = client.get_cached_market_data("test_key") @@ -140,7 +139,7 @@ async def test_cache_statistics(self, mock_project_x, mock_instrument): # Verify values assert stats["compression_enabled"] is True - assert stats["serialization"] == "msgpack" + assert stats["serialization"] == "arrow-ipc" assert stats["compression"] == "lz4" assert stats["instrument_cache_size"] == 1 assert stats["market_data_cache_size"] == 1 diff --git a/tests/client/test_market_data.py b/tests/client/test_market_data.py index 1af8c4e..7a0c28a 100644 --- a/tests/client/test_market_data.py +++ b/tests/client/test_market_data.py @@ -540,12 +540,12 @@ async def test_get_bars_with_start_and_end_time( assert "timestamp" in bars.columns assert "open" in bars.columns - # Should use time-based cache key - start_utc = pytz.UTC.localize(start) - end_utc = pytz.UTC.localize(end) - cache_key = ( - f"MGC_{start_utc.isoformat()}_{end_utc.isoformat()}_15_2_True" - ) + # Should use time-based cache key with market timezone + # Client uses America/Chicago by default + market_tz = pytz.timezone("America/Chicago") + start_tz = market_tz.localize(start) + end_tz = market_tz.localize(end) + cache_key = f"MGC_{start_tz.isoformat()}_{end_tz.isoformat()}_15_2_True" assert cache_key in client._opt_market_data_cache @pytest.mark.asyncio @@ -677,12 +677,8 @@ async def test_get_bars_with_timezone_aware_times( assert not bars.is_empty() assert "timestamp" in bars.columns - # Cache key should use UTC times - start_utc = start.astimezone(pytz.UTC) - end_utc = end.astimezone(pytz.UTC) - cache_key = ( - f"MGC_{start_utc.isoformat()}_{end_utc.isoformat()}_30_2_True" - ) + # Cache key should use the same timezone as provided (Chicago) + cache_key = f"MGC_{start.isoformat()}_{end.isoformat()}_30_2_True" assert cache_key in client._opt_market_data_cache @pytest.mark.asyncio @@ -729,10 +725,12 @@ async def test_get_bars_time_params_override_days( ) # Verify that the cache key uses the time range, not days - start_utc = pytz.UTC.localize(start) - end_utc = pytz.UTC.localize(end) + # Client uses America/Chicago by default + market_tz = pytz.timezone("America/Chicago") + start_tz = market_tz.localize(start) + end_tz = market_tz.localize(end) time_based_key = ( - f"MGC_{start_utc.isoformat()}_{end_utc.isoformat()}_15_2_True" + f"MGC_{start_tz.isoformat()}_{end_tz.isoformat()}_15_2_True" ) days_based_key = "MGC_100_15_2_True" diff --git a/tests/conftest.py b/tests/conftest.py index 7d41920..a41d3f8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -226,6 +226,7 @@ def mock_httpx_client(): mock_client = AsyncMock() mock_client.request = AsyncMock() mock_client.aclose = AsyncMock() + mock_client.is_closed = False # Add is_closed attribute return mock_client diff --git a/tests/order_manager/test_bracket_orders.py b/tests/order_manager/test_bracket_orders.py index 0481b6e..58493ad 100644 --- a/tests/order_manager/test_bracket_orders.py +++ b/tests/order_manager/test_bracket_orders.py @@ -68,6 +68,29 @@ async def test_bracket_order_success_flow(self): "BAR": {"entry_orders": [], "stop_orders": [], "target_orders": []} } mixin.stats = {"bracket_orders": 0} + # Mock the methods that are called from bracket_orders + mixin._wait_for_order_fill = AsyncMock(return_value=True) + mixin._link_oco_orders = AsyncMock() + + # Create a side effect that updates position_orders + async def mock_track_order(contract_id, order_id, order_type, account_id=None): + if contract_id not in mixin.position_orders: + mixin.position_orders[contract_id] = { + "entry_orders": [], + "stop_orders": [], + "target_orders": [], + } + if order_type == "entry": + mixin.position_orders[contract_id]["entry_orders"].append(order_id) + elif order_type == "stop": + mixin.position_orders[contract_id]["stop_orders"].append(order_id) + elif order_type == "target": + mixin.position_orders[contract_id]["target_orders"].append(order_id) + + mixin.track_order_for_position = AsyncMock(side_effect=mock_track_order) + mixin.close_position = AsyncMock() + mixin.cancel_order = AsyncMock() + mixin.oco_groups = {} # Entry type = limit resp = await mixin.place_bracket_order( diff --git a/tests/test_client_auth_simple.py b/tests/test_client_auth_simple.py new file mode 100644 index 0000000..e8a0bb8 --- /dev/null +++ b/tests/test_client_auth_simple.py @@ -0,0 +1,235 @@ +"""Simplified tests for the authentication module of ProjectX client.""" + +import asyncio +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, patch + +import jwt +import pytest + +from project_x_py.client.auth import AuthenticationMixin +from project_x_py.exceptions import ProjectXAuthenticationError +from project_x_py.models import Account + + +class MockAuthClient(AuthenticationMixin): + """Mock client that includes AuthenticationMixin for testing.""" + + def __init__(self): + super().__init__() + self.username = "test_user" + self.api_key = "test_api_key" + self.account_name = None + self.base_url = "https://api.test.com" + self.headers = {} + self._http_client = AsyncMock() + self._make_request = AsyncMock() + self._auth_lock = asyncio.Lock() + self._authenticated = False + self.jwt_token = None + self.session_token = None + self.account_info = None + + +class TestAuthenticationMixin: + """Test suite for AuthenticationMixin class.""" + + @pytest.fixture + def auth_client(self): + """Create a mock client with AuthenticationMixin for testing.""" + return MockAuthClient() + + @pytest.mark.asyncio + async def test_authenticate_success(self, auth_client): + """Test successful authentication flow.""" + # Mock responses for the two API calls + auth_response = {"token": "test_jwt_token"} + accounts_response = { + "success": True, + "accounts": [ + { + "id": 1, + "name": "Test Account", + "balance": 10000.0, + "canTrade": True, + "isVisible": True, + "simulated": False, + } + ], + } + + auth_client._make_request.side_effect = [auth_response, accounts_response] + + await auth_client.authenticate() + + assert auth_client.session_token == "test_jwt_token" + assert auth_client.account_info.name == "Test Account" + assert auth_client._authenticated is True + + @pytest.mark.asyncio + async def test_authenticate_with_specific_account(self, auth_client): + """Test authentication with specific account selection.""" + auth_client.account_name = "Second Account" + + auth_response = {"token": "test_jwt_token"} + accounts_response = { + "success": True, + "accounts": [ + { + "id": 1, + "name": "First Account", + "balance": 5000.0, + "canTrade": True, + "isVisible": True, + "simulated": False, + }, + { + "id": 2, + "name": "Second Account", + "balance": 10000.0, + "canTrade": True, + "isVisible": True, + "simulated": True, + }, + ], + } + + auth_client._make_request.side_effect = [auth_response, accounts_response] + + await auth_client.authenticate() + + assert auth_client.account_info.name == "Second Account" + assert auth_client.account_info.id == 2 + assert auth_client.account_info.simulated is True + + @pytest.mark.asyncio + async def test_authenticate_no_matching_account(self, auth_client): + """Test authentication fails when specified account not found.""" + auth_client.account_name = "Nonexistent Account" + + auth_response = {"token": "test_jwt_token"} + accounts_response = { + "success": True, + "accounts": [ + { + "id": 1, + "name": "Only Account", + "balance": 5000.0, + "canTrade": True, + "isVisible": True, + "simulated": False, + } + ], + } + + auth_client._make_request.side_effect = [auth_response, accounts_response] + + from project_x_py.exceptions import ProjectXError + + with pytest.raises(ProjectXError, match="not found"): + await auth_client.authenticate() + + @pytest.mark.asyncio + async def test_authenticate_no_accounts(self, auth_client): + """Test authentication fails when no accounts returned.""" + auth_response = {"token": "test_jwt_token"} + accounts_response = {"success": True, "accounts": []} + + auth_client._make_request.side_effect = [auth_response, accounts_response] + + with pytest.raises(ProjectXAuthenticationError, match="No accounts found"): + await auth_client.authenticate() + + @pytest.mark.asyncio + async def test_ensure_authenticated_when_not_authenticated(self, auth_client): + """Test _ensure_authenticated triggers authentication.""" + auth_client._authenticated = False + auth_client.authenticate = AsyncMock() + + await auth_client._ensure_authenticated() + + auth_client.authenticate.assert_called_once() + + @pytest.mark.asyncio + async def test_ensure_authenticated_when_authenticated(self, auth_client): + """Test _ensure_authenticated skips when already authenticated.""" + auth_client._authenticated = True + auth_client.jwt_token = "valid_token" + auth_client.account_info = Account( + id=1, + name="Test", + balance=10000.0, + canTrade=True, + isVisible=True, + simulated=False, + ) + auth_client.authenticate = AsyncMock() + auth_client._should_refresh_token = lambda: False # Mock to return False + + await auth_client._ensure_authenticated() + + auth_client.authenticate.assert_not_called() + + def test_should_refresh_token_near_expiry(self, auth_client): + """Test _should_refresh_token returns True when token is near expiry.""" + import pytz + + auth_client.token_expiry = datetime.now(pytz.UTC) + timedelta(minutes=4) + assert auth_client._should_refresh_token() is True + + def test_should_refresh_token_plenty_time(self, auth_client): + """Test _should_refresh_token returns False when token has time.""" + import pytz + + auth_client.token_expiry = datetime.now(pytz.UTC) + timedelta(hours=2) + assert auth_client._should_refresh_token() is False + + def test_should_refresh_token_no_expiry(self, auth_client): + """Test _should_refresh_token returns True when no expiry set.""" + auth_client.token_expiry = None + assert auth_client._should_refresh_token() is True + + @pytest.mark.asyncio + async def test_list_accounts(self, auth_client): + """Test listing all available accounts.""" + accounts_response = { + "success": True, + "accounts": [ + { + "id": 1, + "name": "Account 1", + "balance": 5000.0, + "canTrade": True, + "isVisible": True, + "simulated": False, + }, + { + "id": 2, + "name": "Account 2", + "balance": 10000.0, + "canTrade": True, + "isVisible": True, + "simulated": True, + }, + ], + } + + auth_client._make_request.return_value = accounts_response + auth_client._ensure_authenticated = AsyncMock() # Mock authentication check + + accounts = await auth_client.list_accounts() + + assert len(accounts) == 2 + assert accounts[0].name == "Account 1" + assert accounts[1].name == "Account 2" + + @pytest.mark.asyncio + async def test_authentication_error_handling(self, auth_client): + """Test proper error handling during authentication.""" + auth_client._make_request.side_effect = Exception("Connection failed") + + with pytest.raises(Exception, match="Connection failed"): + await auth_client.authenticate() + + assert auth_client._authenticated is False + assert auth_client.jwt_token is None diff --git a/tests/test_client_base.py b/tests/test_client_base.py new file mode 100644 index 0000000..39539a1 --- /dev/null +++ b/tests/test_client_base.py @@ -0,0 +1,330 @@ +"""Comprehensive tests for the base module of ProjectX client.""" + +import os +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from project_x_py.client.base import ProjectXBase +from project_x_py.exceptions import ProjectXAuthenticationError +from project_x_py.models import Account, ProjectXConfig + + +class TestProjectXBase: + """Test suite for ProjectXBase class.""" + + @pytest.fixture + def mock_config(self): + """Create a mock configuration.""" + return ProjectXConfig( + api_url="https://api.test.com", + realtime_url="wss://realtime.test.com", + user_hub_url="/tradehub-userhub", + market_hub_url="/tradehub-markethub", + timezone="America/Chicago", + timeout_seconds=30, + retry_attempts=3, + retry_delay_seconds=1.0, + requests_per_minute=100, + burst_limit=10, + ) + + @pytest.fixture + def base_client(self, mock_config): + """Create a ProjectXBase client for testing.""" + return ProjectXBase( + username="testuser", + api_key="test-api-key", + config=mock_config, + account_name="TEST_ACCOUNT", + ) + + def test_initialization(self, base_client): + """Test client initialization.""" + assert base_client.username == "testuser" + assert base_client.api_key == "test-api-key" + assert base_client.account_name == "TEST_ACCOUNT" + assert base_client.base_url == "https://api.test.com" + assert base_client._client is None + assert base_client._authenticated is False + assert base_client.session_token == "" # Initialized as empty string + assert base_client.account_info is None + assert base_client.api_call_count == 0 + assert base_client.cache_hit_count == 0 + + def test_initialization_with_defaults(self): + """Test client initialization with default config.""" + client = ProjectXBase( + username="user", + api_key="key", + ) + assert client.username == "user" + assert client.api_key == "key" + assert client.account_name is None + assert client.base_url == "https://api.topstepx.com/api" # Default URL + + @pytest.mark.asyncio + async def test_context_manager(self, base_client): + """Test async context manager functionality.""" + mock_http_client = AsyncMock() + + with patch( + "project_x_py.client.base.httpx.AsyncClient", return_value=mock_http_client + ): + async with base_client as client: + assert client is base_client + assert base_client._client is not None + + # After exiting context, client should be closed + mock_http_client.aclose.assert_called_once() + assert base_client._client is None + + @pytest.mark.asyncio + async def test_context_manager_with_exception(self, base_client): + """Test context manager handles exceptions properly.""" + mock_http_client = AsyncMock() + + with patch( + "project_x_py.client.base.httpx.AsyncClient", return_value=mock_http_client + ): + with pytest.raises(ValueError, match="Test exception"): + async with base_client: + raise ValueError("Test exception") + + # Client should still be closed even with exception + mock_http_client.aclose.assert_called_once() + + def test_get_session_token_when_authenticated(self, base_client): + """Test getting session token when authenticated.""" + base_client._authenticated = True + base_client.session_token = "test-session-token" + + token = base_client.get_session_token() + assert token == "test-session-token" + + def test_get_session_token_when_not_authenticated(self, base_client): + """Test getting session token when not authenticated.""" + base_client._authenticated = False + + with pytest.raises(ProjectXAuthenticationError, match="Not authenticated"): + base_client.get_session_token() + + def test_get_session_token_no_token(self, base_client): + """Test getting session token when authenticated but no token.""" + base_client._authenticated = True + base_client.session_token = "" # Empty string counts as no token + + with pytest.raises(ProjectXAuthenticationError, match="Not authenticated"): + base_client.get_session_token() + + def test_get_account_info_when_available(self, base_client): + """Test getting account info when available.""" + account = Account( + id=12345, + name="Test Account", + balance=10000.0, + canTrade=True, + isVisible=True, + simulated=False, + ) + base_client.account_info = account + + result = base_client.get_account_info() + assert result == account + + def test_get_account_info_when_not_available(self, base_client): + """Test getting account info when not available.""" + base_client.account_info = None + + with pytest.raises(ProjectXAuthenticationError, match="No account selected"): + base_client.get_account_info() + + @pytest.mark.asyncio + async def test_from_env_success(self): + """Test creating client from environment variables.""" + with patch.dict( + os.environ, + { + "PROJECT_X_USERNAME": "env_user", + "PROJECT_X_API_KEY": "env_key", + "PROJECT_X_ACCOUNT_NAME": "env_account", + }, + ): + with patch("project_x_py.client.base.ConfigManager") as mock_config_manager: + mock_manager = Mock() + mock_manager.get_auth_config.return_value = { + "username": "env_user", + "api_key": "env_key", + } + mock_config_manager.return_value = mock_manager + + mock_http_client = AsyncMock() + with patch( + "project_x_py.client.base.httpx.AsyncClient", + return_value=mock_http_client, + ): + async with ProjectXBase.from_env() as client: + assert client.username == "env_user" + assert client.api_key == "env_key" + assert ( + client.account_name == "ENV_ACCOUNT" + ) # Should be uppercase + + @pytest.mark.asyncio + async def test_from_env_with_custom_account(self): + """Test creating client from environment with custom account name.""" + with patch.dict( + os.environ, + { + "PROJECT_X_USERNAME": "env_user", + "PROJECT_X_API_KEY": "env_key", + }, + ): + with patch("project_x_py.client.base.ConfigManager") as mock_config_manager: + mock_manager = Mock() + mock_manager.get_auth_config.return_value = { + "username": "env_user", + "api_key": "env_key", + } + mock_config_manager.return_value = mock_manager + + mock_http_client = AsyncMock() + with patch( + "project_x_py.client.base.httpx.AsyncClient", + return_value=mock_http_client, + ): + async with ProjectXBase.from_env( + account_name="custom_account" + ) as client: + assert client.account_name == "CUSTOM_ACCOUNT" + + @pytest.mark.asyncio + async def test_from_env_with_custom_config(self): + """Test creating client from environment with custom config.""" + custom_config = ProjectXConfig( + api_url="https://custom.api.com", + realtime_url="wss://custom.realtime.com", + user_hub_url="/custom-userhub", + market_hub_url="/custom-markethub", + timezone="Europe/London", + timeout_seconds=60, + retry_attempts=5, + retry_delay_seconds=2.0, + requests_per_minute=200, + burst_limit=20, + ) + + with patch.dict( + os.environ, + { + "PROJECT_X_USERNAME": "env_user", + "PROJECT_X_API_KEY": "env_key", + }, + ): + with patch("project_x_py.client.base.ConfigManager") as mock_config_manager: + mock_manager = Mock() + mock_manager.get_auth_config.return_value = { + "username": "env_user", + "api_key": "env_key", + } + mock_config_manager.return_value = mock_manager + + mock_http_client = AsyncMock() + with patch( + "project_x_py.client.base.httpx.AsyncClient", + return_value=mock_http_client, + ): + async with ProjectXBase.from_env(config=custom_config) as client: + assert client.config == custom_config + assert client.base_url == "https://custom.api.com" + + @pytest.mark.asyncio + async def test_from_config_file(self): + """Test creating client from config file.""" + with patch("project_x_py.client.base.ConfigManager") as mock_config_manager: + mock_manager = Mock() + mock_config = ProjectXConfig( + api_url="https://file.api.com", + realtime_url="wss://file.realtime.com", + user_hub_url="/file-userhub", + market_hub_url="/file-markethub", + timezone="US/Pacific", + timeout_seconds=45, + retry_attempts=4, + retry_delay_seconds=1.5, + requests_per_minute=150, + burst_limit=15, + ) + mock_manager.load_config.return_value = mock_config + mock_manager.get_auth_config.return_value = { + "username": "file_user", + "api_key": "file_key", + } + mock_config_manager.return_value = mock_manager + + mock_http_client = AsyncMock() + with patch( + "project_x_py.client.base.httpx.AsyncClient", + return_value=mock_http_client, + ): + async with ProjectXBase.from_config_file("test_config.json") as client: + assert client.username == "file_user" + assert client.api_key == "file_key" + assert client.base_url == "https://file.api.com" + + # Verify ConfigManager was called with the config file + mock_config_manager.assert_called_once_with("test_config.json") + + @pytest.mark.asyncio + async def test_from_config_file_with_account_name(self): + """Test creating client from config file with account name.""" + with patch("project_x_py.client.base.ConfigManager") as mock_config_manager: + mock_manager = Mock() + mock_config = ProjectXConfig( + api_url="https://file.api.com", + realtime_url="wss://file.realtime.com", + user_hub_url="/file-userhub", + market_hub_url="/file-markethub", + timezone="US/Pacific", + timeout_seconds=45, + retry_attempts=4, + retry_delay_seconds=1.5, + requests_per_minute=150, + burst_limit=15, + ) + mock_manager.load_config.return_value = mock_config + mock_manager.get_auth_config.return_value = { + "username": "file_user", + "api_key": "file_key", + } + mock_config_manager.return_value = mock_manager + + mock_http_client = AsyncMock() + with patch( + "project_x_py.client.base.httpx.AsyncClient", + return_value=mock_http_client, + ): + async with ProjectXBase.from_config_file( + "test_config.json", account_name="file_account" + ) as client: + assert client.account_name == "FILE_ACCOUNT" + + def test_headers_property(self, base_client): + """Test headers property.""" + assert base_client.headers == {"Content-Type": "application/json"} + + def test_config_property(self, mock_config): + """Test config property.""" + client = ProjectXBase( + username="user", + api_key="key", + config=mock_config, + ) + assert client.config == mock_config + assert client.config.timezone == "America/Chicago" + + def test_rate_limiter_initialization(self, base_client): + """Test rate limiter is properly initialized.""" + assert base_client.rate_limiter is not None + assert base_client.rate_limiter.max_requests == 100 + assert base_client.rate_limiter.window_seconds == 60 diff --git a/tests/test_client_cache.py b/tests/test_client_cache.py new file mode 100644 index 0000000..44b0e18 --- /dev/null +++ b/tests/test_client_cache.py @@ -0,0 +1,363 @@ +"""Tests for the cache module of ProjectX client.""" + +import io +from datetime import datetime, timezone +from unittest.mock import Mock, patch + +import lz4.frame +import polars as pl +import pytest +import pytz + +from project_x_py.client.cache import CacheMixin +from project_x_py.models import Instrument + + +class MockCacheClient(CacheMixin): + """Mock client that includes CacheMixin for testing.""" + + def __init__(self): + # Set config before calling super().__init__() + self.config = {"compression_threshold": 1024, "compression_level": 3} + super().__init__() + + +class TestCacheMixin: + """Test suite for CacheMixin class.""" + + @pytest.fixture + def cache_client(self): + """Create a mock client with CacheMixin for testing.""" + return MockCacheClient() + + def test_initialization(self, cache_client): + """Test cache initialization.""" + assert cache_client.cache_ttl == 300 # Default 5 minutes + assert cache_client.cache_hit_count == 0 + assert cache_client.compression_threshold == 1024 + assert cache_client.compression_level == 3 + assert len(cache_client._opt_instrument_cache) == 0 + assert len(cache_client._opt_market_data_cache) == 0 + + def test_cache_ttl_setter(self, cache_client): + """Test setting cache TTL recreates caches.""" + # Add some data first + instrument = Instrument( + id="TEST123", + name="TEST", + description="Test instrument", + tickSize=0.25, + tickValue=12.5, + activeContract=True, + symbolId="TEST", + ) + cache_client.cache_instrument("TEST", instrument) + + # Change TTL + cache_client.cache_ttl = 600 + + # Verify new TTL and caches are cleared + assert cache_client.cache_ttl == 600 + assert cache_client._opt_instrument_cache.ttl == 600 + assert cache_client._opt_market_data_cache.ttl == 600 + # Caches should be recreated (empty) + assert len(cache_client._opt_instrument_cache) == 0 + + def test_cache_instrument(self, cache_client): + """Test caching an instrument.""" + instrument = Instrument( + id="CON.F.US.MNQ.U25", + name="MNQU25", + description="Micro E-mini Nasdaq-100", + tickSize=0.25, + tickValue=12.5, + activeContract=True, + symbolId="MNQ", + ) + + cache_client.cache_instrument("MNQ", instrument) + + assert "MNQ" in cache_client._opt_instrument_cache + assert cache_client._opt_instrument_cache["MNQ"] == instrument + + def test_get_cached_instrument_hit(self, cache_client): + """Test getting a cached instrument (cache hit).""" + instrument = Instrument( + id="CON.F.US.ES.H25", + name="ESH25", + description="E-mini S&P 500", + tickSize=0.25, + tickValue=12.5, + activeContract=True, + symbolId="ES", + ) + + cache_client.cache_instrument("ES", instrument) + cache_hit_count_before = cache_client.cache_hit_count + + cached = cache_client.get_cached_instrument("ES") + + assert cached == instrument + assert cache_client.cache_hit_count == cache_hit_count_before + 1 + + def test_get_cached_instrument_miss(self, cache_client): + """Test getting a non-cached instrument (cache miss).""" + cached = cache_client.get_cached_instrument("UNKNOWN") + + assert cached is None + assert cache_client.cache_hit_count == 0 + + def test_get_cached_instrument_case_insensitive(self, cache_client): + """Test that instrument cache is case-insensitive.""" + instrument = Instrument( + id="TEST", + name="TEST", + description="Test", + tickSize=1.0, + tickValue=1.0, + activeContract=True, + symbolId="TEST", + ) + + cache_client.cache_instrument("test", instrument) + + # Should find it with different case + cached = cache_client.get_cached_instrument("TEST") + assert cached == instrument + + cached = cache_client.get_cached_instrument("TeSt") + assert cached == instrument + + def test_serialize_dataframe_small(self, cache_client): + """Test serializing a small DataFrame (no compression).""" + # Create a very small DataFrame that won't exceed compression threshold + df = pl.DataFrame({"value": [1.0]}) + + serialized = cache_client._serialize_dataframe(df) + + # Check the actual size to determine if it's compressed + # IPC format adds overhead, so even small DataFrames might be > 1KB + if len(serialized) - 3 > cache_client.compression_threshold: + assert serialized.startswith(b"LZ4") # Compressed + else: + assert serialized.startswith(b"RAW") # Not compressed + assert len(serialized) > 3 + + def test_serialize_dataframe_large(self, cache_client): + """Test serializing a large DataFrame (with compression).""" + # Create a large DataFrame that exceeds compression threshold + df = pl.DataFrame( + { + "timestamp": [datetime.now(pytz.UTC)] * 100, + "open": [100.0] * 100, + "high": [101.0] * 100, + "low": [99.0] * 100, + "close": [100.5] * 100, + "volume": [1000] * 100, + } + ) + + serialized = cache_client._serialize_dataframe(df) + + assert serialized.startswith(b"LZ4") # Compressed + assert len(serialized) > 3 + + def test_serialize_empty_dataframe(self, cache_client): + """Test serializing an empty DataFrame.""" + df = pl.DataFrame() + + serialized = cache_client._serialize_dataframe(df) + + assert serialized == b"" + + def test_deserialize_dataframe_small(self, cache_client): + """Test deserializing a small DataFrame.""" + original_df = pl.DataFrame( + { + "timestamp": [datetime(2025, 1, 15, 10, 30, tzinfo=pytz.UTC)], + "open": [100.0], + "high": [101.0], + "low": [99.0], + "close": [100.5], + "volume": [1000], + } + ) + + serialized = cache_client._serialize_dataframe(original_df) + deserialized = cache_client._deserialize_dataframe(serialized) + + assert deserialized is not None + assert deserialized.equals(original_df) + + def test_deserialize_dataframe_large(self, cache_client): + """Test deserializing a large compressed DataFrame.""" + # Create a large DataFrame + original_df = pl.DataFrame( + { + "timestamp": [ + datetime(2025, 1, 15, i, 0, tzinfo=pytz.UTC) for i in range(24) + ], + "open": [100.0 + i for i in range(24)], + "high": [101.0 + i for i in range(24)], + "low": [99.0 + i for i in range(24)], + "close": [100.5 + i for i in range(24)], + "volume": [1000 * i for i in range(24)], + } + ) + + serialized = cache_client._serialize_dataframe(original_df) + deserialized = cache_client._deserialize_dataframe(serialized) + + assert deserialized is not None + assert len(deserialized) == len(original_df) + assert list(deserialized.columns) == list(original_df.columns) + + def test_deserialize_empty_data(self, cache_client): + """Test deserializing empty data.""" + deserialized = cache_client._deserialize_dataframe(b"") + + assert deserialized is None + + def test_deserialize_corrupted_lz4(self, cache_client): + """Test deserializing corrupted LZ4 data.""" + corrupted = b"LZ4" + b"corrupted_data" + + with patch("project_x_py.client.cache.logger") as mock_logger: + deserialized = cache_client._deserialize_dataframe(corrupted) + + assert deserialized is None + mock_logger.warning.assert_called_once() + + def test_deserialize_unknown_header(self, cache_client): + """Test deserializing data with unknown header.""" + unknown = b"XXX" + b"some_data" + + with patch("project_x_py.client.cache.logger") as mock_logger: + deserialized = cache_client._deserialize_dataframe(unknown) + + assert deserialized is None + mock_logger.warning.assert_called_once() + + def test_cache_market_data(self, cache_client): + """Test caching market data.""" + df = pl.DataFrame( + {"timestamp": [datetime.now(pytz.UTC)], "close": [100.0], "volume": [1000]} + ) + + cache_key = "TEST_KEY" + cache_client.cache_market_data(cache_key, df) + + assert cache_key in cache_client._opt_market_data_cache + assert len(cache_client._opt_market_data_cache[cache_key]) > 0 + + def test_get_cached_market_data_hit(self, cache_client): + """Test getting cached market data (cache hit).""" + original_df = pl.DataFrame( + {"timestamp": [datetime.now(pytz.UTC)], "close": [100.0], "volume": [1000]} + ) + + cache_key = "TEST_KEY" + cache_client.cache_market_data(cache_key, original_df) + cache_hit_count_before = cache_client.cache_hit_count + + cached_df = cache_client.get_cached_market_data(cache_key) + + assert cached_df is not None + assert cached_df.equals(original_df) + assert cache_client.cache_hit_count == cache_hit_count_before + 1 + + def test_get_cached_market_data_miss(self, cache_client): + """Test getting non-cached market data (cache miss).""" + cached_df = cache_client.get_cached_market_data("UNKNOWN_KEY") + + assert cached_df is None + assert cache_client.cache_hit_count == 0 + + def test_get_cached_market_data_corrupted(self, cache_client): + """Test getting market data with corrupted cache entry.""" + cache_key = "CORRUPTED" + cache_client._opt_market_data_cache[cache_key] = b"corrupted_data" + + cached_df = cache_client.get_cached_market_data(cache_key) + + assert cached_df is None + assert cache_client.cache_hit_count == 0 + + def test_clear_all_caches(self, cache_client): + """Test clearing all caches.""" + # Add some data + instrument = Instrument( + id="TEST", + name="TEST", + description="Test", + tickSize=1.0, + tickValue=1.0, + activeContract=True, + symbolId="TEST", + ) + cache_client.cache_instrument("TEST", instrument) + + df = pl.DataFrame({"value": [1, 2, 3]}) + cache_client.cache_market_data("TEST_DATA", df) + cache_client.cache_hit_count = 10 + + # Clear caches + cache_client.clear_all_caches() + + assert len(cache_client._opt_instrument_cache) == 0 + assert len(cache_client._opt_market_data_cache) == 0 + assert cache_client.cache_hit_count == 0 + + def test_get_cache_stats(self, cache_client): + """Test getting cache statistics.""" + # Add some data + instrument = Instrument( + id="TEST", + name="TEST", + description="Test", + tickSize=1.0, + tickValue=1.0, + activeContract=True, + symbolId="TEST", + ) + cache_client.cache_instrument("TEST", instrument) + + df = pl.DataFrame({"value": [1, 2, 3]}) + cache_client.cache_market_data("TEST_DATA", df) + cache_client.cache_hit_count = 5 + + stats = cache_client.get_cache_stats() + + assert stats["cache_hits"] == 5 + assert stats["instrument_cache_size"] == 1 + assert stats["market_data_cache_size"] == 1 + assert stats["instrument_cache_max"] == 1000 + assert stats["market_data_cache_max"] == 10000 + assert stats["compression_enabled"] is True + assert stats["serialization"] == "arrow-ipc" + assert stats["compression"] == "lz4" + + def test_cache_ttl_expiration(self, cache_client): + """Test that cache respects TTL expiration.""" + # Set very short TTL + cache_client.cache_ttl = 0.001 # 1 millisecond + + instrument = Instrument( + id="TEST", + name="TEST", + description="Test", + tickSize=1.0, + tickValue=1.0, + activeContract=True, + symbolId="TEST", + ) + cache_client.cache_instrument("TEST", instrument) + + # Wait for expiration + import time + + time.sleep(0.01) + + # Should be expired + cached = cache_client.get_cached_instrument("TEST") + assert cached is None diff --git a/tests/test_client_http.py b/tests/test_client_http.py new file mode 100644 index 0000000..24e3780 --- /dev/null +++ b/tests/test_client_http.py @@ -0,0 +1,354 @@ +"""Tests for the HTTP module of ProjectX client.""" + +import time +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import httpx +import pytest + +from project_x_py.client.http import HttpMixin +from project_x_py.exceptions import ( + ProjectXAuthenticationError, + ProjectXConnectionError, + ProjectXDataError, + ProjectXError, + ProjectXRateLimitError, + ProjectXServerError, +) + + +class MockHttpClient(HttpMixin): + """Mock client that includes HttpMixin for testing.""" + + def __init__(self): + super().__init__() + self.base_url = "https://api.test.com" + self.headers = {"X-Test": "test"} + self.session_token = None + self.config = Mock() + self.config.timeout_seconds = 30 + self.rate_limiter = Mock() + self.rate_limiter.acquire = AsyncMock() + self.cache_hit_count = 0 + self._refresh_authentication = AsyncMock() + + +class TestHttpMixin: + """Test suite for HttpMixin class.""" + + @pytest.fixture + def http_client(self): + """Create a mock client with HttpMixin for testing.""" + return MockHttpClient() + + @pytest.mark.asyncio + async def test_create_client(self, http_client): + """Test HTTP client creation with proper configuration.""" + client = await http_client._create_client() + + assert isinstance(client, httpx.AsyncClient) + assert client.timeout.connect == 5.0 + assert client.timeout.read == 30 + assert client.follow_redirects is True + # HTTP/2 is enabled via parameter but not exposed as attribute + + await client.aclose() + + @pytest.mark.asyncio + async def test_ensure_client_creates_new(self, http_client): + """Test _ensure_client creates client when none exists.""" + assert http_client._client is None + + client = await http_client._ensure_client() + + assert client is not None + assert http_client._client is client + + await client.aclose() + + @pytest.mark.asyncio + async def test_ensure_client_reuses_existing(self, http_client): + """Test _ensure_client reuses existing client.""" + mock_client = Mock(spec=httpx.AsyncClient) + http_client._client = mock_client + + client = await http_client._ensure_client() + + assert client is mock_client + + @pytest.mark.asyncio + async def test_make_request_success(self, http_client): + """Test successful API request.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"success": True, "data": "test"} + + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request.return_value = mock_response + http_client._client = mock_client + + result = await http_client._make_request("GET", "/test") + + assert result == {"success": True, "data": "test"} + assert http_client.api_call_count == 1 + mock_client.request.assert_called_once() + http_client.rate_limiter.acquire.assert_called_once() + + @pytest.mark.asyncio + async def test_make_request_with_data_and_params(self, http_client): + """Test API request with data and parameters.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"result": "ok"} + + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request.return_value = mock_response + http_client._client = mock_client + + data = {"key": "value"} + params = {"param": "test"} + + result = await http_client._make_request( + "POST", "/test", data=data, params=params + ) + + assert result == {"result": "ok"} + call_args = mock_client.request.call_args + assert call_args.kwargs["json"] == data + assert call_args.kwargs["params"] == params + + @pytest.mark.asyncio + async def test_make_request_with_auth_token(self, http_client): + """Test that auth token is included in headers.""" + http_client.session_token = "test_token" + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {} + + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request.return_value = mock_response + http_client._client = mock_client + + await http_client._make_request("GET", "/test") + + call_args = mock_client.request.call_args + assert "Authorization" in call_args.kwargs["headers"] + assert call_args.kwargs["headers"]["Authorization"] == "Bearer test_token" + + @pytest.mark.asyncio + async def test_make_request_no_auth_for_login(self, http_client): + """Test that auth token is not included for login endpoint.""" + http_client.session_token = "test_token" + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"token": "new_token"} + + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request.return_value = mock_response + http_client._client = mock_client + + await http_client._make_request("POST", "/Auth/loginKey") + + call_args = mock_client.request.call_args + # Should not have Authorization header for login endpoint + assert "Authorization" not in call_args.kwargs["headers"] + + @pytest.mark.asyncio + @patch("asyncio.sleep", return_value=None) # Mock sleep to avoid waiting + async def test_make_request_rate_limit_error(self, mock_sleep, http_client): + """Test handling of rate limit errors.""" + mock_response = Mock() + mock_response.status_code = 429 + mock_response.headers = {"Retry-After": "30"} + + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request.return_value = mock_response + http_client._client = mock_client + + with pytest.raises(ProjectXRateLimitError, match="Rate limit exceeded"): + await http_client._make_request("GET", "/test") + + @pytest.mark.asyncio + async def test_make_request_connection_error(self, http_client): + """Test handling of connection errors.""" + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request.side_effect = httpx.ConnectError("Connection failed") + http_client._client = mock_client + + with pytest.raises(ProjectXConnectionError, match="Connection failed"): + await http_client._make_request("GET", "/test") + + @pytest.mark.asyncio + async def test_make_request_timeout_error(self, http_client): + """Test handling of timeout errors.""" + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request.side_effect = httpx.TimeoutException("Request timed out") + http_client._client = mock_client + + with pytest.raises(ProjectXConnectionError, match="Request timed out"): + await http_client._make_request("GET", "/test") + + @pytest.mark.asyncio + async def test_make_request_401_refresh_auth(self, http_client): + """Test that 401 triggers authentication refresh.""" + http_client.session_token = "expired_token" + + # First response: 401 error + mock_response_401 = Mock() + mock_response_401.status_code = 401 + + # Second response after refresh: success + mock_response_success = Mock() + mock_response_success.status_code = 200 + mock_response_success.json.return_value = {"data": "refreshed"} + + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request.side_effect = [mock_response_401, mock_response_success] + http_client._client = mock_client + + result = await http_client._make_request("GET", "/test") + + assert result == {"data": "refreshed"} + http_client._refresh_authentication.assert_called_once() + assert mock_client.request.call_count == 2 + + @pytest.mark.asyncio + async def test_make_request_401_on_login_endpoint(self, http_client): + """Test that 401 on login endpoint raises error without refresh.""" + mock_response = Mock() + mock_response.status_code = 401 + + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request.return_value = mock_response + http_client._client = mock_client + + with pytest.raises(ProjectXAuthenticationError, match="Authentication failed"): + await http_client._make_request("POST", "/Auth/loginKey") + + http_client._refresh_authentication.assert_not_called() + + @pytest.mark.asyncio + async def test_make_request_404_error(self, http_client): + """Test handling of 404 errors.""" + mock_response = Mock() + mock_response.status_code = 404 + mock_response.json.side_effect = Exception("Not JSON") + mock_response.text = "Not found" + + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request.return_value = mock_response + http_client._client = mock_client + + with pytest.raises(ProjectXDataError, match="Resource not found"): + await http_client._make_request("GET", "/test") + + @pytest.mark.asyncio + async def test_make_request_400_error_with_message(self, http_client): + """Test handling of 400 errors with error message.""" + mock_response = Mock() + mock_response.status_code = 400 + mock_response.json.return_value = {"message": "Invalid request parameters"} + + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request.return_value = mock_response + http_client._client = mock_client + + with pytest.raises(ProjectXError, match="Invalid request parameters"): + await http_client._make_request("POST", "/test") + + @pytest.mark.asyncio + async def test_make_request_400_error_with_error_field(self, http_client): + """Test handling of 400 errors with error field.""" + mock_response = Mock() + mock_response.status_code = 400 + mock_response.json.return_value = {"error": "Bad request"} + + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request.return_value = mock_response + http_client._client = mock_client + + with pytest.raises(ProjectXError, match="Bad request"): + await http_client._make_request("POST", "/test") + + @pytest.mark.asyncio + async def test_make_request_500_error(self, http_client): + """Test handling of server errors.""" + mock_response = Mock() + mock_response.status_code = 500 + mock_response.text = "Internal server error occurred" + + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request.return_value = mock_response + http_client._client = mock_client + + with pytest.raises(ProjectXServerError, match="Server error"): + await http_client._make_request("GET", "/test") + + @pytest.mark.asyncio + async def test_make_request_204_no_content(self, http_client): + """Test handling of 204 No Content response.""" + mock_response = Mock() + mock_response.status_code = 204 + + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request.return_value = mock_response + http_client._client = mock_client + + result = await http_client._make_request("DELETE", "/test") + + assert result == {} + + @pytest.mark.asyncio + async def test_make_request_json_parse_error(self, http_client): + """Test handling of JSON parsing errors.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.side_effect = ValueError("Invalid JSON") + + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request.return_value = mock_response + http_client._client = mock_client + + with pytest.raises(ProjectXDataError, match="Failed to parse"): + await http_client._make_request("GET", "/test") + + @pytest.mark.asyncio + async def test_get_health_status(self, http_client): + """Test health status reporting.""" + http_client.api_call_count = 10 + http_client.cache_hit_count = 5 + http_client._client = Mock(spec=httpx.AsyncClient) + http_client._client.is_closed = False + + status = await http_client.get_health_status() + + assert status["api_calls"] == 10 + assert status["cache_hits"] == 5 + assert status["cache_misses"] == 10 + assert status["total_requests"] == 15 + assert status["cache_hit_ratio"] == 5 / 15 + assert status["active_connections"] == 1 + + @pytest.mark.asyncio + async def test_get_health_status_no_requests(self, http_client): + """Test health status with no requests.""" + status = await http_client.get_health_status() + + assert status["api_calls"] == 0 + assert status["cache_hits"] == 0 + assert status["cache_misses"] == 0 + assert status["total_requests"] == 0 + assert status["cache_hit_ratio"] == 0 + assert status["active_connections"] == 0 + + @pytest.mark.asyncio + async def test_get_health_status_closed_client(self, http_client): + """Test health status with closed client.""" + http_client._client = Mock(spec=httpx.AsyncClient) + http_client._client.is_closed = True + + status = await http_client.get_health_status() + + assert status["active_connections"] == 0 diff --git a/tests/test_client_market_data.py b/tests/test_client_market_data.py new file mode 100644 index 0000000..8d92a07 --- /dev/null +++ b/tests/test_client_market_data.py @@ -0,0 +1,470 @@ +"""Tests for the market data module of ProjectX client.""" + +import asyncio +from datetime import datetime, timedelta +from unittest.mock import AsyncMock, Mock, patch + +import polars as pl +import pytest +import pytz + +from project_x_py.client.market_data import MarketDataMixin +from project_x_py.exceptions import ProjectXDataError, ProjectXInstrumentError +from project_x_py.models import Instrument + + +class MockMarketDataClient(MarketDataMixin): + """Mock client that includes MarketDataMixin for testing.""" + + def __init__(self): + super().__init__() + self.base_url = "https://api.test.com" + self._http_client = AsyncMock() + self._make_request = AsyncMock() + self._ensure_authenticated = AsyncMock() + self._authenticated = False + self.config = Mock() + self.config.timezone = "America/Chicago" + self.logger = Mock() + # Mock cache methods + self.get_cached_instrument = Mock(return_value=None) + self.cache_instrument = Mock() + self.get_cached_market_data = Mock(return_value=None) + self.cache_market_data = Mock() + + +class TestMarketDataMixin: + """Test suite for MarketDataMixin class.""" + + @pytest.fixture + def market_client(self): + """Create a mock client with MarketDataMixin for testing.""" + return MockMarketDataClient() + + @pytest.mark.asyncio + async def test_get_instrument_success(self, market_client): + """Test successful instrument retrieval.""" + mock_response = { + "success": True, + "contracts": [ + { + "id": "CON.F.US.MNQ.U25", + "name": "MNQU25", + "description": "Micro E-mini Nasdaq-100 Sep 2025", + "tickSize": 0.25, + "tickValue": 12.5, + "activeContract": True, + "symbolId": "MNQ", + } + ], + } + market_client._make_request.return_value = mock_response + + instrument = await market_client.get_instrument("MNQ") + + assert instrument.id == "CON.F.US.MNQ.U25" + assert instrument.name == "MNQU25" + assert instrument.description == "Micro E-mini Nasdaq-100 Sep 2025" + assert instrument.tickSize == 0.25 + assert instrument.tickValue == 12.5 + assert instrument.activeContract is True + market_client.cache_instrument.assert_called_once() + + @pytest.mark.asyncio + async def test_get_instrument_with_cache_hit(self, market_client): + """Test instrument retrieval when cache has the data.""" + cached_instrument = Instrument( + id="CON.F.US.MNQ.U25", + name="MNQU25", + description="Micro E-mini Nasdaq-100 Sep 2025", + tickSize=0.25, + tickValue=12.5, + activeContract=True, + symbolId="MNQ", + ) + market_client.get_cached_instrument.return_value = cached_instrument + + instrument = await market_client.get_instrument("MNQ") + + assert instrument == cached_instrument + market_client._make_request.assert_not_called() + + @pytest.mark.asyncio + async def test_get_instrument_direct_contract_id(self, market_client): + """Test getting instrument by direct contract ID.""" + mock_response = { + "success": True, + "contracts": [ + { + "id": "CON.F.US.MNQ.U25", + "name": "MNQU25", + "description": "Micro E-mini Nasdaq-100 futures", + "tickSize": 0.25, + "tickValue": 12.5, + "activeContract": True, + "symbolId": "MNQ", + } + ], + } + market_client._make_request.return_value = mock_response + + instrument = await market_client.get_instrument("CON.F.US.MNQ.U25") + + assert instrument.id == "CON.F.US.MNQ.U25" + # Should search by ID when format matches CON.* + market_client._make_request.assert_called_once() + call_args = market_client._make_request.call_args + # Check that it uses /Contract/search endpoint + assert call_args[0][1] == "/Contract/search" + + @pytest.mark.asyncio + async def test_get_instrument_no_results(self, market_client): + """Test instrument retrieval with no results.""" + mock_response = {"success": True, "contracts": []} + market_client._make_request.return_value = mock_response + + with pytest.raises(ProjectXInstrumentError, match="Instrument not found"): + await market_client.get_instrument("INVALID") + + @pytest.mark.asyncio + async def test_get_instrument_api_error(self, market_client): + """Test instrument retrieval with API error.""" + mock_response = {"success": False, "error": "API Error"} + market_client._make_request.return_value = mock_response + + with pytest.raises(Exception): # Will be wrapped by handle_errors + await market_client.get_instrument("MNQ") + + @pytest.mark.asyncio + async def test_search_instruments_success(self, market_client): + """Test successful instrument search.""" + mock_response = { + "success": True, + "contracts": [ + { + "id": "CON.F.US.MNQ.U25", + "name": "MNQU25", + "description": "Micro E-mini Nasdaq-100 futures", + "tickSize": 0.25, + "tickValue": 12.5, + "activeContract": True, + "symbolId": "MNQ", + }, + { + "id": "CON.F.US.MNQ.Z25", + "name": "MNQZ25", + "description": "Micro E-mini Nasdaq-100 futures", + "tickSize": 0.25, + "tickValue": 12.5, + "activeContract": True, + "symbolId": "MNQ", + }, + ], + } + market_client._make_request.return_value = mock_response + + instruments = await market_client.search_instruments("MNQ") + + assert len(instruments) == 2 + assert instruments[0].name == "MNQU25" + assert instruments[1].name == "MNQZ25" + + @pytest.mark.asyncio + async def test_search_instruments_live_only(self, market_client): + """Test instrument search with live-only filter.""" + mock_response = {"success": True, "contracts": []} + market_client._make_request.return_value = mock_response + + instruments = await market_client.search_instruments("MNQ", live=True) + + assert instruments == [] + # Check that live parameter was passed + call_args = market_client._make_request.call_args + assert "data" in call_args[1] + assert call_args[1]["data"]["live"] is True + + @pytest.mark.asyncio + async def test_get_bars_success(self, market_client): + """Test successful bar data retrieval.""" + # Mock instrument response + instrument_response = { + "success": True, + "contracts": [ + { + "id": "CON.F.US.MNQ.U25", + "name": "MNQU25", + "description": "Micro E-mini Nasdaq-100 futures", + "tickSize": 0.25, + "tickValue": 12.5, + "activeContract": True, + "symbolId": "MNQ", + } + ], + } + + # Mock bars response + bars_response = { + "success": True, + "bars": [ + { + "t": "2025-01-15T14:00:00Z", + "o": 21000.0, + "h": 21050.0, + "l": 20950.0, + "c": 21025.0, + "v": 1500, + }, + { + "t": "2025-01-15T14:05:00Z", + "o": 21025.0, + "h": 21075.0, + "l": 21020.0, + "c": 21070.0, + "v": 2000, + }, + ], + } + + market_client._make_request.side_effect = [instrument_response, bars_response] + + bars = await market_client.get_bars("MNQ", days=1, interval=5) + + assert isinstance(bars, pl.DataFrame) + assert len(bars) == 2 + assert "timestamp" in bars.columns + assert "open" in bars.columns + assert "high" in bars.columns + assert "low" in bars.columns + assert "close" in bars.columns + assert "volume" in bars.columns + assert bars["close"][0] == 21025.0 + assert bars["volume"][1] == 2000 + + @pytest.mark.asyncio + async def test_get_bars_with_cache_hit(self, market_client): + """Test bar data retrieval when cache has the data.""" + cached_bars = pl.DataFrame( + { + "timestamp": [datetime.now(pytz.UTC)], + "open": [21000.0], + "high": [21050.0], + "low": [20950.0], + "close": [21025.0], + "volume": [1500], + } + ) + market_client.get_cached_market_data.return_value = cached_bars + + bars = await market_client.get_bars("MNQ", days=1) + + assert bars.equals(cached_bars) + market_client._make_request.assert_not_called() + + @pytest.mark.asyncio + async def test_get_bars_with_time_range(self, market_client): + """Test bar data retrieval with specific time range.""" + # Mock instrument response + instrument_response = { + "success": True, + "contracts": [ + { + "id": "CON.F.US.MNQ.U25", + "name": "MNQU25", + "description": "Micro E-mini Nasdaq-100 futures", + "tickSize": 0.25, + "tickValue": 12.5, + "activeContract": True, + "symbolId": "MNQ", + } + ], + } + + # Mock bars response + bars_response = {"success": True, "bars": []} + + market_client._make_request.side_effect = [instrument_response, bars_response] + + start_time = datetime(2025, 1, 1, 9, 30) + end_time = datetime(2025, 1, 1, 16, 0) + + bars = await market_client.get_bars( + "MNQ", start_time=start_time, end_time=end_time, interval=15 + ) + + # Verify the request included time range parameters + calls = market_client._make_request.call_args_list + bars_call = calls[1] + assert "data" in bars_call[1] + assert "startTime" in bars_call[1]["data"] + assert "endTime" in bars_call[1]["data"] + + @pytest.mark.asyncio + async def test_get_bars_empty_response(self, market_client): + """Test bar data retrieval with empty response.""" + # Mock instrument response + instrument_response = { + "success": True, + "contracts": [ + { + "id": "CON.F.US.MNQ.U25", + "name": "MNQU25", + "description": "Micro E-mini Nasdaq-100 futures", + "tickSize": 0.25, + "tickValue": 12.5, + "activeContract": True, + "symbolId": "MNQ", + } + ], + } + + bars_response = {"success": True, "bars": []} + + market_client._make_request.side_effect = [instrument_response, bars_response] + + bars = await market_client.get_bars("MNQ", days=1) + + assert isinstance(bars, pl.DataFrame) + assert len(bars) == 0 + # Empty DataFrame won't have columns + + @pytest.mark.asyncio + async def test_get_bars_api_error(self, market_client): + """Test bar data retrieval with API error.""" + # Mock instrument response + instrument_response = { + "success": True, + "contracts": [ + { + "id": "CON.F.US.MNQ.U25", + "name": "MNQU25", + "description": "Micro E-mini Nasdaq-100 futures", + "tickSize": 0.25, + "tickValue": 12.5, + "activeContract": True, + "symbolId": "MNQ", + } + ], + } + + market_client._make_request.side_effect = [ + instrument_response, + Exception("API Error"), + ] + + with pytest.raises(Exception, match="API Error"): + await market_client.get_bars("MNQ", days=1) + + def test_select_best_contract_exact_match(self, market_client): + """Test contract selection with exact symbol match.""" + instruments = [ + {"name": "MNQ", "description": "Base MNQ"}, + {"name": "MNQU25", "description": "MNQ Sep 2025"}, + {"name": "MNQZ25", "description": "MNQ Dec 2025"}, + ] + + result = market_client._select_best_contract(instruments, "MNQ") + + assert result["name"] == "MNQ" + + def test_select_best_contract_futures_front_month(self, market_client): + """Test contract selection for futures front month.""" + instruments = [ + {"name": "MNQZ25", "description": "MNQ Dec 2025"}, + {"name": "MNQU25", "description": "MNQ Sep 2025"}, + {"name": "MNQH26", "description": "MNQ Mar 2026"}, + ] + + result = market_client._select_best_contract(instruments, "MNQ") + + # Should select the chronologically first (front month) + assert result["name"] == "MNQH26" + + def test_select_best_contract_no_instruments(self, market_client): + """Test contract selection with no instruments.""" + with pytest.raises(ProjectXInstrumentError, match="No instruments found"): + market_client._select_best_contract([], "MNQ") + + def test_select_best_contract_case_insensitive(self, market_client): + """Test contract selection is case insensitive.""" + instruments = [ + {"name": "mnq", "description": "Base MNQ"}, + {"name": "MNQU25", "description": "MNQ Sep 2025"}, + ] + + result = market_client._select_best_contract(instruments, "MNQ") + + assert result["name"] == "mnq" + + @pytest.mark.asyncio + async def test_get_bars_different_intervals(self, market_client): + """Test bar data retrieval with different interval units.""" + # Mock responses + instrument_response = { + "success": True, + "contracts": [ + { + "id": "CON.F.US.ES.U25", + "name": "ESU25", + "description": "E-mini S&P 500 futures", + "tickSize": 0.25, + "tickValue": 12.5, + "activeContract": True, + "symbolId": "ES", + } + ], + } + + bars_response = {"success": True, "bars": []} + + market_client._make_request.side_effect = [ + instrument_response, + bars_response, + instrument_response, + bars_response, + instrument_response, + bars_response, + ] + + # Test minute bars + bars = await market_client.get_bars("ES", days=1, interval=15, unit=2) + assert isinstance(bars, pl.DataFrame) + + # Test hourly bars + bars = await market_client.get_bars("ES", days=7, interval=1, unit=3) + assert isinstance(bars, pl.DataFrame) + + # Test daily bars + bars = await market_client.get_bars("ES", days=30, interval=1, unit=4) + assert isinstance(bars, pl.DataFrame) + + @pytest.mark.asyncio + async def test_authentication_check(self, market_client): + """Test that methods ensure authentication.""" + # Mock responses for each method + instrument_response = {"success": True, "contracts": []} + + market_client._make_request.return_value = instrument_response + + # Test get_instrument + try: + await market_client.get_instrument("MNQ") + except: + pass + market_client._ensure_authenticated.assert_called() + + # Reset mock + market_client._ensure_authenticated.reset_mock() + + # Test search_instruments + await market_client.search_instruments("MNQ") + market_client._ensure_authenticated.assert_called() + + # Reset mock + market_client._ensure_authenticated.reset_mock() + + # Test get_bars (will fail but that's ok) + try: + await market_client.get_bars("MNQ") + except: + pass + market_client._ensure_authenticated.assert_called() diff --git a/tests/test_client_trading.py b/tests/test_client_trading.py new file mode 100644 index 0000000..271ebcc --- /dev/null +++ b/tests/test_client_trading.py @@ -0,0 +1,594 @@ +"""Comprehensive tests for the trading module of ProjectX client.""" + +import datetime +from datetime import timedelta +from unittest.mock import AsyncMock, Mock, patch + +import pytest +import pytz + +from project_x_py.client.trading import TradingMixin +from project_x_py.exceptions import ProjectXError +from project_x_py.models import Account, Position, Trade + + +class MockTradingClient(TradingMixin): + """Mock client that includes TradingMixin for testing.""" + + def __init__(self): + super().__init__() + self.base_url = "https://api.test.com" + self._http_client = AsyncMock() + self._make_request = AsyncMock() + self._ensure_authenticated = AsyncMock() + self.account_info = None + + +class TestTradingMixin: + """Test suite for TradingMixin class.""" + + @pytest.fixture + def trading_client(self): + """Create a mock client with TradingMixin for testing.""" + return MockTradingClient() + + @pytest.mark.asyncio + async def test_get_positions_deprecated(self, trading_client): + """Test that get_positions shows deprecation warning.""" + trading_client.account_info = Account( + id=12345, + name="Test Account", + balance=10000.0, + canTrade=True, + isVisible=True, + simulated=False, + ) + + mock_response = { + "success": True, + "positions": [ + { + "id": "pos1", + "accountId": 12345, + "contractId": "MNQ", + "creationTimestamp": datetime.datetime.now(pytz.UTC).isoformat(), + "size": 2, + "averagePrice": 15000.0, + "type": 1, + } + ], + } + trading_client._make_request.return_value = mock_response + + with pytest.warns(DeprecationWarning, match="get_positions.*deprecated"): + positions = await trading_client.get_positions() + + assert len(positions) == 1 + assert positions[0].contractId == "MNQ" + + @pytest.mark.asyncio + async def test_search_open_positions_success(self, trading_client): + """Test successful position search.""" + trading_client.account_info = Account( + id=12345, + name="Test Account", + balance=10000.0, + canTrade=True, + isVisible=True, + simulated=False, + ) + + mock_response = { + "success": True, + "positions": [ + { + "id": "pos1", + "accountId": 12345, + "contractId": "ES", + "creationTimestamp": datetime.datetime.now(pytz.UTC).isoformat(), + "size": 1, + "averagePrice": 4500.0, + "type": 1, + }, + { + "id": "pos2", + "accountId": 12345, + "contractId": "NQ", + "creationTimestamp": datetime.datetime.now(pytz.UTC).isoformat(), + "size": 2, + "averagePrice": 15000.0, + "type": 2, # SHORT type + }, + ], + } + trading_client._make_request.return_value = mock_response + + positions = await trading_client.search_open_positions() + + assert len(positions) == 2 + assert positions[0].contractId == "ES" + assert positions[0].size == 1 + assert positions[0].type == 1 # LONG + assert positions[1].contractId == "NQ" + assert positions[1].size == 2 + assert positions[1].type == 2 # SHORT + trading_client._ensure_authenticated.assert_called_once() + + @pytest.mark.asyncio + async def test_search_open_positions_with_account_id(self, trading_client): + """Test position search with specific account ID.""" + # No account_info set, but provide explicit account_id + custom_account_id = 67890 + + mock_response = { + "success": True, + "positions": [ + { + "id": "pos1", + "accountId": custom_account_id, + "contractId": "MGC", + "creationTimestamp": datetime.datetime.now(pytz.UTC).isoformat(), + "size": 10, + "averagePrice": 1900.0, + "type": 1, + } + ], + } + trading_client._make_request.return_value = mock_response + + positions = await trading_client.search_open_positions( + account_id=custom_account_id + ) + + assert len(positions) == 1 + assert positions[0].accountId == custom_account_id + + # Verify the request was made with the custom account ID + trading_client._make_request.assert_called_once_with( + "POST", "/Position/searchOpen", data={"accountId": custom_account_id} + ) + + @pytest.mark.asyncio + async def test_search_open_positions_no_account(self, trading_client): + """Test error when no account is available.""" + # No account_info and no account_id provided + trading_client.account_info = None + + with pytest.raises(ProjectXError, match="No account ID available"): + await trading_client.search_open_positions() + + @pytest.mark.asyncio + async def test_search_open_positions_list_response(self, trading_client): + """Test handling of list response format (new API).""" + trading_client.account_info = Account( + id=12345, + name="Test Account", + balance=10000.0, + canTrade=True, + isVisible=True, + simulated=False, + ) + + # API returns list directly (new format) + mock_response = [ + { + "id": "pos1", + "accountId": 12345, + "contractId": "CL", + "creationTimestamp": datetime.datetime.now(pytz.UTC).isoformat(), + "size": 5, + "averagePrice": 75.50, + "type": 1, + } + ] + trading_client._make_request.return_value = mock_response + + positions = await trading_client.search_open_positions() + + assert len(positions) == 1 + assert positions[0].contractId == "CL" + + @pytest.mark.asyncio + async def test_search_open_positions_empty(self, trading_client): + """Test empty position response.""" + trading_client.account_info = Account( + id=12345, + name="Test Account", + balance=10000.0, + canTrade=True, + isVisible=True, + simulated=False, + ) + + mock_response = {"success": True, "positions": []} + trading_client._make_request.return_value = mock_response + + positions = await trading_client.search_open_positions() + + assert positions == [] + + @pytest.mark.asyncio + async def test_search_open_positions_none_response(self, trading_client): + """Test None response handling.""" + trading_client.account_info = Account( + id=12345, + name="Test Account", + balance=10000.0, + canTrade=True, + isVisible=True, + simulated=False, + ) + + trading_client._make_request.return_value = None + + positions = await trading_client.search_open_positions() + + assert positions == [] + + @pytest.mark.asyncio + async def test_search_open_positions_failed_response(self, trading_client): + """Test handling of failed API response.""" + trading_client.account_info = Account( + id=12345, + name="Test Account", + balance=10000.0, + canTrade=True, + isVisible=True, + simulated=False, + ) + + mock_response = {"success": False, "error": "API Error"} + trading_client._make_request.return_value = mock_response + + positions = await trading_client.search_open_positions() + + assert positions == [] + + @pytest.mark.asyncio + async def test_search_open_positions_invalid_response_type(self, trading_client): + """Test handling of invalid response type.""" + trading_client.account_info = Account( + id=12345, + name="Test Account", + balance=10000.0, + canTrade=True, + isVisible=True, + simulated=False, + ) + + # Invalid response type (string instead of list/dict) + trading_client._make_request.return_value = "invalid response" + + positions = await trading_client.search_open_positions() + + assert positions == [] + + @pytest.mark.asyncio + async def test_search_trades_success(self, trading_client): + """Test successful trade search.""" + trading_client.account_info = Account( + id=12345, + name="Test Account", + balance=10000.0, + canTrade=True, + isVisible=True, + simulated=False, + ) + + mock_response = [ + { + "id": 1, + "accountId": 12345, + "contractId": "ES", + "creationTimestamp": datetime.datetime.now(pytz.UTC).isoformat(), + "price": 4500.0, + "profitAndLoss": 50.0, + "fees": 2.50, + "side": 0, # Buy + "size": 2, + "voided": False, + "orderId": 100, + }, + { + "id": 2, + "accountId": 12345, + "contractId": "NQ", + "creationTimestamp": ( + datetime.datetime.now(pytz.UTC) - timedelta(hours=1) + ).isoformat(), + "price": 15000.0, + "profitAndLoss": None, # Half-turn trade + "fees": 2.25, + "side": 1, # Sell + "size": 1, + "voided": False, + "orderId": 101, + }, + ] + trading_client._make_request.return_value = mock_response + + trades = await trading_client.search_trades() + + assert len(trades) == 2 + assert trades[0].contractId == "ES" + assert trades[0].size == 2 + assert trades[0].price == 4500.0 + assert trades[0].side == 0 # Buy + assert trades[1].contractId == "NQ" + assert trades[1].size == 1 + assert trades[1].side == 1 # Sell + trading_client._ensure_authenticated.assert_called_once() + + @pytest.mark.asyncio + async def test_search_trades_with_date_range(self, trading_client): + """Test trade search with custom date range.""" + trading_client.account_info = Account( + id=12345, + name="Test Account", + balance=10000.0, + canTrade=True, + isVisible=True, + simulated=False, + ) + + start_date = datetime.datetime(2025, 1, 1, 9, 30, tzinfo=pytz.UTC) + end_date = datetime.datetime(2025, 1, 15, 16, 0, tzinfo=pytz.UTC) + + mock_response = [] + trading_client._make_request.return_value = mock_response + + trades = await trading_client.search_trades( + start_date=start_date, + end_date=end_date, + ) + + # Verify the request parameters + trading_client._make_request.assert_called_once_with( + "GET", + "/trades/search", + params={ + "accountId": 12345, + "startDate": start_date.isoformat(), + "endDate": end_date.isoformat(), + "limit": 100, + }, + ) + + @pytest.mark.asyncio + async def test_search_trades_with_contract_filter(self, trading_client): + """Test trade search with contract ID filter.""" + trading_client.account_info = Account( + id=12345, + name="Test Account", + balance=10000.0, + canTrade=True, + isVisible=True, + simulated=False, + ) + + mock_response = [ + { + "id": 1, + "accountId": 12345, + "contractId": "MNQ", + "creationTimestamp": datetime.datetime.now(pytz.UTC).isoformat(), + "price": 15000.0, + "profitAndLoss": 75.0, + "fees": 2.25, + "side": 0, # Buy + "size": 3, + "voided": False, + "orderId": 102, + } + ] + trading_client._make_request.return_value = mock_response + + trades = await trading_client.search_trades(contract_id="MNQ", limit=50) + + assert len(trades) == 1 + assert trades[0].contractId == "MNQ" + + # Verify contract_id was included in request + call_args = trading_client._make_request.call_args + assert call_args[1]["params"]["contractId"] == "MNQ" + assert call_args[1]["params"]["limit"] == 50 + + @pytest.mark.asyncio + async def test_search_trades_custom_account_id(self, trading_client): + """Test trade search with custom account ID.""" + # No account_info, use explicit account_id + custom_account_id = 67890 + + mock_response = [] + trading_client._make_request.return_value = mock_response + + trades = await trading_client.search_trades(account_id=custom_account_id) + + # Verify the request used custom account ID + call_args = trading_client._make_request.call_args + assert call_args[1]["params"]["accountId"] == custom_account_id + + @pytest.mark.asyncio + async def test_search_trades_no_account(self, trading_client): + """Test error when no account is available for trade search.""" + trading_client.account_info = None + + with pytest.raises(ProjectXError, match="No account information available"): + await trading_client.search_trades() + + @pytest.mark.asyncio + async def test_search_trades_default_dates(self, trading_client): + """Test default date range (30 days).""" + trading_client.account_info = Account( + id=12345, + name="Test Account", + balance=10000.0, + canTrade=True, + isVisible=True, + simulated=False, + ) + + mock_response = [] + trading_client._make_request.return_value = mock_response + + # Mock datetime.now to get consistent test results + mock_now = datetime.datetime(2025, 1, 15, 12, 0, tzinfo=pytz.UTC) + with patch("project_x_py.client.trading.datetime") as mock_datetime: + mock_datetime.datetime.now.return_value = mock_now + mock_datetime.timedelta = timedelta + + trades = await trading_client.search_trades() + + # Verify date range is approximately 30 days + call_args = trading_client._make_request.call_args + params = call_args[1]["params"] + + start_date = datetime.datetime.fromisoformat(params["startDate"]) + end_date = datetime.datetime.fromisoformat(params["endDate"]) + + date_diff = end_date - start_date + assert 29 <= date_diff.days <= 31 + + @pytest.mark.asyncio + async def test_search_trades_with_start_date_only(self, trading_client): + """Test trade search with only start date provided.""" + trading_client.account_info = Account( + id=12345, + name="Test Account", + balance=10000.0, + canTrade=True, + isVisible=True, + simulated=False, + ) + + start_date = datetime.datetime(2025, 1, 1, 9, 30, tzinfo=pytz.UTC) + + mock_response = [] + trading_client._make_request.return_value = mock_response + + # Mock datetime.now for end_date default + mock_now = datetime.datetime(2025, 1, 15, 12, 0, tzinfo=pytz.UTC) + with patch("project_x_py.client.trading.datetime") as mock_datetime: + mock_datetime.datetime.now.return_value = mock_now + mock_datetime.timedelta = timedelta + + trades = await trading_client.search_trades(start_date=start_date) + + # Verify end_date defaulted to now + call_args = trading_client._make_request.call_args + params = call_args[1]["params"] + + assert params["startDate"] == start_date.isoformat() + end_date = datetime.datetime.fromisoformat(params["endDate"]) + assert end_date == mock_now + + @pytest.mark.asyncio + async def test_search_trades_with_end_date_only(self, trading_client): + """Test trade search with only end date provided.""" + trading_client.account_info = Account( + id=12345, + name="Test Account", + balance=10000.0, + canTrade=True, + isVisible=True, + simulated=False, + ) + + end_date = datetime.datetime(2025, 1, 15, 16, 0, tzinfo=pytz.UTC) + + mock_response = [] + trading_client._make_request.return_value = mock_response + + trades = await trading_client.search_trades(end_date=end_date) + + # Verify start_date is 30 days before end_date + call_args = trading_client._make_request.call_args + params = call_args[1]["params"] + + start_date = datetime.datetime.fromisoformat(params["startDate"]) + assert params["endDate"] == end_date.isoformat() + + date_diff = end_date - start_date + assert 29 <= date_diff.days <= 31 + + @pytest.mark.asyncio + async def test_search_trades_empty_response(self, trading_client): + """Test handling of empty trade response.""" + trading_client.account_info = Account( + id=12345, + name="Test Account", + balance=10000.0, + canTrade=True, + isVisible=True, + simulated=False, + ) + + trading_client._make_request.return_value = [] + + trades = await trading_client.search_trades() + + assert trades == [] + + @pytest.mark.asyncio + async def test_search_trades_none_response(self, trading_client): + """Test handling of None response.""" + trading_client.account_info = Account( + id=12345, + name="Test Account", + balance=10000.0, + canTrade=True, + isVisible=True, + simulated=False, + ) + + trading_client._make_request.return_value = None + + trades = await trading_client.search_trades() + + assert trades == [] + + @pytest.mark.asyncio + async def test_search_trades_invalid_response_type(self, trading_client): + """Test handling of invalid response type.""" + trading_client.account_info = Account( + id=12345, + name="Test Account", + balance=10000.0, + canTrade=True, + isVisible=True, + simulated=False, + ) + + # Invalid response type (dict instead of list) + trading_client._make_request.return_value = {"trades": []} + + trades = await trading_client.search_trades() + + assert trades == [] + + @pytest.mark.asyncio + async def test_authentication_called_for_all_methods(self, trading_client): + """Test that all methods ensure authentication.""" + trading_client.account_info = Account( + id=12345, + name="Test Account", + balance=10000.0, + canTrade=True, + isVisible=True, + simulated=False, + ) + + trading_client._make_request.return_value = [] + + # Test search_open_positions + await trading_client.search_open_positions() + assert trading_client._ensure_authenticated.call_count == 1 + + # Test search_trades + await trading_client.search_trades() + assert trading_client._ensure_authenticated.call_count == 2 + + # Test get_positions (deprecated) + with pytest.warns(DeprecationWarning): + await trading_client.get_positions() + assert trading_client._ensure_authenticated.call_count == 3 diff --git a/uv.lock b/uv.lock index abcdb34..4a87783 100644 --- a/uv.lock +++ b/uv.lock @@ -309,6 +309,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/18/d8/9b768ac73a8ac2d10c080af23937212434a958c8d2a1c84e89b450237942/coverage-7.10.2-py3-none-any.whl", hash = "sha256:95db3750dd2e6e93d99fa2498f3a1580581e49c494bddccc6f85c5c21604921f", size = 206973 }, ] +[[package]] +name = "deprecated" +version = "1.2.18" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/98/97/06afe62762c9a8a86af0cfb7bfdab22a43ad17138b07af5b1a58442690a2/deprecated-1.2.18.tar.gz", hash = "sha256:422b6f6d859da6f2ef57857761bfb392480502a64c3028ca9bbe86085d72115d", size = 2928744 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6e/c6/ac0b6c1e2d138f1002bcf799d330bd6d85084fece321e662a14223794041/Deprecated-1.2.18-py2.py3-none-any.whl", hash = "sha256:bd5011788200372a32418f888e326a09ff80d0214bd961147cfed01b5c018eec", size = 9998 }, +] + [[package]] name = "distlib" version = "0.4.0" @@ -965,10 +977,11 @@ wheels = [ [[package]] name = "project-x-py" -version = "3.1.11" +version = "3.1.12" source = { editable = "." } dependencies = [ { name = "cachetools" }, + { name = "deprecated" }, { name = "httpx", extra = ["http2"] }, { name = "lz4" }, { name = "msgpack-python" }, @@ -1043,6 +1056,7 @@ dev = [ { name = "mypy" }, { name = "pre-commit" }, { name = "psutil" }, + { name = "pyjwt" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, @@ -1064,6 +1078,7 @@ requires-dist = [ { name = "aioresponses", marker = "extra == 'test'", specifier = ">=0.7.6" }, { name = "black", marker = "extra == 'dev'", specifier = ">=23.0.0" }, { name = "cachetools", specifier = ">=6.1.0" }, + { name = "deprecated", specifier = ">=1.2.18" }, { name = "httpx", extras = ["http2"], specifier = ">=0.27.0" }, { name = "isort", marker = "extra == 'dev'", specifier = ">=5.12.0" }, { name = "lz4", specifier = ">=4.4.4" }, @@ -1108,6 +1123,7 @@ dev = [ { name = "mypy", specifier = ">=1.0.0" }, { name = "pre-commit", specifier = ">=4.2.0" }, { name = "psutil", specifier = ">=7.0.0" }, + { name = "pyjwt", specifier = ">=2.10.1" }, { name = "pytest", specifier = ">=7.0.0" }, { name = "pytest-asyncio", specifier = ">=1.1.0" }, { name = "pytest-cov", specifier = ">=4.0.0" }, @@ -1261,6 +1277,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217 }, ] +[[package]] +name = "pyjwt" +version = "2.10.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997 }, +] + [[package]] name = "pytest" version = "8.4.1" @@ -1679,6 +1704,55 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ba/d1/501076b54481412df1bc4cdd1fe479f66e17857c63ec5981bedcdc2ca793/websocket_client-1.0.0-py2.py3-none-any.whl", hash = "sha256:57f876f1af4731cacb806cf54d02f5fbf75dee796053b9a5b94fd7c1d9621db9", size = 68319 }, ] +[[package]] +name = "wrapt" +version = "1.17.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/95/8f/aeb76c5b46e273670962298c23e7ddde79916cb74db802131d49a85e4b7d/wrapt-1.17.3.tar.gz", hash = "sha256:f66eb08feaa410fe4eebd17f2a2c8e2e46d3476e9f8c783daa8e09e0faa666d0", size = 55547 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9f/41/cad1aba93e752f1f9268c77270da3c469883d56e2798e7df6240dcb2287b/wrapt-1.17.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ab232e7fdb44cdfbf55fc3afa31bcdb0d8980b9b95c38b6405df2acb672af0e0", size = 53998 }, + { url = "https://files.pythonhosted.org/packages/60/f8/096a7cc13097a1869fe44efe68dace40d2a16ecb853141394047f0780b96/wrapt-1.17.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:9baa544e6acc91130e926e8c802a17f3b16fbea0fd441b5a60f5cf2cc5c3deba", size = 39020 }, + { url = "https://files.pythonhosted.org/packages/33/df/bdf864b8997aab4febb96a9ae5c124f700a5abd9b5e13d2a3214ec4be705/wrapt-1.17.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6b538e31eca1a7ea4605e44f81a48aa24c4632a277431a6ed3f328835901f4fd", size = 39098 }, + { url = "https://files.pythonhosted.org/packages/9f/81/5d931d78d0eb732b95dc3ddaeeb71c8bb572fb01356e9133916cd729ecdd/wrapt-1.17.3-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:042ec3bb8f319c147b1301f2393bc19dba6e176b7da446853406d041c36c7828", size = 88036 }, + { url = "https://files.pythonhosted.org/packages/ca/38/2e1785df03b3d72d34fc6252d91d9d12dc27a5c89caef3335a1bbb8908ca/wrapt-1.17.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3af60380ba0b7b5aeb329bc4e402acd25bd877e98b3727b0135cb5c2efdaefe9", size = 88156 }, + { url = "https://files.pythonhosted.org/packages/b3/8b/48cdb60fe0603e34e05cffda0b2a4adab81fd43718e11111a4b0100fd7c1/wrapt-1.17.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0b02e424deef65c9f7326d8c19220a2c9040c51dc165cddb732f16198c168396", size = 87102 }, + { url = "https://files.pythonhosted.org/packages/3c/51/d81abca783b58f40a154f1b2c56db1d2d9e0d04fa2d4224e357529f57a57/wrapt-1.17.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:74afa28374a3c3a11b3b5e5fca0ae03bef8450d6aa3ab3a1e2c30e3a75d023dc", size = 87732 }, + { url = "https://files.pythonhosted.org/packages/9e/b1/43b286ca1392a006d5336412d41663eeef1ad57485f3e52c767376ba7e5a/wrapt-1.17.3-cp312-cp312-win32.whl", hash = "sha256:4da9f45279fff3543c371d5ababc57a0384f70be244de7759c85a7f989cb4ebe", size = 36705 }, + { url = "https://files.pythonhosted.org/packages/28/de/49493f962bd3c586ab4b88066e967aa2e0703d6ef2c43aa28cb83bf7b507/wrapt-1.17.3-cp312-cp312-win_amd64.whl", hash = "sha256:e71d5c6ebac14875668a1e90baf2ea0ef5b7ac7918355850c0908ae82bcb297c", size = 38877 }, + { url = "https://files.pythonhosted.org/packages/f1/48/0f7102fe9cb1e8a5a77f80d4f0956d62d97034bbe88d33e94699f99d181d/wrapt-1.17.3-cp312-cp312-win_arm64.whl", hash = "sha256:604d076c55e2fdd4c1c03d06dc1a31b95130010517b5019db15365ec4a405fc6", size = 36885 }, + { url = "https://files.pythonhosted.org/packages/fc/f6/759ece88472157acb55fc195e5b116e06730f1b651b5b314c66291729193/wrapt-1.17.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a47681378a0439215912ef542c45a783484d4dd82bac412b71e59cf9c0e1cea0", size = 54003 }, + { url = "https://files.pythonhosted.org/packages/4f/a9/49940b9dc6d47027dc850c116d79b4155f15c08547d04db0f07121499347/wrapt-1.17.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:54a30837587c6ee3cd1a4d1c2ec5d24e77984d44e2f34547e2323ddb4e22eb77", size = 39025 }, + { url = "https://files.pythonhosted.org/packages/45/35/6a08de0f2c96dcdd7fe464d7420ddb9a7655a6561150e5fc4da9356aeaab/wrapt-1.17.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:16ecf15d6af39246fe33e507105d67e4b81d8f8d2c6598ff7e3ca1b8a37213f7", size = 39108 }, + { url = "https://files.pythonhosted.org/packages/0c/37/6faf15cfa41bf1f3dba80cd3f5ccc6622dfccb660ab26ed79f0178c7497f/wrapt-1.17.3-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6fd1ad24dc235e4ab88cda009e19bf347aabb975e44fd5c2fb22a3f6e4141277", size = 88072 }, + { url = "https://files.pythonhosted.org/packages/78/f2/efe19ada4a38e4e15b6dff39c3e3f3f73f5decf901f66e6f72fe79623a06/wrapt-1.17.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0ed61b7c2d49cee3c027372df5809a59d60cf1b6c2f81ee980a091f3afed6a2d", size = 88214 }, + { url = "https://files.pythonhosted.org/packages/40/90/ca86701e9de1622b16e09689fc24b76f69b06bb0150990f6f4e8b0eeb576/wrapt-1.17.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:423ed5420ad5f5529db9ce89eac09c8a2f97da18eb1c870237e84c5a5c2d60aa", size = 87105 }, + { url = "https://files.pythonhosted.org/packages/fd/e0/d10bd257c9a3e15cbf5523025252cc14d77468e8ed644aafb2d6f54cb95d/wrapt-1.17.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e01375f275f010fcbf7f643b4279896d04e571889b8a5b3f848423d91bf07050", size = 87766 }, + { url = "https://files.pythonhosted.org/packages/e8/cf/7d848740203c7b4b27eb55dbfede11aca974a51c3d894f6cc4b865f42f58/wrapt-1.17.3-cp313-cp313-win32.whl", hash = "sha256:53e5e39ff71b3fc484df8a522c933ea2b7cdd0d5d15ae82e5b23fde87d44cbd8", size = 36711 }, + { url = "https://files.pythonhosted.org/packages/57/54/35a84d0a4d23ea675994104e667ceff49227ce473ba6a59ba2c84f250b74/wrapt-1.17.3-cp313-cp313-win_amd64.whl", hash = "sha256:1f0b2f40cf341ee8cc1a97d51ff50dddb9fcc73241b9143ec74b30fc4f44f6cb", size = 38885 }, + { url = "https://files.pythonhosted.org/packages/01/77/66e54407c59d7b02a3c4e0af3783168fff8e5d61def52cda8728439d86bc/wrapt-1.17.3-cp313-cp313-win_arm64.whl", hash = "sha256:7425ac3c54430f5fc5e7b6f41d41e704db073309acfc09305816bc6a0b26bb16", size = 36896 }, + { url = "https://files.pythonhosted.org/packages/02/a2/cd864b2a14f20d14f4c496fab97802001560f9f41554eef6df201cd7f76c/wrapt-1.17.3-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:cf30f6e3c077c8e6a9a7809c94551203c8843e74ba0c960f4a98cd80d4665d39", size = 54132 }, + { url = "https://files.pythonhosted.org/packages/d5/46/d011725b0c89e853dc44cceb738a307cde5d240d023d6d40a82d1b4e1182/wrapt-1.17.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:e228514a06843cae89621384cfe3a80418f3c04aadf8a3b14e46a7be704e4235", size = 39091 }, + { url = "https://files.pythonhosted.org/packages/2e/9e/3ad852d77c35aae7ddebdbc3b6d35ec8013af7d7dddad0ad911f3d891dae/wrapt-1.17.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:5ea5eb3c0c071862997d6f3e02af1d055f381b1d25b286b9d6644b79db77657c", size = 39172 }, + { url = "https://files.pythonhosted.org/packages/c3/f7/c983d2762bcce2326c317c26a6a1e7016f7eb039c27cdf5c4e30f4160f31/wrapt-1.17.3-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:281262213373b6d5e4bb4353bc36d1ba4084e6d6b5d242863721ef2bf2c2930b", size = 87163 }, + { url = "https://files.pythonhosted.org/packages/e4/0f/f673f75d489c7f22d17fe0193e84b41540d962f75fce579cf6873167c29b/wrapt-1.17.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dc4a8d2b25efb6681ecacad42fca8859f88092d8732b170de6a5dddd80a1c8fa", size = 87963 }, + { url = "https://files.pythonhosted.org/packages/df/61/515ad6caca68995da2fac7a6af97faab8f78ebe3bf4f761e1b77efbc47b5/wrapt-1.17.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:373342dd05b1d07d752cecbec0c41817231f29f3a89aa8b8843f7b95992ed0c7", size = 86945 }, + { url = "https://files.pythonhosted.org/packages/d3/bd/4e70162ce398462a467bc09e768bee112f1412e563620adc353de9055d33/wrapt-1.17.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d40770d7c0fd5cbed9d84b2c3f2e156431a12c9a37dc6284060fb4bec0b7ffd4", size = 86857 }, + { url = "https://files.pythonhosted.org/packages/2b/b8/da8560695e9284810b8d3df8a19396a6e40e7518059584a1a394a2b35e0a/wrapt-1.17.3-cp314-cp314-win32.whl", hash = "sha256:fbd3c8319de8e1dc79d346929cd71d523622da527cca14e0c1d257e31c2b8b10", size = 37178 }, + { url = "https://files.pythonhosted.org/packages/db/c8/b71eeb192c440d67a5a0449aaee2310a1a1e8eca41676046f99ed2487e9f/wrapt-1.17.3-cp314-cp314-win_amd64.whl", hash = "sha256:e1a4120ae5705f673727d3253de3ed0e016f7cd78dc463db1b31e2463e1f3cf6", size = 39310 }, + { url = "https://files.pythonhosted.org/packages/45/20/2cda20fd4865fa40f86f6c46ed37a2a8356a7a2fde0773269311f2af56c7/wrapt-1.17.3-cp314-cp314-win_arm64.whl", hash = "sha256:507553480670cab08a800b9463bdb881b2edeed77dc677b0a5915e6106e91a58", size = 37266 }, + { url = "https://files.pythonhosted.org/packages/77/ed/dd5cf21aec36c80443c6f900449260b80e2a65cf963668eaef3b9accce36/wrapt-1.17.3-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:ed7c635ae45cfbc1a7371f708727bf74690daedc49b4dba310590ca0bd28aa8a", size = 56544 }, + { url = "https://files.pythonhosted.org/packages/8d/96/450c651cc753877ad100c7949ab4d2e2ecc4d97157e00fa8f45df682456a/wrapt-1.17.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:249f88ed15503f6492a71f01442abddd73856a0032ae860de6d75ca62eed8067", size = 40283 }, + { url = "https://files.pythonhosted.org/packages/d1/86/2fcad95994d9b572db57632acb6f900695a648c3e063f2cd344b3f5c5a37/wrapt-1.17.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:5a03a38adec8066d5a37bea22f2ba6bbf39fcdefbe2d91419ab864c3fb515454", size = 40366 }, + { url = "https://files.pythonhosted.org/packages/64/0e/f4472f2fdde2d4617975144311f8800ef73677a159be7fe61fa50997d6c0/wrapt-1.17.3-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:5d4478d72eb61c36e5b446e375bbc49ed002430d17cdec3cecb36993398e1a9e", size = 108571 }, + { url = "https://files.pythonhosted.org/packages/cc/01/9b85a99996b0a97c8a17484684f206cbb6ba73c1ce6890ac668bcf3838fb/wrapt-1.17.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:223db574bb38637e8230eb14b185565023ab624474df94d2af18f1cdb625216f", size = 113094 }, + { url = "https://files.pythonhosted.org/packages/25/02/78926c1efddcc7b3aa0bc3d6b33a822f7d898059f7cd9ace8c8318e559ef/wrapt-1.17.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:e405adefb53a435f01efa7ccdec012c016b5a1d3f35459990afc39b6be4d5056", size = 110659 }, + { url = "https://files.pythonhosted.org/packages/dc/ee/c414501ad518ac3e6fe184753632fe5e5ecacdcf0effc23f31c1e4f7bfcf/wrapt-1.17.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:88547535b787a6c9ce4086917b6e1d291aa8ed914fdd3a838b3539dc95c12804", size = 106946 }, + { url = "https://files.pythonhosted.org/packages/be/44/a1bd64b723d13bb151d6cc91b986146a1952385e0392a78567e12149c7b4/wrapt-1.17.3-cp314-cp314t-win32.whl", hash = "sha256:41b1d2bc74c2cac6f9074df52b2efbef2b30bdfe5f40cb78f8ca22963bc62977", size = 38717 }, + { url = "https://files.pythonhosted.org/packages/79/d9/7cfd5a312760ac4dd8bf0184a6ee9e43c33e47f3dadc303032ce012b8fa3/wrapt-1.17.3-cp314-cp314t-win_amd64.whl", hash = "sha256:73d496de46cd2cdbdbcce4ae4bcdb4afb6a11234a1df9c085249d55166b95116", size = 41334 }, + { url = "https://files.pythonhosted.org/packages/46/78/10ad9781128ed2f99dbc474f43283b13fea8ba58723e98844367531c18e9/wrapt-1.17.3-cp314-cp314t-win_arm64.whl", hash = "sha256:f38e60678850c42461d4202739f9bf1e3a737c7ad283638251e79cc49effb6b6", size = 38471 }, + { url = "https://files.pythonhosted.org/packages/1f/f6/a933bd70f98e9cf3e08167fc5cd7aaaca49147e48411c0bd5ae701bb2194/wrapt-1.17.3-py3-none-any.whl", hash = "sha256:7171ae35d2c33d326ac19dd8facb1e82e5fd04ef8c6c0e394d7af55a55051c22", size = 23591 }, +] + [[package]] name = "yarl" version = "1.20.1"