diff --git a/.secrets.baseline b/.secrets.baseline index df51152..3e58244 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -133,7 +133,7 @@ "filename": "CHANGELOG.md", "hashed_secret": "89a6cfe2a229151e8055abee107d45ed087bbb4f", "is_verified": false, - "line_number": 2149 + "line_number": 2183 } ], "README.md": [ @@ -325,5 +325,5 @@ } ] }, - "generated_at": "2025-09-01T01:31:44Z" + "generated_at": "2025-09-02T01:57:55Z" } diff --git a/CHANGELOG.md b/CHANGELOG.md index 006ff20..70af97a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,40 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Migration guides will be provided for all breaking changes - Semantic versioning (MAJOR.MINOR.PATCH) is strictly followed +## [3.5.6] - 2025-02-02 + +### šŸ› Fixed + +**Multi-Instrument Event System**: +- **Event Forwarding**: Implemented event forwarding from instrument-specific EventBuses to suite-level EventBus +- **InstrumentContext Methods**: Added `on()`, `once()`, `off()`, and `wait_for()` methods that delegate to event_bus +- **Event Propagation**: Fixed broken event system that prevented `mnq_context.wait_for(EventType.NEW_BAR)` from working +- **Multi-Instrument Support**: Events now properly flow from individual instruments to the suite level + +**Bracket Order Improvements**: +- **Automatic Price Alignment**: Changed validation from failing to auto-aligning prices to tick size +- **Smart Adjustment**: Orders with misaligned prices are now automatically corrected instead of rejected +- **Better UX**: Improved user experience by handling price alignment transparently + +**Example Scripts**: +- **Advanced Trading Examples**: Fixed all 4 advanced trading examples with proper async/await patterns +- **Real-time Streaming**: Fixed bar data access in real-time streaming example +- **OrderBook Methods**: Corrected API usage with proper method names and parameters +- **TypedDict Access**: Fixed bracket notation for TypedDict field access + +### āœ… Testing + +- **Test Suite Updates**: Fixed 30 failing tests to match new correct behavior +- **Event System Tests**: Updated tests to verify event forwarding functionality +- **Price Alignment Tests**: Tests now verify automatic alignment instead of rejection +- **InstrumentContext Tests**: Added event_bus parameter to all test constructors + +### šŸ”§ Changed + +- **OrderManager**: Removed duplicate price alignment calls in `place_order()` method +- **TradingSuite**: Added `_setup_event_forwarding()` method for event bus connectivity +- **InstrumentContext**: Now requires `event_bus` parameter in constructor + ## [3.5.5] - 2025-01-21 ### āœ… Testing diff --git a/README.md b/README.md index 5f9a67e..e1fb15e 100644 --- a/README.md +++ b/README.md @@ -28,18 +28,18 @@ A **high-performance async Python SDK** for the [ProjectX Trading Platform](http This Python SDK acts as a bridge between your trading strategies and the ProjectX platform, handling all the complex API interactions, data processing, and real-time connectivity. -## šŸš€ v3.5.5 - Sessions Module Testing & Documentation +## šŸš€ v3.5.6 - Event System & Bracket Order Enhancements -**Latest Version**: v3.5.5 - Comprehensive testing and documentation improvements for the ETH vs RTH Trading Sessions feature, ensuring production-ready session filtering and analysis. +**Latest Version**: v3.5.6 - Critical fixes for multi-instrument event handling and automatic price alignment for bracket orders, ensuring robust real-time trading operations. -**Key Benefits**: -- šŸŽÆ **Multi-Asset Strategies**: Trade ES vs NQ pairs, commodity spreads, sector rotation -- šŸ“Š **Portfolio Management**: Unified risk management across multiple instruments -- šŸ”„ **Parallel Processing**: Efficient concurrent data processing and order management -- šŸ›”ļø **Backward Compatible**: Existing single-instrument code continues to work -- ⚔ **Performance Optimized**: Parallel context creation and resource sharing +**Key Improvements**: +- šŸ”„ **Event Forwarding**: Fixed multi-instrument event propagation with proper bus forwarding +- šŸŽÆ **Smart Price Alignment**: Bracket orders now auto-align to tick sizes instead of failing +- šŸ“Š **Enhanced Examples**: All advanced trading examples updated and tested +- šŸ›”ļø **Improved Reliability**: 30+ test fixes ensuring production stability +- ⚔ **Real-time Fixes**: Corrected bar data access in streaming examples -See [CHANGELOG.md](CHANGELOG.md) for complete v3.5.5 features including sessions module improvements and comprehensive example scripts. +See [CHANGELOG.md](CHANGELOG.md) for complete v3.5.6 fixes and previous version features. ### šŸ“¦ Production Stability Guarantee diff --git a/docs/examples/advanced.md b/docs/examples/advanced.md index d386f50..db7cd5f 100644 --- a/docs/examples/advanced.md +++ b/docs/examples/advanced.md @@ -17,175 +17,343 @@ Bracket orders combine entry, stop loss, and take profit in a single operation: ```python #!/usr/bin/env python """ -Advanced bracket order strategy with dynamic stops based on ATR +Advanced bracket order strategy with dynamic stops based on ATR. + +This example demonstrates: +- ATR-based dynamic stop loss and take profit levels +- RSI and SMA-based entry signals +- Bracket orders with automatic price alignment +- Real-time order monitoring and management +- Event-driven trade execution """ + import asyncio from decimal import Decimal -from project_x_py import TradingSuite, EventType +from typing import Optional + +from project_x_py import EventType, TradingSuite +from project_x_py.event_bus import Event from project_x_py.indicators import ATR, RSI, SMA +from project_x_py.models import BracketOrderResponse + class ATRBracketStrategy: + """Advanced bracket order strategy using ATR for dynamic stops.""" + def __init__(self, suite: TradingSuite): self.suite = suite self.atr_period = 14 self.rsi_period = 14 self.sma_period = 20 self.position_size = 1 - self.active_orders = [] + self.active_orders: list[BracketOrderResponse] = [] + self.max_positions = 1 # Limit concurrent positions - async def calculate_dynamic_levels(self): + async def calculate_dynamic_levels(self) -> tuple[float, float]: """Calculate stop and target levels based on ATR.""" - bars = await self.suite.data.get_data("5min") + try: + # Get bars for ATR calculation + bars = await self.suite["MNQ"].data.get_data("5min") - # Calculate ATR for volatility-based stops - atr_values = bars.pipe(ATR, period=self.atr_period) - current_atr = float(atr_values[-1]) + if bars is None or bars.is_empty(): + print("No data available for ATR calculation") + # Return default values + return 50.0, 100.0 - # Dynamic stop loss: 2x ATR - stop_offset = Decimal(str(current_atr * 2)) + # Calculate ATR for volatility-based stops + with_atr = bars.pipe(ATR, period=self.atr_period) - # Dynamic take profit: 3x ATR (1.5:1 reward:risk) - target_offset = Decimal(str(current_atr * 3)) + # Get current ATR value + atr_column = f"atr_{self.atr_period}" + if atr_column not in with_atr.columns: + print(f"ATR column {atr_column} not found") + return 50.0, 100.0 - return stop_offset, target_offset + current_atr = float(with_atr[atr_column].tail(1)[0]) - async def check_entry_conditions(self): + # Dynamic stop loss: 2x ATR + stop_offset = current_atr * 2 + + # Dynamic take profit: 3x ATR (1.5:1 reward:risk) + target_offset = current_atr * 3 + + return stop_offset, target_offset + + except Exception as e: + print(f"Error calculating ATR levels: {e}") + # Return default values on error + return 50.0, 100.0 + + async def check_entry_conditions(self) -> tuple[Optional[str], Optional[float]]: """Check if conditions are met for entry.""" - bars = await self.suite.data.get_data("5min") + try: + bars = await self.suite["MNQ"].data.get_data("5min") - if len(bars) < max(self.rsi_period, self.sma_period, self.atr_period): - return None, None + if bars is None or bars.is_empty(): + return None, None + + # Ensure we have enough data for indicators + if len(bars) < max(self.rsi_period, self.sma_period, self.atr_period): + return None, None + + # Calculate indicators using pipe method + with_rsi = bars.pipe(RSI, period=self.rsi_period) + with_sma = with_rsi.pipe(SMA, period=self.sma_period) + + # Get current values from the last row + last_row = with_sma.tail(1) - # Calculate indicators - rsi = bars.pipe(RSI, period=self.rsi_period) - sma = bars.pipe(SMA, period=self.sma_period) - current_price = bars['close'][-1] - current_rsi = rsi[-1] - current_sma = sma[-1] + current_price = float(last_row["close"][0]) - # Long signal: Price above SMA and RSI oversold recovery - if current_price > current_sma and 30 < current_rsi < 50: - return "long", current_price + # Get RSI value + rsi_column = f"rsi_{self.rsi_period}" + current_rsi = ( + float(last_row[rsi_column][0]) + if rsi_column in last_row.columns + else 50.0 + ) + + # Get SMA value + sma_column = f"sma_{self.sma_period}" + current_sma = ( + float(last_row[sma_column][0]) + if sma_column in last_row.columns + else current_price + ) - # Short signal: Price below SMA and RSI overbought decline - elif current_price < current_sma and 50 < current_rsi < 70: - return "short", current_price + # Long signal: Price above SMA and RSI oversold recovery + if current_price > current_sma and 30 < current_rsi < 50: + return "long", current_price - return None, None + # Short signal: Price below SMA and RSI overbought decline + elif current_price < current_sma and 50 < current_rsi < 70: + return "short", current_price - async def place_bracket_order(self, direction: str): + return None, None + + except Exception as e: + print(f"Error checking entry conditions: {e}") + return None, None + + async def place_bracket_order( + self, direction: str + ) -> Optional[BracketOrderResponse]: """Place a bracket order based on strategy conditions.""" try: - # Calculate dynamic stop and target levels - stop_offset, target_offset = await self.calculate_dynamic_levels() + # Get current price + current_price = await self.suite["MNQ"].data.get_current_price() + if not current_price: + print("Could not get current price") + return None - # Determine side (0=Buy, 1=Sell) - side = 0 if direction == "long" else 1 + # Calculate dynamic stop and target levels (offsets) + stop_offset, target_offset = await self.calculate_dynamic_levels() - print(f"Placing {direction.upper()} bracket order:") - print(f" Size: {self.position_size} contracts") - print(f" Stop Loss: {stop_offset} points") - print(f" Take Profit: {target_offset} points") + # Calculate actual price levels + if direction == "long": + stop_loss_price = current_price - stop_offset + take_profit_price = current_price + target_offset + side = 0 # Buy + else: # short + stop_loss_price = current_price + stop_offset + take_profit_price = current_price - target_offset + side = 1 # Sell - # Place bracket order - result = await self.suite.orders.place_bracket_order( - contract_id=self.suite.instrument_info.id, + # Display trade setup + print("\n" + "=" * 60) + print(f"{direction.upper()} BRACKET ORDER SETUP") + print("=" * 60) + print(f"Current Price: ${current_price:.2f}") + print(f"Position Size: {self.position_size} contracts") + print(f"Stop Loss: ${stop_loss_price:.2f} ({stop_offset:.2f} points)") + print(f"Take Profit: ${take_profit_price:.2f} ({target_offset:.2f} points)") + + # Calculate risk/reward + risk = abs(current_price - stop_loss_price) + reward = abs(take_profit_price - current_price) + rr_ratio = reward / risk if risk > 0 else 0 + print(f"Risk/Reward Ratio: {rr_ratio:.2f}:1") + print("=" * 60) + + # Get instrument contract ID + instrument = self.suite["MNQ"].instrument_info + contract_id = instrument.id if hasattr(instrument, "id") else "MNQ" + + print("\nPlacing bracket order...") + + # Place bracket order with market entry + # Prices will be automatically aligned to tick size + result = await self.suite["MNQ"].orders.place_bracket_order( + contract_id=contract_id, side=side, size=self.position_size, - stop_offset=stop_offset, - target_offset=target_offset + entry_price=None, # Market order + entry_type="market", + stop_loss_price=stop_loss_price, + take_profit_price=take_profit_price, ) - print(f"Bracket order placed successfully:") - print(f" Main Order ID: {result.main_order_id}") - print(f" Stop Order ID: {result.stop_order_id}") - print(f" Target Order ID: {result.target_order_id}") + if result and result.success: + print("\nāœ… Bracket order placed successfully!") + print(f" Entry Order ID: {result.entry_order_id}") + print(f" Stop Order ID: {result.stop_order_id}") + print(f" Target Order ID: {result.target_order_id}") - self.active_orders.append(result) - return result + self.active_orders.append(result) + return result + else: + error_msg = result.error_message if result else "Unknown error" + print(f"\nāŒ Failed to place bracket order: {error_msg}") + return None except Exception as e: print(f"Failed to place bracket order: {e}") + import traceback + + traceback.print_exc() return None async def monitor_orders(self): """Monitor active orders and handle fills/cancellations.""" - for bracket in self.active_orders[:]: # Copy list to modify during iteration + if not self.active_orders: + return + + # Copy list to allow modification during iteration + for bracket in self.active_orders[:]: try: - # Check main order status - main_status = await self.suite.orders.get_order_status(bracket.main_order_id) + if bracket is None: + continue - if main_status.status == "Filled": - print(f"Main order {bracket.main_order_id} filled at ${main_status.fill_price}") + # For this example, we'll just track the count + # In production, you would check order status via the API - elif main_status.status in ["Cancelled", "Rejected"]: - print(f"Main order {bracket.main_order_id} {main_status.status}") - self.active_orders.remove(bracket) + # Note: The actual order monitoring would typically be done + # through event handlers rather than polling except Exception as e: - print(f"Error monitoring order {bracket.main_order_id}: {e}") + print(f"Error monitoring orders: {e}") + + def remove_completed_order(self, order_id: int): + """Remove a completed order from tracking.""" + self.active_orders = [ + bracket + for bracket in self.active_orders + if bracket and bracket.entry_order_id != order_id + ] + async def main(): + """Main function to run the ATR bracket strategy.""" + print("Initializing Advanced Bracket Order Strategy...") + # Create trading suite with required timeframes suite = await TradingSuite.create( ["MNQ"], timeframes=["1min", "5min"], initial_days=10, # Need historical data for indicators - features=["risk_manager"] + features=["risk_manager"], ) # Initialize strategy strategy = ATRBracketStrategy(suite) + mnq_context = suite["MNQ"] + + # Track last bar time to avoid duplicate processing + last_bar_time = {} # Set up event handlers for real-time monitoring - async def on_new_bar(event): - if event.data.get('timeframe') == '5min': - print(f"New 5min bar: ${event.data['close']:.2f}") + async def on_new_bar(event: Event): + """Handle new bar events.""" + timeframe = event.data.get("timeframe", "unknown") + + if timeframe == "5min": + # Avoid duplicate processing + current_time = event.data.get("timestamp", "") + if current_time == last_bar_time.get(timeframe): + return + last_bar_time[timeframe] = current_time + + # Get bar data + bar_data = event.data.get("data", {}) + close_price = bar_data.get("close", 0) + + if close_price: + print(f"\nNew 5min bar: ${close_price:.2f}") + + # Check if we can take a new position + if len(strategy.active_orders) >= strategy.max_positions: + return # Check for entry signals direction, price = await strategy.check_entry_conditions() - if direction and len(strategy.active_orders) == 0: # No active positions - print(f"Entry signal detected: {direction.upper()} at ${price:.2f}") + if direction: + print( + f"\nšŸŽÆ Entry signal detected: {direction.upper()} at ${price:.2f}" + ) - # Confirm with user before placing order - response = input(f"Place {direction.upper()} bracket order? (y/N): ") - if response.lower().startswith('y'): + # Auto-confirm for demo, or ask user + if False: # Set to True for auto-trading await strategy.place_bracket_order(direction) - - async def on_order_filled(event): - order_data = event.data - print(f"ORDER FILLED: {order_data.get('order_id')} at ${order_data.get('fill_price', 0):.2f}") + else: + # Confirm with user before placing order + response = input( + f"Place {direction.upper()} bracket order? (y/N): " + ) + if response.lower().startswith("y"): + await strategy.place_bracket_order(direction) # Register event handlers - await suite["MNQ"].on(EventType.NEW_BAR, on_new_bar) - await suite["MNQ"].on(EventType.ORDER_FILLED, on_order_filled) + await mnq_context.on(EventType.NEW_BAR, on_new_bar) - print("Advanced Bracket Order Strategy Active") - print("Monitoring for entry signals on 5-minute bars...") + print("\n" + "=" * 60) + print("ADVANCED BRACKET ORDER STRATEGY ACTIVE") + print("=" * 60) + print("Strategy Settings:") + print(f" ATR Period: {strategy.atr_period}") + print(f" RSI Period: {strategy.rsi_period}") + print(f" SMA Period: {strategy.sma_period}") + print(f" Position Size: {strategy.position_size} contracts") + print(f" Max Positions: {strategy.max_positions}") + print("\nMonitoring for entry signals on 5-minute bars...") print("Press Ctrl+C to exit") + print("=" * 60) try: while True: - await asyncio.sleep(5) + await asyncio.sleep(30) # Status update every 30 seconds # Monitor active orders await strategy.monitor_orders() # Display current market info - current_price = await suite["MNQ"].data.get_current_price() - active_count = len(strategy.active_orders) - print(f"Price: ${current_price:.2f} | Active Orders: {active_count}") + current_price = await mnq_context.data.get_current_price() + if current_price: + active_count = len(strategy.active_orders) + print( + f"\nStatus: Price=${current_price:.2f} | Active Orders={active_count}" + ) except KeyboardInterrupt: - print("\nShutting down strategy...") + print("\n\nShutting down strategy...") # Cancel any remaining orders for bracket in strategy.active_orders: - try: - await suite["MNQ"].orders.cancel_order(bracket.main_order_id) - print(f"Cancelled order {bracket.main_order_id}") - except Exception as e: - print(f"Error cancelling order: {e}") + if bracket: + try: + # Cancel stop and target orders + if bracket.stop_order_id: + await mnq_context.orders.cancel_order(bracket.stop_order_id) + print(f"Cancelled stop order {bracket.stop_order_id}") + if bracket.target_order_id: + await mnq_context.orders.cancel_order(bracket.target_order_id) + print(f"Cancelled target order {bracket.target_order_id}") + except Exception as e: + print(f"Error cancelling orders: {e}") + + finally: + # Disconnect from real-time feeds + await suite.disconnect() + print("Strategy disconnected. Goodbye!") if __name__ == "__main__": @@ -199,231 +367,415 @@ Advanced strategy using multiple timeframes for confirmation: ```python #!/usr/bin/env python """ -Multi-timeframe momentum strategy with confluence analysis +Multi-timeframe momentum strategy with confluence analysis. + +This example demonstrates: +- Multi-timeframe analysis (5min, 15min, 1hr) +- Momentum and trend confluence detection +- Technical indicators (RSI, MACD, EMA, ATR) +- Dynamic position sizing based on ATR +- Bracket orders with volatility-based stops """ + import asyncio from decimal import Decimal -from project_x_py import TradingSuite, EventType -from project_x_py.indicators import RSI, MACD, EMA, ATR +from typing import Any, Optional + +from project_x_py import EventType, TradingSuite +from project_x_py.event_bus import Event +from project_x_py.indicators import ATR, EMA, MACD, RSI +from project_x_py.models import BracketOrderResponse + class MultiTimeframeMomentumStrategy: + """Multi-timeframe momentum trading strategy.""" + def __init__(self, suite: TradingSuite): self.suite = suite self.position_size = 1 self.risk_per_trade = Decimal("0.02") # 2% risk per trade - self.active_position = None + self.account_balance = Decimal("50000") # Default balance + self.active_position: Optional[dict[str, Any]] = None - async def analyze_timeframe(self, timeframe: str): + async def analyze_timeframe(self, timeframe: str) -> Optional[dict[str, Any]]: """Analyze a specific timeframe for momentum signals.""" - bars = await self.suite.data.get_data(timeframe) + try: + # Get bars for the timeframe + bars = await self.suite["MNQ"].data.get_data(timeframe) + + if bars is None or bars.is_empty(): + print(f"No data available for {timeframe}") + return None + + if len(bars) < 50: # Need sufficient data for indicators + print(f"Insufficient data for {timeframe} (need 50+ bars)") + return None - if len(bars) < 50: # Need sufficient data + # Calculate indicators using pipe method + with_rsi = bars.pipe(RSI, period=14) + with_macd = with_rsi.pipe( + MACD, fast_period=12, slow_period=26, signal_period=9 + ) + with_ema20 = with_macd.pipe(EMA, period=20) + with_ema50 = with_ema20.pipe(EMA, period=50) + + # Get the last row for current values + last_row = with_ema50.tail(1) + + # Extract values from the last row + current_price = float(last_row["close"][0]) + current_rsi = ( + float(last_row["rsi_14"][0]) if "rsi_14" in last_row.columns else 50.0 + ) + + # MACD values + current_macd = ( + float(last_row["macd"][0]) if "macd" in last_row.columns else 0.0 + ) + macd_signal = ( + float(last_row["macd_signal"][0]) + if "macd_signal" in last_row.columns + else 0.0 + ) + + # EMA values + current_ema_20 = ( + float(last_row["ema_20"][0]) + if "ema_20" in last_row.columns + else current_price + ) + current_ema_50 = ( + float(last_row["ema_50"][0]) + if "ema_50" in last_row.columns + else current_price + ) + + # Determine trend and momentum + trend = "bullish" if current_ema_20 > current_ema_50 else "bearish" + momentum = "positive" if current_macd > macd_signal else "negative" + rsi_level = ( + "oversold" + if current_rsi < 30 + else "overbought" + if current_rsi > 70 + else "neutral" + ) + + return { + "timeframe": timeframe, + "price": current_price, + "trend": trend, + "momentum": momentum, + "rsi_level": rsi_level, + "rsi": current_rsi, + "macd": current_macd, + "macd_signal": macd_signal, + "ema_20": current_ema_20, + "ema_50": current_ema_50, + } + + except Exception as e: + print(f"Error analyzing {timeframe}: {e}") return None - # Calculate indicators - rsi = bars.pipe(RSI, period=14) - macd_result = bars.pipe(MACD, fast_period=12, slow_period=26, signal_period=9) - ema_20 = bars.pipe(EMA, period=20) - ema_50 = bars.pipe(EMA, period=50) - - current_price = bars['close'][-1] - current_rsi = rsi[-1] - current_macd = macd_result['macd'][-1] - macd_signal = macd_result['signal'][-1] - current_ema_20 = ema_20[-1] - current_ema_50 = ema_50[-1] - - # Determine trend and momentum - trend = "bullish" if current_ema_20 > current_ema_50 else "bearish" - momentum = "positive" if current_macd > macd_signal else "negative" - rsi_level = "oversold" if current_rsi < 30 else "overbought" if current_rsi > 70 else "neutral" - - return { - "timeframe": timeframe, - "price": current_price, - "trend": trend, - "momentum": momentum, - "rsi_level": rsi_level, - "rsi": current_rsi, - "macd": current_macd, - "macd_signal": macd_signal, - "ema_20": current_ema_20, - "ema_50": current_ema_50 - } - - async def check_confluence(self): + async def check_confluence(self) -> tuple[Optional[str], Optional[list]]: """Check for confluence across multiple timeframes.""" - # Analyze all timeframes - tf_5min = await self.analyze_timeframe("5min") - tf_15min = await self.analyze_timeframe("15min") # Add 15min if available - tf_1hr = await self.analyze_timeframe("1hr") # Add 1hr if available + # Analyze all configured timeframes + analyses = [] - analyses = [tf for tf in [tf_5min, tf_15min, tf_1hr] if tf is not None] + for timeframe in ["5min", "15min", "1hr"]: + if timeframe in self.suite["MNQ"].data.timeframes: + analysis = await self.analyze_timeframe(timeframe) + if analysis: + analyses.append(analysis) if len(analyses) < 2: - return None, None + return None, analyses if analyses else None # Count bullish/bearish signals - bullish_signals = sum(1 for tf in analyses if tf['trend'] == 'bullish' and tf['momentum'] == 'positive') - bearish_signals = sum(1 for tf in analyses if tf['trend'] == 'bearish' and tf['momentum'] == 'negative') + bullish_signals = sum( + 1 + for tf in analyses + if tf["trend"] == "bullish" and tf["momentum"] == "positive" + ) + bearish_signals = sum( + 1 + for tf in analyses + if tf["trend"] == "bearish" and tf["momentum"] == "negative" + ) + + # Get the lowest timeframe analysis (usually 5min) + entry_tf = analyses[0] # First timeframe for entry conditions # Require confluence (majority agreement) - if bullish_signals >= 2 and tf_5min['rsi'] < 70: # Not overbought on entry timeframe + if bullish_signals >= 2 and entry_tf["rsi"] < 70: # Not overbought return "long", analyses - elif bearish_signals >= 2 and tf_5min['rsi'] > 30: # Not oversold on entry timeframe + elif bearish_signals >= 2 and entry_tf["rsi"] > 30: # Not oversold return "short", analyses return None, analyses - async def calculate_position_size(self, entry_price: float, stop_loss: float): + async def calculate_position_size( + self, entry_price: float, stop_loss: float + ) -> int: """Calculate position size based on risk management.""" - account_info = await self.suite.client.get_account_info() - account_balance = float(account_info.balance) - # Calculate risk amount - risk_amount = account_balance * float(self.risk_per_trade) + risk_amount = float(self.account_balance) * float(self.risk_per_trade) - # Calculate risk per contract + # Calculate risk per contract (MNQ = $20 per point) price_diff = abs(entry_price - stop_loss) - risk_per_contract = price_diff * 20 # MNQ multiplier + risk_per_contract = price_diff * 20 + + if risk_per_contract <= 0: + return 1 # Calculate position size calculated_size = int(risk_amount / risk_per_contract) return max(1, min(calculated_size, 5)) # Between 1-5 contracts - async def place_momentum_trade(self, direction: str, analyses: list): - """Place a trade based on momentum confluence.""" + async def calculate_atr_stops( + self, direction: str, current_price: float, timeframe: str = "5min" + ) -> tuple[float, float]: + """Calculate ATR-based stop loss and take profit.""" try: - current_price = analyses[0]['price'] # Use 5min price + # Get bars for ATR calculation + bars = await self.suite["MNQ"].data.get_data(timeframe) + if bars is None or bars.is_empty(): + # Fallback to fixed stops + if direction == "long": + return current_price - 50, current_price + 100 + else: + return current_price + 50, current_price - 100 - # Calculate ATR-based stop loss - bars_5min = await self.suite.data.get_data("5min") - atr = bars_5min.pipe(ATR, period=14) - current_atr = float(atr[-1]) + # Calculate ATR + with_atr = bars.pipe(ATR, period=14) + current_atr = float(with_atr["atr_14"].tail(1)[0]) - # Dynamic stops based on volatility + # Dynamic stops based on volatility (2x ATR stop, 3x ATR target) if direction == "long": stop_loss = current_price - (current_atr * 2) take_profit = current_price + (current_atr * 3) - side = 0 # Buy else: stop_loss = current_price + (current_atr * 2) take_profit = current_price - (current_atr * 3) - side = 1 # Sell + + return stop_loss, take_profit + + except Exception as e: + print(f"Error calculating ATR stops: {e}") + # Fallback to fixed stops + if direction == "long": + return current_price - 50, current_price + 100 + else: + return current_price + 50, current_price - 100 + + async def place_momentum_trade( + self, direction: str, analyses: list + ) -> Optional[Any]: + """Place a trade based on momentum confluence.""" + try: + # Use the entry timeframe price + current_price = analyses[0]["price"] + + # Calculate ATR-based stops + stop_loss, take_profit = await self.calculate_atr_stops( + direction, current_price + ) # Calculate position size position_size = await self.calculate_position_size(current_price, stop_loss) - print(f"\n{direction.upper()} Momentum Trade Setup:") - print(f" Entry Price: ${current_price:.2f}") - print(f" Stop Loss: ${stop_loss:.2f}") - print(f" Take Profit: ${take_profit:.2f}") - print(f" Position Size: {position_size} contracts") - print(f" Risk/Reward: {abs(take_profit - current_price) / abs(current_price - stop_loss):.2f}:1") + # Display trade setup + print("\n" + "=" * 60) + print(f"{direction.upper()} MOMENTUM TRADE SETUP") + print("=" * 60) + print(f"Entry Price: ${current_price:.2f}") + print( + f"Stop Loss: ${stop_loss:.2f} ({abs(current_price - stop_loss):.2f} points)" + ) + print( + f"Take Profit: ${take_profit:.2f} ({abs(take_profit - current_price):.2f} points)" + ) + print(f"Position Size: {position_size} contracts") + + # Calculate risk/reward + risk = abs(current_price - stop_loss) + reward = abs(take_profit - current_price) + rr_ratio = reward / risk if risk > 0 else 0 + print(f"Risk/Reward: {rr_ratio:.2f}:1") # Display confluence analysis print("\nConfluence Analysis:") for analysis in analyses: - print(f" {analysis['timeframe']}: {analysis['trend']} trend, {analysis['momentum']} momentum, RSI: {analysis['rsi']:.1f}") + print( + f" {analysis['timeframe']:5s}: " + f"{analysis['trend']:7s} trend, " + f"{analysis['momentum']:8s} momentum, " + f"RSI: {analysis['rsi']:5.1f}" + ) + print("=" * 60) # Confirm trade response = input(f"\nPlace {direction.upper()} momentum trade? (y/N): ") - if not response.lower().startswith('y'): + if not response.lower().startswith("y"): + print("Trade cancelled") return None - # Place bracket order - result = await self.suite.orders.place_bracket_order( - contract_id=self.suite.instrument_info.id, + # Get instrument contract ID + instrument = self.suite["MNQ"].instrument_info + contract_id = instrument.id if hasattr(instrument, "id") else "MNQ" + + # Determine side + side = 0 if direction == "long" else 1 # 0=Buy, 1=Sell + + print("\nPlacing bracket order...") + + # Place bracket order with market entry + result = await self.suite["MNQ"].orders.place_bracket_order( + contract_id=contract_id, side=side, size=position_size, - stop_offset=Decimal(str(abs(current_price - stop_loss))), - target_offset=Decimal(str(abs(take_profit - current_price))) + entry_price=None, # Market order + entry_type="market", + stop_loss_price=stop_loss, + take_profit_price=take_profit, ) - self.active_position = { - "direction": direction, - "entry_price": current_price, - "stop_loss": stop_loss, - "take_profit": take_profit, - "size": position_size, - "bracket": result - } + if result and result.success: + self.active_position = { + "direction": direction, + "entry_price": current_price, + "stop_loss": stop_loss, + "take_profit": take_profit, + "size": position_size, + "bracket_result": result, + } + + print("\nāœ… Momentum trade placed successfully!") + print(f" Entry Order: {result.entry_order_id}") + print(f" Stop Order: {result.stop_order_id}") + print(f" Target Order: {result.target_order_id}") + else: + error_msg = result.error_message if result else "Unknown error" + print(f"\nāŒ Failed to place trade: {error_msg}") - print(f"Momentum trade placed successfully!") return result except Exception as e: print(f"Failed to place momentum trade: {e}") + import traceback + + traceback.print_exc() return None + async def main(): + """Main function to run the momentum strategy.""" + print("Initializing Multi-Timeframe Momentum Strategy...") + # Create suite with multiple timeframes suite = await TradingSuite.create( ["MNQ"], timeframes=["5min", "15min", "1hr"], initial_days=15, # More historical data for higher timeframes - features=["risk_manager"] + features=["risk_manager"], ) - mnq_context = suite["MNQ"] + mnq_context = suite["MNQ"] strategy = MultiTimeframeMomentumStrategy(suite) # Event handlers - async def on_new_bar(event): - if event.data.get('timeframe') == '5min': # Only act on 5min bars - # Check for confluence signals - direction, analyses = await strategy.check_confluence() + last_bar_time = {} - if direction and not strategy.active_position: - print(f"\n=== MOMENTUM CONFLUENCE DETECTED: {direction.upper()} ===") - await strategy.place_momentum_trade(direction, analyses) - elif analyses: - # Display current analysis - print(f"\nCurrent Analysis (no confluence):") - for analysis in analyses: - if analysis: - print(f" {analysis['timeframe']}: {analysis['trend']}/{analysis['momentum']} (RSI: {analysis['rsi']:.1f})") + async def on_new_bar(event: Event): + """Handle new bar events.""" + # Get timeframe from event data + timeframe = event.data.get("timeframe", "unknown") - async def on_order_filled(event): - if strategy.active_position: - order_id = event.data.get('order_id') - fill_price = event.data.get('fill_price', 0) + # Only act on 5min bars for trade decisions + if timeframe == "5min": + # Avoid duplicate processing + current_time = event.data.get("timestamp", "") + if current_time == last_bar_time.get(timeframe): + return + last_bar_time[timeframe] = current_time - # Check if it's our stop or target - bracket = strategy.active_position['bracket'] - if order_id in [bracket.stop_order_id, bracket.target_order_id]: - result = "STOP LOSS" if order_id == bracket.stop_order_id else "TAKE PROFIT" - print(f"\n{result} HIT: Order {order_id} filled at ${fill_price:.2f}") - strategy.active_position = None # Clear position + # Check for confluence signals + direction, analyses = await strategy.check_confluence() - # Register events + if analyses and not strategy.active_position: + if direction: + print(f"\n{'=' * 60}") + print(f"MOMENTUM CONFLUENCE DETECTED: {direction.upper()}") + print(f"{'=' * 60}") + await strategy.place_momentum_trade(direction, analyses) + else: + # Display current analysis (no confluence) + print("\nCurrent Market Analysis (No Confluence):") + for analysis in analyses: + print( + f" {analysis['timeframe']:5s}: " + f"{analysis['trend']:7s}/{analysis['momentum']:8s} " + f"(RSI: {analysis['rsi']:5.1f})" + ) + + # Register event handlers await mnq_context.on(EventType.NEW_BAR, on_new_bar) - await mnq_context.on(EventType.ORDER_FILLED, on_order_filled) - print("Multi-Timeframe Momentum Strategy Active") - print("Analyzing 5min, 15min, and 1hr timeframes for confluence...") - print("Press Ctrl+C to exit") + print("\n" + "=" * 60) + print("MULTI-TIMEFRAME MOMENTUM STRATEGY ACTIVE") + print("=" * 60) + print("Analyzing 5min, 15min, and 1hr timeframes for confluence") + print("Looking for aligned trend and momentum signals") + print("Using ATR-based dynamic stops and targets") + print("\nPress Ctrl+C to exit") + print("=" * 60) try: while True: - await asyncio.sleep(10) + await asyncio.sleep(30) # Status update every 30 seconds # Display status current_price = await mnq_context.data.get_current_price() - position_status = "ACTIVE" if strategy.active_position else "FLAT" - print(f"Price: ${current_price:.2f} | Position: {position_status}") + if current_price: + position_status = "ACTIVE" if strategy.active_position else "FLAT" + + print("\nStatus Update:") + print(f" Price: ${current_price:.2f}") + print(f" Position: {position_status}") + + if strategy.active_position: + pos = strategy.active_position + print(f" Direction: {pos['direction'].upper()}") + print(f" Entry: ${pos['entry_price']:.2f}") + print(f" Stop: ${pos['stop_loss']:.2f}") + print(f" Target: ${pos['take_profit']:.2f}") except KeyboardInterrupt: - print("\nShutting down strategy...") + print("\n\nShutting down strategy...") # Cancel active orders if any if strategy.active_position: - bracket = strategy.active_position['bracket'] - try: - await mnq_context.orders.cancel_order(bracket.main_order_id) - print("Cancelled active orders") - except Exception as e: - print(f"Error cancelling orders: {e}") + bracket_result: BracketOrderResponse = strategy.active_position.get("bracket_result", {}) + if bracket_result: + try: + # Cancel stop and target orders + if bracket_result.stop_order_id: + await mnq_context.orders.cancel_order( + bracket_result.stop_order_id + ) + if bracket_result.target_order_id: + await mnq_context.orders.cancel_order( + bracket_result.target_order_id + ) + print("Cancelled active orders") + except Exception as e: + print(f"Error cancelling orders: {e}") + + finally: + # Disconnect from real-time feeds + await suite.disconnect() + print("Strategy disconnected. Goodbye!") if __name__ == "__main__": @@ -437,58 +789,75 @@ Comprehensive risk management with position sizing and portfolio limits: ```python #!/usr/bin/env python """ -Advanced risk management system with portfolio-level controls +Advanced risk management system with portfolio-level controls. + +This example demonstrates: +- Position sizing based on risk parameters +- Portfolio risk monitoring +- Bracket orders with automatic stop-loss and take-profit +- Real-time P&L tracking +- Risk limit enforcement """ + import asyncio +from datetime import datetime from decimal import Decimal -from datetime import datetime, timedelta -from project_x_py import TradingSuite, EventType + +from project_x_py import EventType, TradingSuite +from project_x_py.event_bus import Event + class AdvancedRiskManager: + """Advanced risk management system for trading.""" + def __init__(self, suite: TradingSuite): self.suite = suite # Risk parameters self.max_risk_per_trade = Decimal("0.02") # 2% per trade - self.max_daily_risk = Decimal("0.06") # 6% per day + self.max_daily_risk = Decimal("0.06") # 6% per day self.max_portfolio_risk = Decimal("0.20") # 20% total portfolio - self.max_positions = 3 # Maximum open positions + self.max_positions = 3 # Maximum open positions # Tracking + self.account_balance = Decimal("50000") # Default demo balance self.daily_pnl = Decimal("0") self.active_trades = [] self.daily_reset_time = datetime.now().date() - async def get_account_balance(self): - """Get current account balance.""" - account_info = await self.suite.client.get_account_info() - return Decimal(str(account_info.balance)) + async def update_account_info(self): + """Update account information.""" + try: + # Try to get positions to calculate P&L + positions = await self.suite["MNQ"].positions.get_all_positions() - async def calculate_current_portfolio_risk(self): - """Calculate current portfolio risk exposure.""" - positions = await self.suite.positions.get_all_positions() - total_risk = Decimal("0") + # Calculate total P&L from positions + total_pnl = Decimal("0") + # Note: Actual P&L calculation would depend on position attributes + # This is a placeholder for demonstration - for position in positions: - if position.size != 0: - # Estimate risk based on position size and current unrealized P&L - position_value = abs(Decimal(str(position.size * position.average_price * 20))) # MNQ multiplier - total_risk += position_value + # Update daily P&L + current_date = datetime.now().date() + if current_date > self.daily_reset_time: + self.daily_pnl = Decimal("0") + self.daily_reset_time = current_date + print(f"Daily P&L reset for {current_date}") - account_balance = await self.get_account_balance() - portfolio_risk_pct = total_risk / account_balance if account_balance > 0 else Decimal("0") + self.daily_pnl += total_pnl - return portfolio_risk_pct, total_risk + except Exception as e: + print(f"Could not update account info: {e}") - async def calculate_position_size(self, entry_price: float, stop_loss: float, risk_amount: Decimal = None): + async def calculate_position_size( + self, entry_price: float, stop_loss: float + ) -> int: """Calculate optimal position size based on risk parameters.""" - if risk_amount is None: - account_balance = await self.get_account_balance() - risk_amount = account_balance * self.max_risk_per_trade + # Calculate risk amount + risk_amount = self.account_balance * self.max_risk_per_trade - # Calculate risk per contract + # Calculate risk per contract (MNQ = $20 per point) price_diff = abs(Decimal(str(entry_price)) - Decimal(str(stop_loss))) - risk_per_contract = price_diff * 20 # MNQ multiplier + risk_per_contract = price_diff * 20 if risk_per_contract <= 0: return 0 @@ -500,53 +869,51 @@ class AdvancedRiskManager: max_size = 10 # Hard limit return max(1, min(calculated_size, max_size)) - async def check_risk_limits(self, proposed_trade: dict): + async def check_risk_limits(self, proposed_size: int) -> tuple[bool, list[str]]: """Check if proposed trade violates risk limits.""" errors = [] # Check maximum positions - if len(self.active_trades) >= self.max_positions: + positions = await self.suite["MNQ"].positions.get_all_positions() + active_positions = [p for p in positions if p.size != 0] + + if len(active_positions) >= self.max_positions: errors.append(f"Maximum positions reached ({self.max_positions})") # Check daily risk limit - account_balance = await self.get_account_balance() - if abs(self.daily_pnl) >= (account_balance * self.max_daily_risk): + if abs(self.daily_pnl) >= (self.account_balance * self.max_daily_risk): errors.append(f"Daily risk limit reached ({self.max_daily_risk * 100}%)") # Check portfolio risk - portfolio_risk_pct, _ = await self.calculate_current_portfolio_risk() - if portfolio_risk_pct >= self.max_portfolio_risk: - errors.append(f"Portfolio risk limit reached ({self.max_portfolio_risk * 100}%)") - - # Check proposed trade risk - trade_risk = Decimal(str(proposed_trade['risk_amount'])) - if trade_risk > (account_balance * self.max_risk_per_trade): - errors.append(f"Trade risk too high ({self.max_risk_per_trade * 100}% max)") + total_position_size = sum(abs(p.size) for p in active_positions) + if total_position_size + proposed_size > 20: # Max 20 contracts total + errors.append("Portfolio size limit reached (20 contracts max)") return len(errors) == 0, errors - async def monitor_daily_pnl(self): - """Monitor and update daily P&L.""" - current_date = datetime.now().date() - - # Reset daily P&L if new day - if current_date > self.daily_reset_time: - self.daily_pnl = Decimal("0") - self.daily_reset_time = current_date - print(f"Daily P&L reset for {current_date}") - - # Calculate current daily P&L - positions = await self.suite.positions.get_all_positions() - total_unrealized = sum(Decimal(str(p.unrealized_pnl)) for p in positions) - total_realized = sum(Decimal(str(p.realized_pnl)) for p in positions) - - self.daily_pnl = total_realized + total_unrealized - - return self.daily_pnl - - async def place_risk_managed_trade(self, direction: str, entry_price: float, stop_loss: float, take_profit: float): + async def place_risk_managed_trade( + self, direction: str, stop_offset: float = 50, target_offset: float = 100 + ): """Place a trade with full risk management.""" try: + # Get current price + current_price = await self.suite["MNQ"].data.get_current_price() + if not current_price: + print("Could not get current price") + return None + + # Calculate entry, stop, and target prices + if direction == "long": + entry_price = current_price + stop_loss = current_price - stop_offset + take_profit = current_price + target_offset + side = 0 # Buy + else: + entry_price = current_price + stop_loss = current_price + stop_offset + take_profit = current_price - target_offset + side = 1 # Sell + # Calculate position size position_size = await self.calculate_position_size(entry_price, stop_loss) @@ -554,22 +921,8 @@ class AdvancedRiskManager: print("Position size calculated as 0 - trade rejected") return None - # Calculate trade risk - risk_per_contract = abs(Decimal(str(entry_price)) - Decimal(str(stop_loss))) * 20 - total_risk = risk_per_contract * position_size - - # Prepare trade proposal - proposed_trade = { - "direction": direction, - "entry_price": entry_price, - "stop_loss": stop_loss, - "take_profit": take_profit, - "size": position_size, - "risk_amount": total_risk - } - # Check risk limits - risk_ok, risk_errors = await self.check_risk_limits(proposed_trade) + risk_ok, risk_errors = await self.check_risk_limits(position_size) if not risk_ok: print("Trade rejected due to risk limits:") @@ -577,149 +930,198 @@ class AdvancedRiskManager: print(f" - {error}") return None - # Display trade details - account_balance = await self.get_account_balance() - risk_pct = (total_risk / account_balance) * 100 - - print(f"\nRisk-Managed Trade Setup:") - print(f" Direction: {direction.upper()}") - print(f" Entry: ${entry_price:.2f}") - print(f" Stop: ${stop_loss:.2f}") - print(f" Target: ${take_profit:.2f}") - print(f" Size: {position_size} contracts") - print(f" Risk: ${total_risk:.2f} ({risk_pct:.2f}% of account)") - print(f" R:R Ratio: {abs(take_profit - entry_price) / abs(entry_price - stop_loss):.2f}:1") - - # Show current risk status - portfolio_risk_pct, _ = await self.calculate_current_portfolio_risk() - daily_pnl = await self.monitor_daily_pnl() + # Calculate trade risk + risk_per_contract = abs(entry_price - stop_loss) * 20 # MNQ multiplier + total_risk = risk_per_contract * position_size + risk_pct = float((total_risk / float(self.account_balance)) * 100) - print(f"\nCurrent Risk Status:") - print(f" Daily P&L: ${daily_pnl:.2f}") - print(f" Portfolio Risk: {portfolio_risk_pct * 100:.2f}%") - print(f" Active Positions: {len(self.active_trades)}") + # Display trade details + print("\n" + "=" * 50) + print("RISK-MANAGED TRADE SETUP") + print("=" * 50) + print(f"Direction: {direction.upper()}") + print(f"Current Price: ${entry_price:.2f}") + print(f"Stop Loss: ${stop_loss:.2f} ({stop_offset} points)") + print(f"Take Profit: ${take_profit:.2f} ({target_offset} points)") + print(f"Position Size: {position_size} contracts") + print(f"Risk Amount: ${total_risk:.2f} ({risk_pct:.2f}% of account)") + print(f"R:R Ratio: {target_offset / stop_offset:.1f}:1") + + # Show current status + positions = await self.suite["MNQ"].positions.get_all_positions() + active_positions = [p for p in positions if p.size != 0] + print("\nCurrent Status:") + print(f" Active Positions: {len(active_positions)}") + print(f" Daily P&L: ${self.daily_pnl:.2f}") + print("=" * 50) # Confirm trade - response = input(f"\nProceed with risk-managed {direction.upper()} trade? (y/N): ") - if not response.lower().startswith('y'): + response = input(f"\nProceed with {direction.upper()} trade? (y/N): ") + if not response.lower().startswith("y"): + print("Trade cancelled") return None # Place bracket order - side = 0 if direction == "long" else 1 - stop_offset = Decimal(str(abs(entry_price - stop_loss))) - target_offset = Decimal(str(abs(take_profit - entry_price))) + print("\nPlacing bracket order...") - result = await self.suite.orders.place_bracket_order( - contract_id=self.suite.instrument_info.id, + # Get the instrument contract ID + instrument = self.suite["MNQ"].instrument_info + contract_id = instrument.id if hasattr(instrument, "id") else "MNQ" + + result = await self.suite["MNQ"].orders.place_bracket_order( + contract_id=contract_id, side=side, size=position_size, - stop_offset=stop_offset, - target_offset=target_offset + entry_price=None, # Market entry + entry_type="market", + stop_loss_price=stop_loss, + take_profit_price=take_profit, ) - # Track the trade - trade_record = { - **proposed_trade, - "bracket": result, - "timestamp": datetime.now(), - "status": "active" - } - - self.active_trades.append(trade_record) + if result and result.success: + # Track the trade + trade_record = { + "direction": direction, + "entry_price": entry_price, + "stop_loss": stop_loss, + "take_profit": take_profit, + "size": position_size, + "risk_amount": total_risk, + "bracket_result": result, + "timestamp": datetime.now(), + "status": "active", + } + self.active_trades.append(trade_record) - print(f"Risk-managed trade placed successfully!") - print(f" Main Order: {result.main_order_id}") - print(f" Stop Order: {result.stop_order_id}") - print(f" Target Order: {result.target_order_id}") + print("\nāœ… Risk-managed trade placed successfully!") + print(f" Entry Order: {result.entry_order_id}") + print(f" Stop Order: {result.stop_order_id}") + print(f" Target Order: {result.target_order_id}") + else: + error_msg = result.error_message if result else "Unknown error" + print(f"\nāŒ Failed to place trade: {error_msg}") return result except Exception as e: print(f"Failed to place risk-managed trade: {e}") + import traceback + + traceback.print_exc() return None async def generate_risk_report(self): """Generate comprehensive risk report.""" - print("\n" + "="*50) - print("RISK MANAGEMENT REPORT") - print("="*50) - - account_balance = await self.get_account_balance() - daily_pnl = await self.monitor_daily_pnl() - portfolio_risk_pct, total_risk = await self.calculate_current_portfolio_risk() - - print(f"Account Balance: ${account_balance:,.2f}") - print(f"Daily P&L: ${daily_pnl:.2f} ({(daily_pnl/account_balance)*100:.2f}%)") - print(f"Portfolio Risk: ${total_risk:,.2f} ({portfolio_risk_pct*100:.2f}%)") - print(f"Active Trades: {len(self.active_trades)}") + await self.update_account_info() - print(f"\nRisk Limits:") - print(f" Per Trade: {self.max_risk_per_trade*100:.1f}% (${account_balance * self.max_risk_per_trade:.2f})") - print(f" Daily: {self.max_daily_risk*100:.1f}% (${account_balance * self.max_daily_risk:.2f})") - print(f" Portfolio: {self.max_portfolio_risk*100:.1f}% (${account_balance * self.max_portfolio_risk:.2f})") + print("\n" + "=" * 60) + print("RISK MANAGEMENT REPORT") + print("=" * 60) + + print(f"Account Balance: ${self.account_balance:,.2f}") + print( + f"Daily P&L: ${self.daily_pnl:.2f} ({(self.daily_pnl / self.account_balance) * 100:.2f}%)" + ) + + # Get current positions + positions = await self.suite["MNQ"].positions.get_all_positions() + active_positions = [p for p in positions if p.size != 0] + + print(f"\nActive Positions: {len(active_positions)}") + for i, pos in enumerate(active_positions, 1): + side = "LONG" if pos.size > 0 else "SHORT" + print(f" {i}. {side} {abs(pos.size)} contracts") + + print("\nRisk Limits:") + print( + f" Per Trade: {self.max_risk_per_trade * 100:.1f}% (${self.account_balance * self.max_risk_per_trade:.2f})" + ) + print( + f" Daily: {self.max_daily_risk * 100:.1f}% (${self.account_balance * self.max_daily_risk:.2f})" + ) + print(f" Portfolio: {self.max_portfolio_risk * 100:.1f}%") print(f" Max Positions: {self.max_positions}") if self.active_trades: - print(f"\nActive Trades:") - for i, trade in enumerate(self.active_trades, 1): - print(f" {i}. {trade['direction'].upper()} - ${trade['entry_price']:.2f} (Risk: ${trade['risk_amount']:.2f})") + print("\nRecent Trades:") + for i, trade in enumerate(self.active_trades[-5:], 1): # Show last 5 + print( + f" {i}. {trade['direction'].upper()} - " + f"${trade['entry_price']:.2f} " + f"(Risk: ${trade['risk_amount']:.2f}) - " + f"{trade['status'].upper()}" + ) + + print("=" * 60) - print("="*50) async def main(): - suite = await TradingSuite.create(["MNQ"], timeframes=["5min"], features=["risk_manager"]) + """Main function to run the risk management system.""" + print("Initializing Advanced Risk Management System...") + + # Create TradingSuite with risk management features + suite = await TradingSuite.create( + ["MNQ"], + timeframes=["1min", "5min"], + initial_days=1, + features=["risk_manager"], # Enable risk manager feature + ) + + # Create risk manager risk_manager = AdvancedRiskManager(suite) mnq_context = suite["MNQ"] - # Event handlers - async def on_order_filled(event): - order_data = event.data - print(f"Order filled: {order_data.get('order_id')} at ${order_data.get('fill_price', 0):.2f}") - - # Update trade records - for trade in risk_manager.active_trades: - bracket = trade['bracket'] - if order_data.get('order_id') in [bracket.stop_order_id, bracket.target_order_id]: - trade['status'] = 'completed' - print(f"Trade completed: {trade['direction']} from ${trade['entry_price']:.2f}") + # Set up event handlers + async def on_new_bar(_event: Event): + """Handle new bar events to update P&L.""" + # Update account info on each new bar + await risk_manager.update_account_info() - await mnq_context.on(EventType.ORDER_FILLED, on_order_filled) + async def on_quote(_event: Event): + """Handle quote updates.""" + # Could use this for real-time P&L updates + # Placeholder for future real-time updates - print("Advanced Risk Management System Active") - print("Commands:") - print(" 'long' - Test long trade") - print(" 'short' - Test short trade") + # Register event handlers + await mnq_context.on(EventType.NEW_BAR, on_new_bar) + await mnq_context.on(EventType.QUOTE_UPDATE, on_quote) + + print("\n" + "=" * 60) + print("ADVANCED RISK MANAGEMENT SYSTEM ACTIVE") + print("=" * 60) + print("\nCommands:") + print(" 'long' - Place risk-managed LONG trade") + print(" 'short' - Place risk-managed SHORT trade") print(" 'report' - Generate risk report") - print(" 'quit' - Exit") + print(" 'quit' - Exit system") + print("=" * 60) try: while True: + # Get user input command = input("\nEnter command: ").strip().lower() - if command == 'quit': + if command == "quit": break - elif command == 'report': + elif command == "report": await risk_manager.generate_risk_report() - elif command in ['long', 'short']: - # Get current price and simulate trade levels - current_price = await mnq_context.data.get_current_price() - - if command == 'long': - entry_price = float(current_price) - stop_loss = entry_price * 0.998 # 0.2% stop - take_profit = entry_price * 1.004 # 0.4% target - else: - entry_price = float(current_price) - stop_loss = entry_price * 1.002 # 0.2% stop - take_profit = entry_price * 0.996 # 0.4% target - - await risk_manager.place_risk_managed_trade(command, entry_price, stop_loss, take_profit) + elif command == "long": + await risk_manager.place_risk_managed_trade("long") + elif command == "short": + await risk_manager.place_risk_managed_trade("short") + elif command: + print(f"Unknown command: {command}") - # Update daily P&L monitoring - await risk_manager.monitor_daily_pnl() + # Brief pause to allow async operations + await asyncio.sleep(0.1) except KeyboardInterrupt: - print("\nShutting down risk management system...") + print("\n\nShutting down risk management system...") + finally: + # Disconnect from real-time feeds + await suite.disconnect() + print("System disconnected. Goodbye!") + if __name__ == "__main__": asyncio.run(main()) diff --git a/docs/examples/realtime.md b/docs/examples/realtime.md index b120fd3..b7683c7 100644 --- a/docs/examples/realtime.md +++ b/docs/examples/realtime.md @@ -16,18 +16,30 @@ Start with simple real-time data consumption: ```python #!/usr/bin/env python """ -Basic real-time data streaming example +Basic real-time data streaming example. + +This example demonstrates: +- Connecting to real-time data feeds +- Handling tick (quote) updates +- Processing new bar events +- Monitoring connection health +- Displaying streaming statistics """ + import asyncio from datetime import datetime -from project_x_py import TradingSuite, EventType + +from project_x_py import EventType, TradingSuite +from project_x_py.event_bus import Event + async def main(): + """Main function to run real-time data streaming.""" # Create suite with real-time capabilities suite = await TradingSuite.create( ["MNQ"], timeframes=["15sec", "1min"], - initial_days=1 # Minimal historical data + initial_days=1, # Minimal historical data ) mnq_context = suite["MNQ"] @@ -39,39 +51,65 @@ async def main(): bar_count = 0 last_price = None - async def on_tick(event): + async def on_tick(event: Event): + """Handle tick updates.""" nonlocal tick_count, last_price tick_data = event.data tick_count += 1 - last_price = tick_data.get('price', 0) + last_price = tick_data.get("last") or last_price # Display every 10th tick to avoid spam if tick_count % 10 == 0: timestamp = datetime.now().strftime("%H:%M:%S") print(f"[{timestamp}] Tick #{tick_count}: ${last_price:.2f}") - async def on_new_bar(event): + async def on_new_bar(event: Event): + """Handle new bar events.""" nonlocal bar_count - bar_data = event.data bar_count += 1 timestamp = datetime.now().strftime("%H:%M:%S") - timeframe = bar_data.get('timeframe', 'unknown') - - print(f"[{timestamp}] New {timeframe} bar #{bar_count}:") - print(f" OHLC: ${bar_data['open']:.2f} / ${bar_data['high']:.2f} / ${bar_data['low']:.2f} / ${bar_data['close']:.2f}") - print(f" Volume: {bar_data.get('volume', 0)}") - print(f" Timestamp: {bar_data.get('timestamp')}") - async def on_connection_status(event): - status = event.data.get('status', 'unknown') - print(f"Connection Status: {status}") + # The event.data contains timeframe and nested data + event_data = event.data + timeframe = event_data.get("timeframe", "unknown") + + # Get the bar data directly from the event + bar_data = event_data.get("data", {}) + + if bar_data: + print(f"[{timestamp}] New {timeframe} bar #{bar_count}:") + + # Access the bar data fields directly + open_price = bar_data.get("open", 0) + high_price = bar_data.get("high", 0) + low_price = bar_data.get("low", 0) + close_price = bar_data.get("close", 0) + volume = bar_data.get("volume", 0) + bar_timestamp = bar_data.get("timestamp", "") + + print( + f" OHLC: ${open_price:.2f} / ${high_price:.2f} / " + f"${low_price:.2f} / ${close_price:.2f}" + ) + print(f" Volume: {volume}") + print(f" Timestamp: {bar_timestamp}") + + async def on_connection_status(event: Event): + """Handle connection status changes.""" + status = event.data.get("connected", False) + print(f"Connection Status Changed: {status}") + if status: + print("āœ… Real-time feed connected") + else: + print("āŒ Real-time feed disconnected") # Register event handlers - await mnq_context.on(EventType.TICK, on_tick) + await mnq_context.on(EventType.QUOTE_UPDATE, on_tick) await mnq_context.on(EventType.NEW_BAR, on_new_bar) - await mnq_context.on(EventType.CONNECTION_STATUS, on_connection_status) + await mnq_context.on(EventType.CONNECTED, on_connection_status) + await mnq_context.on(EventType.DISCONNECTED, on_connection_status) print("Listening for real-time data... Press Ctrl+C to exit") @@ -81,12 +119,21 @@ async def main(): # Display periodic status current_price = await mnq_context.data.get_current_price() - connection_health = await mnq_context.data.get_connection_health() + connection_health = await mnq_context.data.get_health_score() - print(f"Status - Price: ${current_price:.2f} | Ticks: {tick_count} | Bars: {bar_count} | Health: {connection_health}") + print( + f"Status - Price: ${current_price:.2f} | " + f"Ticks: {tick_count} | Bars: {bar_count} | " + f"Health: {connection_health}" + ) except KeyboardInterrupt: print("\nShutting down real-time stream...") + finally: + # Ensure proper cleanup + await suite.disconnect() + print("Disconnected from real-time feeds") + if __name__ == "__main__": asyncio.run(main()) @@ -101,11 +148,14 @@ Handle multiple timeframes with proper synchronization: """ Multi-timeframe real-time data synchronization """ + import asyncio from collections import defaultdict from datetime import datetime -from project_x_py import TradingSuite, EventType -from project_x_py.indicators import SMA, RSI + +from project_x_py import EventType, TradingSuite +from project_x_py.indicators import RSI, SMA + class MultiTimeframeDataProcessor: def __init__(self, suite: TradingSuite): @@ -117,8 +167,8 @@ class MultiTimeframeDataProcessor: async def process_new_bar(self, event): """Process incoming bar data for all timeframes.""" - bar_data = event.data - timeframe = bar_data.get('timeframe', 'unknown') + bar_data = event.data.get("data", event.data) + timeframe = event.data.get("timeframe", "unknown") if timeframe not in self.timeframes: return @@ -130,7 +180,9 @@ class MultiTimeframeDataProcessor: if len(self.data_cache[timeframe]) > 200: self.data_cache[timeframe] = self.data_cache[timeframe][-100:] - print(f"New {timeframe} bar: ${bar_data['close']:.2f} @ {bar_data.get('timestamp')}") + print( + f"New {timeframe} bar: ${bar_data['close']:.2f} @ {bar_data.get('timestamp')}" + ) # Perform analysis on this timeframe await self.analyze_timeframe(timeframe) @@ -143,22 +195,29 @@ class MultiTimeframeDataProcessor: """Analyze a specific timeframe with technical indicators.""" try: # Get fresh data from suite - bars = await self.suite.data.get_data(timeframe) + bars = await self.suite["MNQ"].data.get_data(timeframe) + + if bars is None: + return if len(bars) < 50: # Need enough data for indicators return # Calculate indicators - sma_20 = bars.pipe(SMA, period=20) - rsi = bars.pipe(RSI, period=14) + bars = bars.pipe(SMA, period=20).pipe(RSI, period=14) - current_price = bars['close'][-1] - current_sma = sma_20[-1] - current_rsi = rsi[-1] + current_price = bars["close"][-1] + current_sma = bars["sma_20"][-1] + current_rsi = bars["rsi_14"][-1] # Determine trend and momentum trend = "bullish" if current_price > current_sma else "bearish" - momentum = "strong" if (trend == "bullish" and current_rsi > 50) or (trend == "bearish" and current_rsi < 50) else "weak" + momentum = ( + "strong" + if (trend == "bullish" and current_rsi > 50) + or (trend == "bearish" and current_rsi < 50) + else "weak" + ) # Store analysis self.last_analysis[timeframe] = { @@ -167,10 +226,12 @@ class MultiTimeframeDataProcessor: "rsi": current_rsi, "trend": trend, "momentum": momentum, - "timestamp": datetime.now() + "timestamp": datetime.now(), } - print(f" {timeframe} Analysis - Trend: {trend}, RSI: {current_rsi:.1f}, Momentum: {momentum}") + print( + f" {timeframe} Analysis - Trend: {trend}, RSI: {current_rsi:.1f}, Momentum: {momentum}" + ) except Exception as e: print(f"Error analyzing {timeframe}: {e}") @@ -187,19 +248,29 @@ class MultiTimeframeDataProcessor: return # Count bullish/bearish signals - bullish_count = sum(1 for analysis in self.last_analysis.values() - if analysis.get('trend') == 'bullish') - bearish_count = sum(1 for analysis in self.last_analysis.values() - if analysis.get('trend') == 'bearish') + bullish_count = sum( + 1 + for analysis in self.last_analysis.values() + if analysis.get("trend") == "bullish" + ) + bearish_count = sum( + 1 + for analysis in self.last_analysis.values() + if analysis.get("trend") == "bearish" + ) # Check for strong confluence total_timeframes = len(self.last_analysis) if bullish_count >= total_timeframes * 0.8: # 80% agreement - print(f"\n= BULLISH CONFLUENCE DETECTED ({bullish_count}/{total_timeframes})") + print( + f"\n= BULLISH CONFLUENCE DETECTED ({bullish_count}/{total_timeframes})" + ) await self.display_confluence_analysis("BULLISH") elif bearish_count >= total_timeframes * 0.8: - print(f"\n=4 BEARISH CONFLUENCE DETECTED ({bearish_count}/{total_timeframes})") + print( + f"\n=4 BEARISH CONFLUENCE DETECTED ({bearish_count}/{total_timeframes})" + ) await self.display_confluence_analysis("BEARISH") async def display_confluence_analysis(self, signal_type: str): @@ -208,10 +279,12 @@ class MultiTimeframeDataProcessor: print("-" * 40) for tf, analysis in self.last_analysis.items(): - trend_emoji = "=" if analysis['trend'] == 'bullish' else "=" - momentum_emoji = "=" if analysis['momentum'] == 'strong' else "=" + trend_emoji = "=" if analysis["trend"] == "bullish" else "=" + momentum_emoji = "=" if analysis["momentum"] == "strong" else "=" - print(f" {tf:>5} {trend_emoji} {analysis['trend']:>8} | RSI: {analysis['rsi']:>5.1f} | {momentum_emoji} {analysis['momentum']}") + print( + f" {tf:>5} {trend_emoji} {analysis['trend']:>8} | RSI: {analysis['rsi']:>5.1f} | {momentum_emoji} {analysis['momentum']}" + ) print("-" * 40) @@ -220,12 +293,13 @@ class MultiTimeframeDataProcessor: print(f"Current Price: ${current_price:.2f}") print() + async def main(): # Create suite with multiple timeframes suite = await TradingSuite.create( "MNQ", timeframes=["1min", "5min", "15min"], - initial_days=3 # Enough data for indicators + initial_days=3, # Enough data for indicators ) processor = MultiTimeframeDataProcessor(suite) @@ -246,13 +320,16 @@ async def main(): for tf in processor.timeframes: cached_bars = len(processor.data_cache[tf]) analysis = processor.last_analysis.get(tf, {}) - trend = analysis.get('trend', 'unknown') - rsi = analysis.get('rsi', 0) - print(f" {tf}: {cached_bars} bars cached, {trend} trend, RSI: {rsi:.1f}") + trend = analysis.get("trend", "unknown") + rsi = analysis.get("rsi", 0) + print( + f" {tf}: {cached_bars} bars cached, {trend} trend, RSI: {rsi:.1f}" + ) except KeyboardInterrupt: print("\nShutting down multi-timeframe processor...") + if __name__ == "__main__": asyncio.run(main()) ``` @@ -266,12 +343,15 @@ Export real-time data and create visualizations: """ Real-time data export with CSV logging and Plotly visualization """ + import asyncio import csv import json from datetime import datetime, timedelta from pathlib import Path -from project_x_py import TradingSuite, EventType + +from project_x_py import EventType, TradingSuite + class RealTimeDataExporter: def __init__(self, suite: TradingSuite, export_dir: str = "data_exports"): @@ -292,54 +372,59 @@ class RealTimeDataExporter: """Initialize CSV files for data export.""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - # Tick data CSV - tick_file = self.export_dir / f"ticks_{timestamp}.csv" - tick_csv = open(tick_file, 'w', newline='') - tick_writer = csv.writer(tick_csv) - tick_writer.writerow(['timestamp', 'price', 'size', 'bid', 'ask']) - self.csv_files['ticks'] = {'file': tick_csv, 'writer': tick_writer} - # Bar data CSV bar_file = self.export_dir / f"bars_{timestamp}.csv" - bar_csv = open(bar_file, 'w', newline='') + bar_csv = open(bar_file, "w", newline="") bar_writer = csv.writer(bar_csv) - bar_writer.writerow(['timestamp', 'timeframe', 'open', 'high', 'low', 'close', 'volume']) - self.csv_files['bars'] = {'file': bar_csv, 'writer': bar_writer} + bar_writer.writerow( + ["timestamp", "timeframe", "open", "high", "low", "close", "volume"] + ) + self.csv_files["bars"] = {"file": bar_csv, "writer": bar_writer} print(f"Export files initialized in {self.export_dir}") async def process_bar(self, event): """Process and export bar data.""" - bar_data = event.data timestamp = datetime.now().isoformat() + # Get the real data for the timeframe + # Data from the event is from the new bar that was just started, so we need to get the previous bar + real_data = await self.suite["MNQ"].data.get_data( + event.data.get("timeframe", "unknown") + ) + + if real_data is None: + return + # Store in memory bar_record = { - 'timestamp': timestamp, - 'bar_timestamp': bar_data.get('timestamp'), - 'timeframe': bar_data.get('timeframe', 'unknown'), - 'open': bar_data.get('open', 0), - 'high': bar_data.get('high', 0), - 'low': bar_data.get('low', 0), - 'close': bar_data.get('close', 0), - 'volume': bar_data.get('volume', 0) + "timestamp": timestamp, + "bar_timestamp": real_data["timestamp"][-2], + "timeframe": event.data.get("timeframe", "unknown"), + "open": real_data["open"][-2], + "high": real_data["high"][-2], + "low": real_data["low"][-2], + "close": real_data["close"][-2], + "volume": real_data["volume"][-2], } self.bar_data.append(bar_record) # Write to CSV - if 'bars' in self.csv_files: - writer = self.csv_files['bars']['writer'] - writer.writerow([ - bar_record['bar_timestamp'] or timestamp, - bar_record['timeframe'], - bar_record['open'], - bar_record['high'], - bar_record['low'], - bar_record['close'], - bar_record['volume'] - ]) - self.csv_files['bars']['file'].flush() + if "bars" in self.csv_files: + writer = self.csv_files["bars"]["writer"] + writer.writerow( + [ + bar_record["bar_timestamp"] or timestamp, + bar_record["timeframe"], + bar_record["open"], + bar_record["high"], + bar_record["low"], + bar_record["close"], + bar_record["volume"], + ] + ) + self.csv_files["bars"]["file"].flush() print(f"Exported {bar_record['timeframe']} bar: ${bar_record['close']:.2f}") @@ -348,21 +433,17 @@ class RealTimeDataExporter: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") snapshot = { - 'export_timestamp': datetime.now().isoformat(), - 'data_summary': { - 'tick_count': len(self.tick_data), - 'bar_count': len(self.bar_data), - 'trade_count': len(self.trade_data) + "export_timestamp": datetime.now().isoformat(), + "data_summary": { + "bar_count": len(self.bar_data), + }, + "recent_data": { + "bars": self.bar_data[-5:], # Last 5 bars }, - 'recent_data': { - 'ticks': self.tick_data[-10:], # Last 10 ticks - 'bars': self.bar_data[-5:], # Last 5 bars - 'trades': self.trade_data[-20:] # Last 20 trades - } } json_file = self.export_dir / f"snapshot_{timestamp}.json" - with open(json_file, 'w') as f: + with open(json_file, "w") as f: json.dump(snapshot, f, indent=2) print(f"JSON snapshot exported: {json_file}") @@ -371,17 +452,18 @@ class RealTimeDataExporter: def close_files(self): """Close all open CSV files.""" for file_info in self.csv_files.values(): - file_info['file'].close() + file_info["file"].close() print("Export files closed") + async def main(): # Create suite for data export suite = await TradingSuite.create( - "MNQ", - timeframes=["1min", "5min"], - initial_days=1 + "MNQ", timeframes=["15sec", "1min", "5min"], initial_days=1 ) + mnq_context = suite["MNQ"] + exporter = RealTimeDataExporter(suite) await exporter.initialize_export_files() @@ -400,7 +482,10 @@ async def main(): export_timer += 10 # Periodic status - current_price = await suite.data.get_current_price() + current_price = await mnq_context.data.get_current_price() + if current_price is None: + continue + print(f"Price: ${current_price:.2f} | Bars: {len(exporter.bar_data)}") # Auto-export JSON snapshot every 5 minutes @@ -420,6 +505,7 @@ async def main(): print("Data export complete!") + if __name__ == "__main__": asyncio.run(main()) ``` diff --git a/examples/advanced_trading/01_advanced_bracket_order_strategy.py b/examples/advanced_trading/01_advanced_bracket_order_strategy.py new file mode 100644 index 0000000..abab949 --- /dev/null +++ b/examples/advanced_trading/01_advanced_bracket_order_strategy.py @@ -0,0 +1,343 @@ +#!/usr/bin/env python +""" +Advanced bracket order strategy with dynamic stops based on ATR. + +This example demonstrates: +- ATR-based dynamic stop loss and take profit levels +- RSI and SMA-based entry signals +- Bracket orders with automatic price alignment +- Real-time order monitoring and management +- Event-driven trade execution +""" + +import asyncio +from decimal import Decimal +from typing import Optional + +from project_x_py import EventType, TradingSuite +from project_x_py.event_bus import Event +from project_x_py.indicators import ATR, RSI, SMA +from project_x_py.models import BracketOrderResponse + + +class ATRBracketStrategy: + """Advanced bracket order strategy using ATR for dynamic stops.""" + + def __init__(self, suite: TradingSuite): + self.suite = suite + self.atr_period = 14 + self.rsi_period = 14 + self.sma_period = 20 + self.position_size = 1 + self.active_orders: list[BracketOrderResponse] = [] + self.max_positions = 1 # Limit concurrent positions + + async def calculate_dynamic_levels(self) -> tuple[float, float]: + """Calculate stop and target levels based on ATR.""" + try: + # Get bars for ATR calculation + bars = await self.suite["MNQ"].data.get_data("5min") + + if bars is None or bars.is_empty(): + print("No data available for ATR calculation") + # Return default values + return 50.0, 100.0 + + # Calculate ATR for volatility-based stops + with_atr = bars.pipe(ATR, period=self.atr_period) + + # Get current ATR value + atr_column = f"atr_{self.atr_period}" + if atr_column not in with_atr.columns: + print(f"ATR column {atr_column} not found") + return 50.0, 100.0 + + current_atr = float(with_atr[atr_column].tail(1)[0]) + + # Dynamic stop loss: 2x ATR + stop_offset = current_atr * 2 + + # Dynamic take profit: 3x ATR (1.5:1 reward:risk) + target_offset = current_atr * 3 + + return stop_offset, target_offset + + except Exception as e: + print(f"Error calculating ATR levels: {e}") + # Return default values on error + return 50.0, 100.0 + + async def check_entry_conditions(self) -> tuple[Optional[str], Optional[float]]: + """Check if conditions are met for entry.""" + try: + bars = await self.suite["MNQ"].data.get_data("5min") + + if bars is None or bars.is_empty(): + return None, None + + # Ensure we have enough data for indicators + if len(bars) < max(self.rsi_period, self.sma_period, self.atr_period): + return None, None + + # Calculate indicators using pipe method + with_rsi = bars.pipe(RSI, period=self.rsi_period) + with_sma = with_rsi.pipe(SMA, period=self.sma_period) + + # Get current values from the last row + last_row = with_sma.tail(1) + + current_price = float(last_row["close"][0]) + + # Get RSI value + rsi_column = f"rsi_{self.rsi_period}" + current_rsi = ( + float(last_row[rsi_column][0]) + if rsi_column in last_row.columns + else 50.0 + ) + + # Get SMA value + sma_column = f"sma_{self.sma_period}" + current_sma = ( + float(last_row[sma_column][0]) + if sma_column in last_row.columns + else current_price + ) + + # Long signal: Price above SMA and RSI oversold recovery + if current_price > current_sma and 30 < current_rsi < 50: + return "long", current_price + + # Short signal: Price below SMA and RSI overbought decline + elif current_price < current_sma and 50 < current_rsi < 70: + return "short", current_price + + return None, None + + except Exception as e: + print(f"Error checking entry conditions: {e}") + return None, None + + async def place_bracket_order( + self, direction: str + ) -> Optional[BracketOrderResponse]: + """Place a bracket order based on strategy conditions.""" + try: + # Get current price + current_price = await self.suite["MNQ"].data.get_current_price() + if not current_price: + print("Could not get current price") + return None + + # Calculate dynamic stop and target levels (offsets) + stop_offset, target_offset = await self.calculate_dynamic_levels() + + # Calculate actual price levels + if direction == "long": + stop_loss_price = current_price - stop_offset + take_profit_price = current_price + target_offset + side = 0 # Buy + else: # short + stop_loss_price = current_price + stop_offset + take_profit_price = current_price - target_offset + side = 1 # Sell + + # Display trade setup + print("\n" + "=" * 60) + print(f"{direction.upper()} BRACKET ORDER SETUP") + print("=" * 60) + print(f"Current Price: ${current_price:.2f}") + print(f"Position Size: {self.position_size} contracts") + print(f"Stop Loss: ${stop_loss_price:.2f} ({stop_offset:.2f} points)") + print(f"Take Profit: ${take_profit_price:.2f} ({target_offset:.2f} points)") + + # Calculate risk/reward + risk = abs(current_price - stop_loss_price) + reward = abs(take_profit_price - current_price) + rr_ratio = reward / risk if risk > 0 else 0 + print(f"Risk/Reward Ratio: {rr_ratio:.2f}:1") + print("=" * 60) + + # Get instrument contract ID + instrument = self.suite["MNQ"].instrument_info + contract_id = instrument.id if hasattr(instrument, "id") else "MNQ" + + print("\nPlacing bracket order...") + + # Place bracket order with market entry + # Prices will be automatically aligned to tick size + result = await self.suite["MNQ"].orders.place_bracket_order( + contract_id=contract_id, + side=side, + size=self.position_size, + entry_price=None, # Market order + entry_type="market", + stop_loss_price=stop_loss_price, + take_profit_price=take_profit_price, + ) + + if result and result.success: + print("\nāœ… Bracket order placed successfully!") + print(f" Entry Order ID: {result.entry_order_id}") + print(f" Stop Order ID: {result.stop_order_id}") + print(f" Target Order ID: {result.target_order_id}") + + self.active_orders.append(result) + return result + else: + error_msg = result.error_message if result else "Unknown error" + print(f"\nāŒ Failed to place bracket order: {error_msg}") + return None + + except Exception as e: + print(f"Failed to place bracket order: {e}") + import traceback + + traceback.print_exc() + return None + + async def monitor_orders(self): + """Monitor active orders and handle fills/cancellations.""" + if not self.active_orders: + return + + # Copy list to allow modification during iteration + for bracket in self.active_orders[:]: + try: + if bracket is None: + continue + + # For this example, we'll just track the count + # In production, you would check order status via the API + + # Note: The actual order monitoring would typically be done + # through event handlers rather than polling + + except Exception as e: + print(f"Error monitoring orders: {e}") + + def remove_completed_order(self, order_id: int): + """Remove a completed order from tracking.""" + self.active_orders = [ + bracket + for bracket in self.active_orders + if bracket and bracket.entry_order_id != order_id + ] + + +async def main(): + """Main function to run the ATR bracket strategy.""" + print("Initializing Advanced Bracket Order Strategy...") + + # Create trading suite with required timeframes + suite = await TradingSuite.create( + ["MNQ"], + timeframes=["1min", "5min"], + initial_days=10, # Need historical data for indicators + features=["risk_manager"], + ) + + # Initialize strategy + strategy = ATRBracketStrategy(suite) + mnq_context = suite["MNQ"] + + # Track last bar time to avoid duplicate processing + last_bar_time = {} + + # Set up event handlers for real-time monitoring + async def on_new_bar(event: Event): + """Handle new bar events.""" + timeframe = event.data.get("timeframe", "unknown") + + if timeframe == "5min": + # Avoid duplicate processing + current_time = event.data.get("timestamp", "") + if current_time == last_bar_time.get(timeframe): + return + last_bar_time[timeframe] = current_time + + # Get bar data + bar_data = event.data.get("data", {}) + close_price = bar_data.get("close", 0) + + if close_price: + print(f"\nNew 5min bar: ${close_price:.2f}") + + # Check if we can take a new position + if len(strategy.active_orders) >= strategy.max_positions: + return + + # Check for entry signals + direction, price = await strategy.check_entry_conditions() + if direction: + print( + f"\nšŸŽÆ Entry signal detected: {direction.upper()} at ${price:.2f}" + ) + + # Auto-confirm for demo, or ask user + if False: # Set to True for auto-trading + await strategy.place_bracket_order(direction) + else: + # Confirm with user before placing order + response = input( + f"Place {direction.upper()} bracket order? (y/N): " + ) + if response.lower().startswith("y"): + await strategy.place_bracket_order(direction) + + # Register event handlers + await mnq_context.on(EventType.NEW_BAR, on_new_bar) + + print("\n" + "=" * 60) + print("ADVANCED BRACKET ORDER STRATEGY ACTIVE") + print("=" * 60) + print("Strategy Settings:") + print(f" ATR Period: {strategy.atr_period}") + print(f" RSI Period: {strategy.rsi_period}") + print(f" SMA Period: {strategy.sma_period}") + print(f" Position Size: {strategy.position_size} contracts") + print(f" Max Positions: {strategy.max_positions}") + print("\nMonitoring for entry signals on 5-minute bars...") + print("Press Ctrl+C to exit") + print("=" * 60) + + try: + while True: + await asyncio.sleep(30) # Status update every 30 seconds + + # Monitor active orders + await strategy.monitor_orders() + + # Display current market info + current_price = await mnq_context.data.get_current_price() + if current_price: + active_count = len(strategy.active_orders) + print( + f"\nStatus: Price=${current_price:.2f} | Active Orders={active_count}" + ) + + except KeyboardInterrupt: + print("\n\nShutting down strategy...") + + # Cancel any remaining orders + for bracket in strategy.active_orders: + if bracket: + try: + # Cancel stop and target orders + if bracket.stop_order_id: + await mnq_context.orders.cancel_order(bracket.stop_order_id) + print(f"Cancelled stop order {bracket.stop_order_id}") + if bracket.target_order_id: + await mnq_context.orders.cancel_order(bracket.target_order_id) + print(f"Cancelled target order {bracket.target_order_id}") + except Exception as e: + print(f"Error cancelling orders: {e}") + + finally: + # Disconnect from real-time feeds + await suite.disconnect() + print("Strategy disconnected. Goodbye!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/advanced_trading/02_multi_timeframe_momentum_strategy.py b/examples/advanced_trading/02_multi_timeframe_momentum_strategy.py new file mode 100644 index 0000000..939bb62 --- /dev/null +++ b/examples/advanced_trading/02_multi_timeframe_momentum_strategy.py @@ -0,0 +1,417 @@ +#!/usr/bin/env python +""" +Multi-timeframe momentum strategy with confluence analysis. + +This example demonstrates: +- Multi-timeframe analysis (5min, 15min, 1hr) +- Momentum and trend confluence detection +- Technical indicators (RSI, MACD, EMA, ATR) +- Dynamic position sizing based on ATR +- Bracket orders with volatility-based stops +""" + +import asyncio +from decimal import Decimal +from typing import Any, Optional + +from project_x_py import EventType, TradingSuite +from project_x_py.event_bus import Event +from project_x_py.indicators import ATR, EMA, MACD, RSI +from project_x_py.models import BracketOrderResponse + + +class MultiTimeframeMomentumStrategy: + """Multi-timeframe momentum trading strategy.""" + + def __init__(self, suite: TradingSuite): + self.suite = suite + self.position_size = 1 + self.risk_per_trade = Decimal("0.02") # 2% risk per trade + self.account_balance = Decimal("50000") # Default balance + self.active_position: Optional[dict[str, Any]] = None + + async def analyze_timeframe(self, timeframe: str) -> Optional[dict[str, Any]]: + """Analyze a specific timeframe for momentum signals.""" + try: + # Get bars for the timeframe + bars = await self.suite["MNQ"].data.get_data(timeframe) + + if bars is None or bars.is_empty(): + print(f"No data available for {timeframe}") + return None + + if len(bars) < 50: # Need sufficient data for indicators + print(f"Insufficient data for {timeframe} (need 50+ bars)") + return None + + # Calculate indicators using pipe method + with_rsi = bars.pipe(RSI, period=14) + with_macd = with_rsi.pipe( + MACD, fast_period=12, slow_period=26, signal_period=9 + ) + with_ema20 = with_macd.pipe(EMA, period=20) + with_ema50 = with_ema20.pipe(EMA, period=50) + + # Get the last row for current values + last_row = with_ema50.tail(1) + + # Extract values from the last row + current_price = float(last_row["close"][0]) + current_rsi = ( + float(last_row["rsi_14"][0]) if "rsi_14" in last_row.columns else 50.0 + ) + + # MACD values + current_macd = ( + float(last_row["macd"][0]) if "macd" in last_row.columns else 0.0 + ) + macd_signal = ( + float(last_row["macd_signal"][0]) + if "macd_signal" in last_row.columns + else 0.0 + ) + + # EMA values + current_ema_20 = ( + float(last_row["ema_20"][0]) + if "ema_20" in last_row.columns + else current_price + ) + current_ema_50 = ( + float(last_row["ema_50"][0]) + if "ema_50" in last_row.columns + else current_price + ) + + # Determine trend and momentum + trend = "bullish" if current_ema_20 > current_ema_50 else "bearish" + momentum = "positive" if current_macd > macd_signal else "negative" + rsi_level = ( + "oversold" + if current_rsi < 30 + else "overbought" + if current_rsi > 70 + else "neutral" + ) + + return { + "timeframe": timeframe, + "price": current_price, + "trend": trend, + "momentum": momentum, + "rsi_level": rsi_level, + "rsi": current_rsi, + "macd": current_macd, + "macd_signal": macd_signal, + "ema_20": current_ema_20, + "ema_50": current_ema_50, + } + + except Exception as e: + print(f"Error analyzing {timeframe}: {e}") + return None + + async def check_confluence(self) -> tuple[Optional[str], Optional[list]]: + """Check for confluence across multiple timeframes.""" + # Analyze all configured timeframes + analyses = [] + + for timeframe in ["5min", "15min", "1hr"]: + if timeframe in self.suite["MNQ"].data.timeframes: + analysis = await self.analyze_timeframe(timeframe) + if analysis: + analyses.append(analysis) + + if len(analyses) < 2: + return None, analyses if analyses else None + + # Count bullish/bearish signals + bullish_signals = sum( + 1 + for tf in analyses + if tf["trend"] == "bullish" and tf["momentum"] == "positive" + ) + bearish_signals = sum( + 1 + for tf in analyses + if tf["trend"] == "bearish" and tf["momentum"] == "negative" + ) + + # Get the lowest timeframe analysis (usually 5min) + entry_tf = analyses[0] # First timeframe for entry conditions + + # Require confluence (majority agreement) + if bullish_signals >= 2 and entry_tf["rsi"] < 70: # Not overbought + return "long", analyses + elif bearish_signals >= 2 and entry_tf["rsi"] > 30: # Not oversold + return "short", analyses + + return None, analyses + + async def calculate_position_size( + self, entry_price: float, stop_loss: float + ) -> int: + """Calculate position size based on risk management.""" + # Calculate risk amount + risk_amount = float(self.account_balance) * float(self.risk_per_trade) + + # Calculate risk per contract (MNQ = $20 per point) + price_diff = abs(entry_price - stop_loss) + risk_per_contract = price_diff * 20 + + if risk_per_contract <= 0: + return 1 + + # Calculate position size + calculated_size = int(risk_amount / risk_per_contract) + return max(1, min(calculated_size, 5)) # Between 1-5 contracts + + async def calculate_atr_stops( + self, direction: str, current_price: float, timeframe: str = "5min" + ) -> tuple[float, float]: + """Calculate ATR-based stop loss and take profit.""" + try: + # Get bars for ATR calculation + bars = await self.suite["MNQ"].data.get_data(timeframe) + if bars is None or bars.is_empty(): + # Fallback to fixed stops + if direction == "long": + return current_price - 50, current_price + 100 + else: + return current_price + 50, current_price - 100 + + # Calculate ATR + with_atr = bars.pipe(ATR, period=14) + current_atr = float(with_atr["atr_14"].tail(1)[0]) + + # Dynamic stops based on volatility (2x ATR stop, 3x ATR target) + if direction == "long": + stop_loss = current_price - (current_atr * 2) + take_profit = current_price + (current_atr * 3) + else: + stop_loss = current_price + (current_atr * 2) + take_profit = current_price - (current_atr * 3) + + return stop_loss, take_profit + + except Exception as e: + print(f"Error calculating ATR stops: {e}") + # Fallback to fixed stops + if direction == "long": + return current_price - 50, current_price + 100 + else: + return current_price + 50, current_price - 100 + + async def place_momentum_trade( + self, direction: str, analyses: list + ) -> Optional[Any]: + """Place a trade based on momentum confluence.""" + try: + # Use the entry timeframe price + current_price = analyses[0]["price"] + + # Calculate ATR-based stops + stop_loss, take_profit = await self.calculate_atr_stops( + direction, current_price + ) + + # Calculate position size + position_size = await self.calculate_position_size(current_price, stop_loss) + + # Display trade setup + print("\n" + "=" * 60) + print(f"{direction.upper()} MOMENTUM TRADE SETUP") + print("=" * 60) + print(f"Entry Price: ${current_price:.2f}") + print( + f"Stop Loss: ${stop_loss:.2f} ({abs(current_price - stop_loss):.2f} points)" + ) + print( + f"Take Profit: ${take_profit:.2f} ({abs(take_profit - current_price):.2f} points)" + ) + print(f"Position Size: {position_size} contracts") + + # Calculate risk/reward + risk = abs(current_price - stop_loss) + reward = abs(take_profit - current_price) + rr_ratio = reward / risk if risk > 0 else 0 + print(f"Risk/Reward: {rr_ratio:.2f}:1") + + # Display confluence analysis + print("\nConfluence Analysis:") + for analysis in analyses: + print( + f" {analysis['timeframe']:5s}: " + f"{analysis['trend']:7s} trend, " + f"{analysis['momentum']:8s} momentum, " + f"RSI: {analysis['rsi']:5.1f}" + ) + print("=" * 60) + + # Confirm trade + response = input(f"\nPlace {direction.upper()} momentum trade? (y/N): ") + if not response.lower().startswith("y"): + print("Trade cancelled") + return None + + # Get instrument contract ID + instrument = self.suite["MNQ"].instrument_info + contract_id = instrument.id if hasattr(instrument, "id") else "MNQ" + + # Determine side + side = 0 if direction == "long" else 1 # 0=Buy, 1=Sell + + print("\nPlacing bracket order...") + + # Place bracket order with market entry + result = await self.suite["MNQ"].orders.place_bracket_order( + contract_id=contract_id, + side=side, + size=position_size, + entry_price=None, # Market order + entry_type="market", + stop_loss_price=stop_loss, + take_profit_price=take_profit, + ) + + if result and result.success: + self.active_position = { + "direction": direction, + "entry_price": current_price, + "stop_loss": stop_loss, + "take_profit": take_profit, + "size": position_size, + "bracket_result": result, + } + + print("\nāœ… Momentum trade placed successfully!") + print(f" Entry Order: {result.entry_order_id}") + print(f" Stop Order: {result.stop_order_id}") + print(f" Target Order: {result.target_order_id}") + else: + error_msg = result.error_message if result else "Unknown error" + print(f"\nāŒ Failed to place trade: {error_msg}") + + return result + + except Exception as e: + print(f"Failed to place momentum trade: {e}") + import traceback + + traceback.print_exc() + return None + + +async def main(): + """Main function to run the momentum strategy.""" + print("Initializing Multi-Timeframe Momentum Strategy...") + + # Create suite with multiple timeframes + suite = await TradingSuite.create( + ["MNQ"], + timeframes=["5min", "15min", "1hr"], + initial_days=15, # More historical data for higher timeframes + features=["risk_manager"], + ) + + mnq_context = suite["MNQ"] + strategy = MultiTimeframeMomentumStrategy(suite) + + # Event handlers + last_bar_time = {} + + async def on_new_bar(event: Event): + """Handle new bar events.""" + # Get timeframe from event data + timeframe = event.data.get("timeframe", "unknown") + + # Only act on 5min bars for trade decisions + if timeframe == "5min": + # Avoid duplicate processing + current_time = event.data.get("timestamp", "") + if current_time == last_bar_time.get(timeframe): + return + last_bar_time[timeframe] = current_time + + # Check for confluence signals + direction, analyses = await strategy.check_confluence() + + if analyses and not strategy.active_position: + if direction: + print(f"\n{'=' * 60}") + print(f"MOMENTUM CONFLUENCE DETECTED: {direction.upper()}") + print(f"{'=' * 60}") + await strategy.place_momentum_trade(direction, analyses) + else: + # Display current analysis (no confluence) + print("\nCurrent Market Analysis (No Confluence):") + for analysis in analyses: + print( + f" {analysis['timeframe']:5s}: " + f"{analysis['trend']:7s}/{analysis['momentum']:8s} " + f"(RSI: {analysis['rsi']:5.1f})" + ) + + # Register event handlers + await mnq_context.on(EventType.NEW_BAR, on_new_bar) + + print("\n" + "=" * 60) + print("MULTI-TIMEFRAME MOMENTUM STRATEGY ACTIVE") + print("=" * 60) + print("Analyzing 5min, 15min, and 1hr timeframes for confluence") + print("Looking for aligned trend and momentum signals") + print("Using ATR-based dynamic stops and targets") + print("\nPress Ctrl+C to exit") + print("=" * 60) + + try: + while True: + await asyncio.sleep(30) # Status update every 30 seconds + + # Display status + current_price = await mnq_context.data.get_current_price() + if current_price: + position_status = "ACTIVE" if strategy.active_position else "FLAT" + + print("\nStatus Update:") + print(f" Price: ${current_price:.2f}") + print(f" Position: {position_status}") + + if strategy.active_position: + pos = strategy.active_position + print(f" Direction: {pos['direction'].upper()}") + print(f" Entry: ${pos['entry_price']:.2f}") + print(f" Stop: ${pos['stop_loss']:.2f}") + print(f" Target: ${pos['take_profit']:.2f}") + + except KeyboardInterrupt: + print("\n\nShutting down strategy...") + + # Cancel active orders if any + if strategy.active_position: + bracket_result: BracketOrderResponse = strategy.active_position.get( + "bracket_result", {} + ) + if bracket_result: + try: + # Cancel stop and target orders + if bracket_result.stop_order_id: + await mnq_context.orders.cancel_order( + bracket_result.stop_order_id + ) + if bracket_result.target_order_id: + await mnq_context.orders.cancel_order( + bracket_result.target_order_id + ) + print("Cancelled active orders") + except Exception as e: + print(f"Error cancelling orders: {e}") + + finally: + # Disconnect from real-time feeds + await suite.disconnect() + print("Strategy disconnected. Goodbye!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/advanced_trading/03_advanced_risk_management_system.py b/examples/advanced_trading/03_advanced_risk_management_system.py new file mode 100644 index 0000000..e340f1a --- /dev/null +++ b/examples/advanced_trading/03_advanced_risk_management_system.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python +""" +Advanced risk management system with portfolio-level controls. + +This example demonstrates: +- Position sizing based on risk parameters +- Portfolio risk monitoring +- Bracket orders with automatic stop-loss and take-profit +- Real-time P&L tracking +- Risk limit enforcement +""" + +import asyncio +from datetime import datetime +from decimal import Decimal + +from project_x_py import EventType, TradingSuite +from project_x_py.event_bus import Event + + +class AdvancedRiskManager: + """Advanced risk management system for trading.""" + + def __init__(self, suite: TradingSuite): + self.suite = suite + + # Risk parameters + self.max_risk_per_trade = Decimal("0.02") # 2% per trade + self.max_daily_risk = Decimal("0.06") # 6% per day + self.max_portfolio_risk = Decimal("0.20") # 20% total portfolio + self.max_positions = 3 # Maximum open positions + + # Tracking + self.account_balance = Decimal("50000") # Default demo balance + self.daily_pnl = Decimal("0") + self.active_trades = [] + self.daily_reset_time = datetime.now().date() + + async def update_account_info(self): + """Update account information.""" + try: + # Try to get positions to calculate P&L + positions = await self.suite["MNQ"].positions.get_all_positions() + + # Calculate total P&L from positions + total_pnl = Decimal("0") + # Note: Actual P&L calculation would depend on position attributes + # This is a placeholder for demonstration + + # Update daily P&L + current_date = datetime.now().date() + if current_date > self.daily_reset_time: + self.daily_pnl = Decimal("0") + self.daily_reset_time = current_date + print(f"Daily P&L reset for {current_date}") + + self.daily_pnl += total_pnl + + except Exception as e: + print(f"Could not update account info: {e}") + + async def calculate_position_size( + self, entry_price: float, stop_loss: float + ) -> int: + """Calculate optimal position size based on risk parameters.""" + # Calculate risk amount + risk_amount = self.account_balance * self.max_risk_per_trade + + # Calculate risk per contract (MNQ = $20 per point) + price_diff = abs(Decimal(str(entry_price)) - Decimal(str(stop_loss))) + risk_per_contract = price_diff * 20 + + if risk_per_contract <= 0: + return 0 + + # Calculate position size + calculated_size = int(risk_amount / risk_per_contract) + + # Apply position limits + max_size = 10 # Hard limit + return max(1, min(calculated_size, max_size)) + + async def check_risk_limits(self, proposed_size: int) -> tuple[bool, list[str]]: + """Check if proposed trade violates risk limits.""" + errors = [] + + # Check maximum positions + positions = await self.suite["MNQ"].positions.get_all_positions() + active_positions = [p for p in positions if p.size != 0] + + if len(active_positions) >= self.max_positions: + errors.append(f"Maximum positions reached ({self.max_positions})") + + # Check daily risk limit + if abs(self.daily_pnl) >= (self.account_balance * self.max_daily_risk): + errors.append(f"Daily risk limit reached ({self.max_daily_risk * 100}%)") + + # Check portfolio risk + total_position_size = sum(abs(p.size) for p in active_positions) + if total_position_size + proposed_size > 20: # Max 20 contracts total + errors.append("Portfolio size limit reached (20 contracts max)") + + return len(errors) == 0, errors + + async def place_risk_managed_trade( + self, direction: str, stop_offset: float = 50, target_offset: float = 100 + ): + """Place a trade with full risk management.""" + try: + # Get current price + current_price = await self.suite["MNQ"].data.get_current_price() + if not current_price: + print("Could not get current price") + return None + + # Calculate entry, stop, and target prices + if direction == "long": + entry_price = current_price + stop_loss = current_price - stop_offset + take_profit = current_price + target_offset + side = 0 # Buy + else: + entry_price = current_price + stop_loss = current_price + stop_offset + take_profit = current_price - target_offset + side = 1 # Sell + + # Calculate position size + position_size = await self.calculate_position_size(entry_price, stop_loss) + + if position_size == 0: + print("Position size calculated as 0 - trade rejected") + return None + + # Check risk limits + risk_ok, risk_errors = await self.check_risk_limits(position_size) + + if not risk_ok: + print("Trade rejected due to risk limits:") + for error in risk_errors: + print(f" - {error}") + return None + + # Calculate trade risk + risk_per_contract = abs(entry_price - stop_loss) * 20 # MNQ multiplier + total_risk = risk_per_contract * position_size + risk_pct = float((total_risk / float(self.account_balance)) * 100) + + # Display trade details + print("\n" + "=" * 50) + print("RISK-MANAGED TRADE SETUP") + print("=" * 50) + print(f"Direction: {direction.upper()}") + print(f"Current Price: ${entry_price:.2f}") + print(f"Stop Loss: ${stop_loss:.2f} ({stop_offset} points)") + print(f"Take Profit: ${take_profit:.2f} ({target_offset} points)") + print(f"Position Size: {position_size} contracts") + print(f"Risk Amount: ${total_risk:.2f} ({risk_pct:.2f}% of account)") + print(f"R:R Ratio: {target_offset / stop_offset:.1f}:1") + + # Show current status + positions = await self.suite["MNQ"].positions.get_all_positions() + active_positions = [p for p in positions if p.size != 0] + print("\nCurrent Status:") + print(f" Active Positions: {len(active_positions)}") + print(f" Daily P&L: ${self.daily_pnl:.2f}") + print("=" * 50) + + # Confirm trade + response = input(f"\nProceed with {direction.upper()} trade? (y/N): ") + if not response.lower().startswith("y"): + print("Trade cancelled") + return None + + # Place bracket order + print("\nPlacing bracket order...") + + # Get the instrument contract ID + instrument = self.suite["MNQ"].instrument_info + contract_id = instrument.id if hasattr(instrument, "id") else "MNQ" + + result = await self.suite["MNQ"].orders.place_bracket_order( + contract_id=contract_id, + side=side, + size=position_size, + entry_price=None, # Market entry + entry_type="market", + stop_loss_price=stop_loss, + take_profit_price=take_profit, + ) + + if result and result.success: + # Track the trade + trade_record = { + "direction": direction, + "entry_price": entry_price, + "stop_loss": stop_loss, + "take_profit": take_profit, + "size": position_size, + "risk_amount": total_risk, + "bracket_result": result, + "timestamp": datetime.now(), + "status": "active", + } + self.active_trades.append(trade_record) + + print("\nāœ… Risk-managed trade placed successfully!") + print(f" Entry Order: {result.entry_order_id}") + print(f" Stop Order: {result.stop_order_id}") + print(f" Target Order: {result.target_order_id}") + else: + error_msg = result.error_message if result else "Unknown error" + print(f"\nāŒ Failed to place trade: {error_msg}") + + return result + + except Exception as e: + print(f"Failed to place risk-managed trade: {e}") + import traceback + + traceback.print_exc() + return None + + async def generate_risk_report(self): + """Generate comprehensive risk report.""" + await self.update_account_info() + + print("\n" + "=" * 60) + print("RISK MANAGEMENT REPORT") + print("=" * 60) + + print(f"Account Balance: ${self.account_balance:,.2f}") + print( + f"Daily P&L: ${self.daily_pnl:.2f} ({(self.daily_pnl / self.account_balance) * 100:.2f}%)" + ) + + # Get current positions + positions = await self.suite["MNQ"].positions.get_all_positions() + active_positions = [p for p in positions if p.size != 0] + + print(f"\nActive Positions: {len(active_positions)}") + for i, pos in enumerate(active_positions, 1): + side = "LONG" if pos.size > 0 else "SHORT" + print(f" {i}. {side} {abs(pos.size)} contracts") + + print("\nRisk Limits:") + print( + f" Per Trade: {self.max_risk_per_trade * 100:.1f}% (${self.account_balance * self.max_risk_per_trade:.2f})" + ) + print( + f" Daily: {self.max_daily_risk * 100:.1f}% (${self.account_balance * self.max_daily_risk:.2f})" + ) + print(f" Portfolio: {self.max_portfolio_risk * 100:.1f}%") + print(f" Max Positions: {self.max_positions}") + + if self.active_trades: + print("\nRecent Trades:") + for i, trade in enumerate(self.active_trades[-5:], 1): # Show last 5 + print( + f" {i}. {trade['direction'].upper()} - " + f"${trade['entry_price']:.2f} " + f"(Risk: ${trade['risk_amount']:.2f}) - " + f"{trade['status'].upper()}" + ) + + print("=" * 60) + + +async def main(): + """Main function to run the risk management system.""" + print("Initializing Advanced Risk Management System...") + + # Create TradingSuite with risk management features + suite = await TradingSuite.create( + ["MNQ"], + timeframes=["1min", "5min"], + initial_days=1, + features=["risk_manager"], # Enable risk manager feature + ) + + # Create risk manager + risk_manager = AdvancedRiskManager(suite) + mnq_context = suite["MNQ"] + + # Set up event handlers + async def on_new_bar(_event: Event): + """Handle new bar events to update P&L.""" + # Update account info on each new bar + await risk_manager.update_account_info() + + async def on_quote(_event: Event): + """Handle quote updates.""" + # Could use this for real-time P&L updates + # Placeholder for future real-time updates + + # Register event handlers + await mnq_context.on(EventType.NEW_BAR, on_new_bar) + await mnq_context.on(EventType.QUOTE_UPDATE, on_quote) + + print("\n" + "=" * 60) + print("ADVANCED RISK MANAGEMENT SYSTEM ACTIVE") + print("=" * 60) + print("\nCommands:") + print(" 'long' - Place risk-managed LONG trade") + print(" 'short' - Place risk-managed SHORT trade") + print(" 'report' - Generate risk report") + print(" 'quit' - Exit system") + print("=" * 60) + + try: + while True: + # Get user input + command = input("\nEnter command: ").strip().lower() + + if command == "quit": + break + elif command == "report": + await risk_manager.generate_risk_report() + elif command == "long": + await risk_manager.place_risk_managed_trade("long") + elif command == "short": + await risk_manager.place_risk_managed_trade("short") + elif command: + print(f"Unknown command: {command}") + + # Brief pause to allow async operations + await asyncio.sleep(0.1) + + except KeyboardInterrupt: + print("\n\nShutting down risk management system...") + finally: + # Disconnect from real-time feeds + await suite.disconnect() + print("System disconnected. Goodbye!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/advanced_trading/04_orderbook_analysis_and_scalping_strategy.py b/examples/advanced_trading/04_orderbook_analysis_and_scalping_strategy.py new file mode 100644 index 0000000..8bec50b --- /dev/null +++ b/examples/advanced_trading/04_orderbook_analysis_and_scalping_strategy.py @@ -0,0 +1,538 @@ +#!/usr/bin/env python +""" +Advanced order book analysis and scalping strategy. + +This example demonstrates: +- Level 2 order book analysis and imbalance detection +- Tape reading for momentum confirmation +- Market microstructure analysis for scalping +- Iceberg order detection +- Volume profile analysis +- Tight risk management for scalping +""" + +import asyncio +from collections import deque +from typing import Any, Optional + +from project_x_py import EventType, TradingSuite +from project_x_py.event_bus import Event +from project_x_py.models import BracketOrderResponse + + +class OrderBookScalpingStrategy: + """Scalping strategy based on order book analysis.""" + + def __init__(self, suite: TradingSuite): + self.suite = suite + self.orderbook = None + self.tick_history = deque(maxlen=100) + self.imbalance_threshold = 0.70 # 70% imbalance threshold + self.min_size_edge = 50 # Minimum size difference for edge + self.active_orders: list[dict[str, Any]] = [] + self.scalp_profit_ticks = 2 # Target 2 ticks profit + self.max_positions = 2 # Max concurrent scalps + + async def initialize_orderbook(self) -> bool: + """Initialize order book for analysis.""" + try: + # Access orderbook through InstrumentContext + mnq_context = self.suite["MNQ"] + if hasattr(mnq_context, "orderbook") and mnq_context.orderbook: + self.orderbook = mnq_context.orderbook + print("āœ… Order book initialized successfully") + return True + else: + print( + "āŒ Order book not available - ensure suite was created with 'orderbook' feature" + ) + return False + except Exception as e: + print(f"Failed to initialize order book: {e}") + return False + + async def analyze_order_book_imbalance(self) -> Optional[dict]: + """Analyze order book for size imbalances.""" + if not self.orderbook: + return None + + try: + # Get market imbalance using the correct method + imbalance_data = await self.orderbook.get_market_imbalance(levels=5) + + if imbalance_data: + # LiquidityAnalysisResponse is a TypedDict, access with brackets + if "imbalance_ratio" in imbalance_data: + ratio = float(imbalance_data["imbalance_ratio"]) + + if abs(ratio) >= self.imbalance_threshold: + return { + "direction": "bullish" if ratio > 0 else "bearish", + "strength": abs(ratio), + "bid_liquidity": imbalance_data.get("bid_liquidity", 0), + "ask_liquidity": imbalance_data.get("ask_liquidity", 0), + "spread": imbalance_data.get("spread", 0), + "levels": 5, + } + + # Fallback to orderbook snapshot + snapshot = await self.orderbook.get_orderbook_snapshot(levels=5) + + if snapshot: + # OrderbookSnapshot is a TypedDict, access with brackets + bids = snapshot.get("bids", []) + asks = snapshot.get("asks", []) + + # Calculate imbalance from snapshot + bid_sizes = sum(level.get("size", 0) for level in bids) if bids else 0 + ask_sizes = sum(level.get("size", 0) for level in asks) if asks else 0 + + if bid_sizes + ask_sizes > 0: + bid_ratio = bid_sizes / (bid_sizes + ask_sizes) + + if bid_ratio >= self.imbalance_threshold: + return { + "direction": "bullish", + "strength": bid_ratio, + "bid_size": bid_sizes, + "ask_size": ask_sizes, + "spread": float(snapshot.get("spread") or 0), + "levels": 5, + } + elif bid_ratio <= (1 - self.imbalance_threshold): + return { + "direction": "bearish", + "strength": 1 - bid_ratio, + "bid_size": bid_sizes, + "ask_size": ask_sizes, + "spread": float(snapshot.get("spread") or 0), + "levels": 5, + } + + return None + + except Exception as e: + print(f"Error analyzing order book: {e}") + return None + + async def check_for_iceberg_orders(self) -> Optional[dict]: + """Detect potential iceberg orders in the book.""" + if not self.orderbook: + return None + + try: + # Use orderbook's iceberg detection with correct parameters + iceberg_info = await self.orderbook.detect_iceberg_orders( + min_refreshes=3, volume_threshold=100, time_window_minutes=5 + ) + + if iceberg_info and iceberg_info.get("detected"): + detections = iceberg_info.get("detections", []) + if detections: + # Get the most confident detection + best_detection = max( + detections, key=lambda x: x.get("confidence", 0) + ) + return { + "detected": True, + "side": best_detection.get("side", "unknown"), + "price_level": best_detection.get("price", 0), + "confidence": best_detection.get("confidence", 0), + "refill_count": best_detection.get("refill_count", 0), + } + + return None + + except Exception as e: + print(f"Error detecting iceberg orders: {e}") + return None + + async def analyze_volume_profile(self) -> Optional[dict]: + """Analyze volume profile for key levels.""" + if not self.orderbook: + return None + + try: + # Get volume profile with correct parameters + profile = await self.orderbook.get_volume_profile( + time_window_minutes=60, price_bins=10 + ) + + if profile: + # Check if profile has the expected structure + poc = profile.get("poc") + value_area = profile.get("value_area") + + if poc: + return { + "poc": float(poc.get("price", 0)), + "poc_volume": poc.get("volume", 0), + "value_area_high": float(value_area.get("high", 0)) + if value_area + else 0, + "value_area_low": float(value_area.get("low", 0)) + if value_area + else 0, + "total_volume": profile.get("total_volume", 0), + } + + # If no profile data, return None + return None + + except Exception as e: + print(f"Error analyzing volume profile: {e}") + return None + + async def analyze_tape_reading(self) -> Optional[dict]: + """Analyze recent trades for momentum.""" + if len(self.tick_history) < 10: + return None + + recent_ticks = list(self.tick_history)[-10:] + + # Analyze trade aggressiveness + buy_volume = sum( + tick["size"] for tick in recent_ticks if tick.get("aggressor") == "buy" + ) + sell_volume = sum( + tick["size"] for tick in recent_ticks if tick.get("aggressor") == "sell" + ) + + total_volume = buy_volume + sell_volume + if total_volume == 0: + return None + + buy_ratio = buy_volume / total_volume + + # Strong buying/selling pressure + if buy_ratio >= 0.70: + return { + "direction": "bullish", + "strength": buy_ratio, + "volume": total_volume, + "buy_volume": buy_volume, + "sell_volume": sell_volume, + } + elif buy_ratio <= 0.30: + return { + "direction": "bearish", + "strength": 1 - buy_ratio, + "volume": total_volume, + "buy_volume": buy_volume, + "sell_volume": sell_volume, + } + + return None + + async def place_scalp_order( + self, direction: str, analysis_data: dict + ) -> Optional[BracketOrderResponse]: + """Place a scalping order with tight stops.""" + try: + mnq_context = self.suite["MNQ"] + + # Get current price + current_price = await mnq_context.data.get_current_price() + if not current_price: + print("Could not get current price") + return None + + tick_size = 0.25 # MNQ tick size + + if direction == "long": + entry_price = float(current_price) + stop_loss = entry_price - (tick_size * 3) # 3 tick stop + take_profit = entry_price + (tick_size * self.scalp_profit_ticks) + side = 0 + else: + entry_price = float(current_price) + stop_loss = entry_price + (tick_size * 3) # 3 tick stop + take_profit = entry_price - (tick_size * self.scalp_profit_ticks) + side = 1 + + # Display scalp setup + print("\n" + "=" * 60) + print(f"SCALP SETUP ({direction.upper()})") + print("=" * 60) + print(f"Entry: ${entry_price:.2f}") + print( + f"Stop: ${stop_loss:.2f} ({abs(entry_price - stop_loss) / tick_size:.0f} ticks)" + ) + print( + f"Target: ${take_profit:.2f} ({abs(take_profit - entry_price) / tick_size:.0f} ticks)" + ) + + # Display analysis details + if "orderbook" in analysis_data: + ob = analysis_data["orderbook"] + print("\nOrder Book Analysis:") + print(f" Imbalance: {ob['strength']:.2%} {ob['direction']}") + if "bid_size" in ob: + print(f" Bid Size: {ob['bid_size']}, Ask Size: {ob['ask_size']}") + elif "bid_liquidity" in ob: + print( + f" Bid Liquidity: {ob['bid_liquidity']}, Ask Liquidity: {ob['ask_liquidity']}" + ) + print(f" Spread: ${ob['spread']:.2f}") + + if "tape" in analysis_data: + tape = analysis_data["tape"] + print("\nTape Reading:") + print(f" Momentum: {tape['strength']:.2%} {tape['direction']}") + print(f" Volume: Buy={tape['buy_volume']}, Sell={tape['sell_volume']}") + + if "iceberg" in analysis_data and analysis_data["iceberg"]: + ice = analysis_data["iceberg"] + print( + f"\nāš ļø Iceberg Detected: {ice['side']} @ ${ice['price_level']:.2f}" + ) + + print("=" * 60) + + # Quick confirmation for scalping + response = input(f"\nExecute {direction.upper()} scalp? (y/N): ") + if not response.lower().startswith("y"): + return None + + # Get instrument contract ID + instrument = mnq_context.instrument_info + contract_id = instrument.id if hasattr(instrument, "id") else "MNQ" + + print("\nPlacing scalp order...") + + # Place bracket order with tight parameters + # Prices will be automatically aligned to tick size + result = await mnq_context.orders.place_bracket_order( + contract_id=contract_id, + side=side, + size=1, # Small size for scalping + entry_price=None, # Market order for quick fills + entry_type="market", + stop_loss_price=stop_loss, + take_profit_price=take_profit, + ) + + if result and result.success: + scalp_record = { + "direction": direction, + "entry_price": entry_price, + "stop_loss": stop_loss, + "take_profit": take_profit, + "bracket": result, + "analysis": analysis_data, + "timestamp": asyncio.get_event_loop().time(), + } + + self.active_orders.append(scalp_record) + + print("āœ… Scalp order placed successfully!") + print(f" Entry Order: {result.entry_order_id}") + print(f" Stop Order: {result.stop_order_id}") + print(f" Target Order: {result.target_order_id}") + + return result + else: + error_msg = result.error_message if result else "Unknown error" + print(f"āŒ Failed to place scalp order: {error_msg}") + return None + + except Exception as e: + print(f"Failed to place scalp order: {e}") + import traceback + + traceback.print_exc() + return None + + async def monitor_scalps(self): + """Monitor active scalping positions.""" + mnq_context = self.suite["MNQ"] + + for scalp in self.active_orders[:]: + try: + elapsed_time = asyncio.get_event_loop().time() - scalp["timestamp"] + + # Time-based cancellation (scalps should be quick) + if elapsed_time > 300: # 5 minutes + print("\nCancelling stale scalp order (>5 min old)") + + # Cancel stop and target orders + bracket: BracketOrderResponse = scalp["bracket"] + if bracket.stop_order_id: + await mnq_context.orders.cancel_order(bracket.stop_order_id) + if bracket.target_order_id: + await mnq_context.orders.cancel_order(bracket.target_order_id) + + self.active_orders.remove(scalp) + + except Exception as e: + print(f"Error monitoring scalp: {e}") + + +async def main(): + """Main function to run the scalping strategy.""" + print("Initializing Order Book Scalping Strategy...") + + # Create suite with order book feature + suite = await TradingSuite.create( + ["MNQ"], + timeframes=["15sec", "1min"], + features=["orderbook"], # Essential for order book analysis + initial_days=1, + ) + + mnq_context = suite["MNQ"] + strategy = OrderBookScalpingStrategy(suite) + + # Initialize order book + if not await strategy.initialize_orderbook(): + print("Cannot proceed without order book data") + await suite.disconnect() + return + + # Track tick count for analysis frequency + tick_count = 0 + + # Event handlers + async def on_tick(event: Event): + """Handle tick updates.""" + nonlocal tick_count + tick_data = event.data + + # Store tick for analysis + strategy.tick_history.append( + { + "price": tick_data.get("price", 0), + "size": tick_data.get("size", 0), + "aggressor": tick_data.get("aggressor", "unknown"), + "timestamp": asyncio.get_event_loop().time(), + } + ) + + tick_count += 1 + + # Analyze every 10th tick to avoid over-trading + if ( + tick_count % 10 == 0 + and len(strategy.active_orders) < strategy.max_positions + ): + # Check for order book imbalances + ob_analysis = await strategy.analyze_order_book_imbalance() + tape_analysis = await strategy.analyze_tape_reading() + iceberg_check = await strategy.check_for_iceberg_orders() + + # Look for confluence between order book and tape + if ob_analysis and tape_analysis: + if ob_analysis["direction"] == tape_analysis["direction"]: + print("\nšŸŽÆ Scalping signal detected!") + print( + f" Order Book: {ob_analysis['direction']} " + f"({ob_analysis['strength']:.2%})" + ) + print( + f" Tape: {tape_analysis['direction']} " + f"({tape_analysis['strength']:.2%})" + ) + + if iceberg_check: + print( + f" āš ļø Iceberg: {iceberg_check['side']} " + f"@ ${iceberg_check['price_level']:.2f}" + ) + + await strategy.place_scalp_order( + ob_analysis["direction"], + { + "orderbook": ob_analysis, + "tape": tape_analysis, + "iceberg": iceberg_check, + }, + ) + + async def on_order_filled(event: Event): + """Handle order fill events.""" + order_data = event.data + print( + f"\nāœ… SCALP FILL: Order {order_data.get('order_id')} " + f"filled at ${order_data.get('fill_price', 0):.2f}" + ) + + # Update active orders + for scalp in strategy.active_orders[:]: + bracket: BracketOrderResponse = scalp["bracket"] + if bracket.entry_order_id == order_data.get("order_id"): + print(f" Entry filled for {scalp['direction']} scalp") + break + + async def on_orderbook_update(_event: Event): + """Handle order book updates.""" + # Could use this for more real-time analysis + + # Register events + await mnq_context.on(EventType.QUOTE_UPDATE, on_tick) + await mnq_context.on(EventType.ORDER_FILLED, on_order_filled) + await mnq_context.on(EventType.ORDERBOOK_UPDATE, on_orderbook_update) + + print("\n" + "=" * 60) + print("ORDER BOOK SCALPING STRATEGY ACTIVE") + print("=" * 60) + print("Strategy Settings:") + print(f" Imbalance Threshold: {strategy.imbalance_threshold:.0%}") + print(f" Profit Target: {strategy.scalp_profit_ticks} ticks") + print(" Stop Loss: 3 ticks") + print(f" Max Positions: {strategy.max_positions}") + print("\nAnalyzing market microstructure for scalping opportunities...") + print("Press Ctrl+C to exit") + print("=" * 60) + + try: + while True: + await asyncio.sleep(10) # Status update every 10 seconds + + # Monitor active scalps + await strategy.monitor_scalps() + + # Display status + current_price = await mnq_context.data.get_current_price() + if current_price: + active_scalps = len(strategy.active_orders) + recent_ticks = len(strategy.tick_history) + + # Get volume profile if available + volume_profile = await strategy.analyze_volume_profile() + + print("\nStatus Update:") + print(f" Price: ${current_price:.2f}") + print(f" Active Scalps: {active_scalps}/{strategy.max_positions}") + print(f" Tick Buffer: {recent_ticks}/100") + + if volume_profile: + print(f" POC: ${volume_profile['poc']:.2f}") + print( + f" Value Area: ${volume_profile['value_area_low']:.2f} - " + f"${volume_profile['value_area_high']:.2f}" + ) + + except KeyboardInterrupt: + print("\n\nShutting down scalping strategy...") + + # Cancel any active orders + for scalp in strategy.active_orders: + try: + bracket: BracketOrderResponse = scalp["bracket"] + if bracket.stop_order_id: + await mnq_context.orders.cancel_order(bracket.stop_order_id) + print(f"Cancelled stop order {bracket.stop_order_id}") + if bracket.target_order_id: + await mnq_context.orders.cancel_order(bracket.target_order_id) + print(f"Cancelled target order {bracket.target_order_id}") + except Exception as e: + print(f"Error cancelling order: {e}") + + finally: + # Disconnect from real-time feeds + await suite.disconnect() + print("Strategy disconnected. Goodbye!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/realtime_data/01_basic_realtime_data_streaming.py b/examples/realtime_data/01_basic_realtime_data_streaming.py new file mode 100644 index 0000000..0c8625a --- /dev/null +++ b/examples/realtime_data/01_basic_realtime_data_streaming.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python +""" +Basic real-time data streaming example. + +This example demonstrates: +- Connecting to real-time data feeds +- Handling tick (quote) updates +- Processing new bar events +- Monitoring connection health +- Displaying streaming statistics +""" + +import asyncio +from datetime import datetime + +from project_x_py import EventType, TradingSuite +from project_x_py.event_bus import Event + + +async def main(): + """Main function to run real-time data streaming.""" + # Create suite with real-time capabilities + suite = await TradingSuite.create( + ["MNQ"], + timeframes=["15sec", "1min"], + initial_days=1, # Minimal historical data + ) + mnq_context = suite["MNQ"] + + print(f"Real-time streaming started for {mnq_context.symbol}") + print(f"Connected: {suite.is_connected}") + + # Track statistics + tick_count = 0 + bar_count = 0 + last_price = None + + async def on_tick(event: Event): + """Handle tick updates.""" + nonlocal tick_count, last_price + tick_data = event.data + + tick_count += 1 + last_price = tick_data.get("last") or last_price + + # Display every 10th tick to avoid spam + if tick_count % 10 == 0: + timestamp = datetime.now().strftime("%H:%M:%S") + print(f"[{timestamp}] Tick #{tick_count}: ${last_price:.2f}") + + async def on_new_bar(event: Event): + """Handle new bar events.""" + nonlocal bar_count + bar_count += 1 + + timestamp = datetime.now().strftime("%H:%M:%S") + + # The event.data contains timeframe and nested data + event_data = event.data + timeframe = event_data.get("timeframe", "unknown") + + # Get the bar data directly from the event + bar_data = event_data.get("data", {}) + + if bar_data: + print(f"[{timestamp}] New {timeframe} bar #{bar_count}:") + + # Access the bar data fields directly + open_price = bar_data.get("open", 0) + high_price = bar_data.get("high", 0) + low_price = bar_data.get("low", 0) + close_price = bar_data.get("close", 0) + volume = bar_data.get("volume", 0) + bar_timestamp = bar_data.get("timestamp", "") + + print( + f" OHLC: ${open_price:.2f} / ${high_price:.2f} / " + f"${low_price:.2f} / ${close_price:.2f}" + ) + print(f" Volume: {volume}") + print(f" Timestamp: {bar_timestamp}") + + async def on_connection_status(event: Event): + """Handle connection status changes.""" + status = event.data.get("connected", False) + print(f"Connection Status Changed: {status}") + if status: + print("āœ… Real-time feed connected") + else: + print("āŒ Real-time feed disconnected") + + # Register event handlers + await mnq_context.on(EventType.QUOTE_UPDATE, on_tick) + await mnq_context.on(EventType.NEW_BAR, on_new_bar) + await mnq_context.on(EventType.CONNECTED, on_connection_status) + await mnq_context.on(EventType.DISCONNECTED, on_connection_status) + + print("Listening for real-time data... Press Ctrl+C to exit") + + try: + while True: + await asyncio.sleep(10) + + # Display periodic status + current_price = await mnq_context.data.get_current_price() + connection_health = await mnq_context.data.get_health_score() + + print( + f"Status - Price: ${current_price:.2f} | " + f"Ticks: {tick_count} | Bars: {bar_count} | " + f"Health: {connection_health}" + ) + + except KeyboardInterrupt: + print("\nShutting down real-time stream...") + finally: + # Ensure proper cleanup + await suite.disconnect() + print("Disconnected from real-time feeds") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/realtime_data/02_multi-timeframe_data_sync.py b/examples/realtime_data/02_multi-timeframe_data_sync.py new file mode 100644 index 0000000..abdbff4 --- /dev/null +++ b/examples/realtime_data/02_multi-timeframe_data_sync.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python +""" +Multi-timeframe real-time data synchronization +""" + +import asyncio +from collections import defaultdict +from datetime import datetime + +from project_x_py import EventType, TradingSuite +from project_x_py.indicators import RSI, SMA + + +class MultiTimeframeDataProcessor: + def __init__(self, suite: TradingSuite): + self.suite = suite + self.timeframes = ["1min", "5min", "15min"] + self.data_cache = defaultdict(list) + self.last_analysis = defaultdict(dict) + self.analysis_count = 0 + + async def process_new_bar(self, event): + """Process incoming bar data for all timeframes.""" + bar_data = event.data.get("data", event.data) + timeframe = event.data.get("timeframe", "unknown") + + if timeframe not in self.timeframes: + return + + # Store the bar + self.data_cache[timeframe].append(bar_data) + + # Keep only recent bars (memory management) + if len(self.data_cache[timeframe]) > 200: + self.data_cache[timeframe] = self.data_cache[timeframe][-100:] + + print( + f"New {timeframe} bar: ${bar_data['close']:.2f} @ {bar_data.get('timestamp')}" + ) + + # Perform analysis on this timeframe + await self.analyze_timeframe(timeframe) + + # Check for multi-timeframe confluence + if timeframe == "1min": # Trigger confluence check on fastest timeframe + await self.check_confluence() + + async def analyze_timeframe(self, timeframe: str): + """Analyze a specific timeframe with technical indicators.""" + try: + # Get fresh data from suite + bars = await self.suite["MNQ"].data.get_data(timeframe) + + if bars is None: + return + + if len(bars) < 50: # Need enough data for indicators + return + + # Calculate indicators + bars = bars.pipe(SMA, period=20).pipe(RSI, period=14) + + current_price = bars["close"][-1] + current_sma = bars["sma_20"][-1] + current_rsi = bars["rsi_14"][-1] + + # Determine trend and momentum + trend = "bullish" if current_price > current_sma else "bearish" + momentum = ( + "strong" + if (trend == "bullish" and current_rsi > 50) + or (trend == "bearish" and current_rsi < 50) + else "weak" + ) + + # Store analysis + self.last_analysis[timeframe] = { + "price": current_price, + "sma_20": current_sma, + "rsi": current_rsi, + "trend": trend, + "momentum": momentum, + "timestamp": datetime.now(), + } + + print( + f" {timeframe} Analysis - Trend: {trend}, RSI: {current_rsi:.1f}, Momentum: {momentum}" + ) + + except Exception as e: + print(f"Error analyzing {timeframe}: {e}") + + async def check_confluence(self): + """Check for confluence across all timeframes.""" + self.analysis_count += 1 + + # Only check confluence every 5th analysis to avoid spam + if self.analysis_count % 5 != 0: + return + + if len(self.last_analysis) < len(self.timeframes): + return + + # Count bullish/bearish signals + bullish_count = sum( + 1 + for analysis in self.last_analysis.values() + if analysis.get("trend") == "bullish" + ) + bearish_count = sum( + 1 + for analysis in self.last_analysis.values() + if analysis.get("trend") == "bearish" + ) + + # Check for strong confluence + total_timeframes = len(self.last_analysis) + + if bullish_count >= total_timeframes * 0.8: # 80% agreement + print( + f"\n= BULLISH CONFLUENCE DETECTED ({bullish_count}/{total_timeframes})" + ) + await self.display_confluence_analysis("BULLISH") + elif bearish_count >= total_timeframes * 0.8: + print( + f"\n=4 BEARISH CONFLUENCE DETECTED ({bearish_count}/{total_timeframes})" + ) + await self.display_confluence_analysis("BEARISH") + + async def display_confluence_analysis(self, signal_type: str): + """Display detailed confluence analysis.""" + print(f"{signal_type} CONFLUENCE ANALYSIS:") + print("-" * 40) + + for tf, analysis in self.last_analysis.items(): + trend_emoji = "=" if analysis["trend"] == "bullish" else "=" + momentum_emoji = "=" if analysis["momentum"] == "strong" else "=" + + print( + f" {tf:>5} {trend_emoji} {analysis['trend']:>8} | RSI: {analysis['rsi']:>5.1f} | {momentum_emoji} {analysis['momentum']}" + ) + + print("-" * 40) + + # Get current market data + current_price = await self.suite["MNQ"].data.get_current_price() + print(f"Current Price: ${current_price:.2f}") + print() + + +async def main(): + # Create suite with multiple timeframes + suite = await TradingSuite.create( + "MNQ", + timeframes=["1min", "5min", "15min"], + initial_days=3, # Enough data for indicators + ) + + processor = MultiTimeframeDataProcessor(suite) + + # Register event handler + await suite.on(EventType.NEW_BAR, processor.process_new_bar) + + print("Multi-Timeframe Data Processor Active") + print("Monitoring 1min, 5min, and 15min timeframes...") + print("Press Ctrl+C to exit") + + try: + while True: + await asyncio.sleep(15) + + # Display periodic status + print(f"\nStatus Update - {datetime.now().strftime('%H:%M:%S')}") + for tf in processor.timeframes: + cached_bars = len(processor.data_cache[tf]) + analysis = processor.last_analysis.get(tf, {}) + trend = analysis.get("trend", "unknown") + rsi = analysis.get("rsi", 0) + print( + f" {tf}: {cached_bars} bars cached, {trend} trend, RSI: {rsi:.1f}" + ) + + except KeyboardInterrupt: + print("\nShutting down multi-timeframe processor...") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/realtime_data/03_realtime_data_export_visualization.py b/examples/realtime_data/03_realtime_data_export_visualization.py new file mode 100644 index 0000000..da01f32 --- /dev/null +++ b/examples/realtime_data/03_realtime_data_export_visualization.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python +""" +Real-time data export with CSV logging and Plotly visualization +""" + +import asyncio +import csv +import json +from datetime import datetime, timedelta +from pathlib import Path + +from project_x_py import EventType, TradingSuite + + +class RealTimeDataExporter: + def __init__(self, suite: TradingSuite, export_dir: str = "data_exports"): + self.suite = suite + self.export_dir = Path(export_dir) + self.export_dir.mkdir(exist_ok=True) + + # Data storage + self.tick_data = [] + self.bar_data = [] + self.trade_data = [] + + # File handles + self.csv_files = {} + self.export_interval = 60 # Export every 60 seconds + + async def initialize_export_files(self): + """Initialize CSV files for data export.""" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + # Bar data CSV + bar_file = self.export_dir / f"bars_{timestamp}.csv" + bar_csv = open(bar_file, "w", newline="") + bar_writer = csv.writer(bar_csv) + bar_writer.writerow( + ["timestamp", "timeframe", "open", "high", "low", "close", "volume"] + ) + self.csv_files["bars"] = {"file": bar_csv, "writer": bar_writer} + + print(f"Export files initialized in {self.export_dir}") + + async def process_bar(self, event): + """Process and export bar data.""" + timestamp = datetime.now().isoformat() + + # Get the real data for the timeframe + # Data from the event is from the new bar that was just started, so we need to get the previous bar + real_data = await self.suite["MNQ"].data.get_data( + event.data.get("timeframe", "unknown") + ) + + if real_data is None: + return + + # Store in memory + bar_record = { + "timestamp": timestamp, + "bar_timestamp": real_data["timestamp"][-2], + "timeframe": event.data.get("timeframe", "unknown"), + "open": real_data["open"][-2], + "high": real_data["high"][-2], + "low": real_data["low"][-2], + "close": real_data["close"][-2], + "volume": real_data["volume"][-2], + } + + self.bar_data.append(bar_record) + + # Write to CSV + if "bars" in self.csv_files: + writer = self.csv_files["bars"]["writer"] + writer.writerow( + [ + bar_record["bar_timestamp"] or timestamp, + bar_record["timeframe"], + bar_record["open"], + bar_record["high"], + bar_record["low"], + bar_record["close"], + bar_record["volume"], + ] + ) + self.csv_files["bars"]["file"].flush() + + print(f"Exported {bar_record['timeframe']} bar: ${bar_record['close']:.2f}") + + async def export_json_snapshot(self): + """Export current data snapshot as JSON.""" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + snapshot = { + "export_timestamp": datetime.now().isoformat(), + "data_summary": { + "bar_count": len(self.bar_data), + }, + "recent_data": { + "bars": self.bar_data[-5:], # Last 5 bars + }, + } + + json_file = self.export_dir / f"snapshot_{timestamp}.json" + with open(json_file, "w") as f: + json.dump(snapshot, f, indent=2) + + print(f"JSON snapshot exported: {json_file}") + return json_file + + def close_files(self): + """Close all open CSV files.""" + for file_info in self.csv_files.values(): + file_info["file"].close() + print("Export files closed") + + +async def main(): + # Create suite for data export + suite = await TradingSuite.create( + "MNQ", timeframes=["15sec", "1min", "5min"], initial_days=1 + ) + + mnq_context = suite["MNQ"] + + exporter = RealTimeDataExporter(suite) + await exporter.initialize_export_files() + + # Event handlers + await suite.on(EventType.NEW_BAR, exporter.process_bar) + + print("Real-time Data Exporter Active") + print(f"Exporting to: {exporter.export_dir}") + print("Streaming data...") + + try: + export_timer = 0 + + while True: + await asyncio.sleep(10) + export_timer += 10 + + # Periodic status + current_price = await mnq_context.data.get_current_price() + if current_price is None: + continue + + print(f"Price: ${current_price:.2f} | Bars: {len(exporter.bar_data)}") + + # Auto-export JSON snapshot every 5 minutes + if export_timer >= 300: # 5 minutes + await exporter.export_json_snapshot() + export_timer = 0 + + except KeyboardInterrupt: + print("\nShutting down data exporter...") + + # Final exports + print("Creating final exports...") + await exporter.export_json_snapshot() + + # Close files + exporter.close_files() + + print("Data export complete!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index d2f0287..b622d54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "project-x-py" -version = "3.5.5" +version = "3.5.6" description = "High-performance Python SDK for futures trading with real-time WebSocket data, technical indicators, order management, and market depth analysis" readme = "README.md" license = { text = "MIT" } diff --git a/src/project_x_py/__init__.py b/src/project_x_py/__init__.py index d01430a..d831fd0 100644 --- a/src/project_x_py/__init__.py +++ b/src/project_x_py/__init__.py @@ -109,7 +109,7 @@ - `utils`: Utility functions and calculations """ -__version__ = "3.5.5" +__version__ = "3.5.6" __author__ = "TexasCoding" # Core client classes - renamed from Async* to standard names diff --git a/src/project_x_py/event_bus.py b/src/project_x_py/event_bus.py index eb9fd90..41a21f4 100644 --- a/src/project_x_py/event_bus.py +++ b/src/project_x_py/event_bus.py @@ -388,3 +388,20 @@ def get_handler_count(self, event: EventType | str | None = None) -> int: count = len(self._legacy_handlers.get(str(event), [])) count += len(self._wildcard_handlers) return count + + async def forward_to(self, target_bus: "EventBus") -> None: + """ + Forward all events from this bus to the target bus. + + This sets up a wildcard handler that forwards all events to another EventBus, + enabling event propagation from instrument-specific buses to the main suite bus. + + Args: + target_bus: The EventBus to forward events to + """ + + async def forwarder(event: Event) -> None: + """Forward event to target bus.""" + await target_bus.emit(event.type, event.data, event.source) + + await self.on_any(forwarder) diff --git a/src/project_x_py/indicators/__init__.py b/src/project_x_py/indicators/__init__.py index 22a53fb..b3713eb 100644 --- a/src/project_x_py/indicators/__init__.py +++ b/src/project_x_py/indicators/__init__.py @@ -207,7 +207,7 @@ ) # Version info -__version__ = "3.5.5" +__version__ = "3.5.6" __author__ = "TexasCoding" diff --git a/src/project_x_py/order_manager/bracket_orders.py b/src/project_x_py/order_manager/bracket_orders.py index a9b6507..c9d5e43 100644 --- a/src/project_x_py/order_manager/bracket_orders.py +++ b/src/project_x_py/order_manager/bracket_orders.py @@ -276,20 +276,38 @@ async def place_bracket_order( ) try: - # CRITICAL: Validate tick sizes BEFORE any price operations + # CRITICAL: Align prices to tick sizes BEFORE any price operations if hasattr(self, "project_x") and self.project_x: - from .utils import validate_price_tick_size + from .utils import align_price_to_tick_size - if entry_price is not None: # Only validate if not market order - await validate_price_tick_size( - entry_price, contract_id, self.project_x, "entry_price" + # Align all prices to valid tick sizes + if entry_price is not None: # Only align if not market order + aligned_entry = await align_price_to_tick_size( + entry_price, contract_id, self.project_x ) - await validate_price_tick_size( - stop_loss_price, contract_id, self.project_x, "stop_loss_price" + if aligned_entry is not None and aligned_entry != entry_price: + logger.info( + f"Entry price aligned from {entry_price} to {aligned_entry}" + ) + entry_price = aligned_entry + + aligned_stop = await align_price_to_tick_size( + stop_loss_price, contract_id, self.project_x ) - await validate_price_tick_size( - take_profit_price, contract_id, self.project_x, "take_profit_price" + if aligned_stop is not None and aligned_stop != stop_loss_price: + logger.info( + f"Stop loss price aligned from {stop_loss_price} to {aligned_stop}" + ) + stop_loss_price = aligned_stop + + aligned_target = await align_price_to_tick_size( + take_profit_price, contract_id, self.project_x ) + if aligned_target is not None and aligned_target != take_profit_price: + logger.info( + f"Take profit price aligned from {take_profit_price} to {aligned_target}" + ) + take_profit_price = aligned_target # Convert prices to Decimal for precise comparisons # For market orders, use a placeholder for entry_decimal that won't affect validation diff --git a/src/project_x_py/order_manager/core.py b/src/project_x_py/order_manager/core.py index a23136c..73dbabb 100644 --- a/src/project_x_py/order_manager/core.py +++ b/src/project_x_py/order_manager/core.py @@ -86,7 +86,6 @@ async def main(): from .utils import ( align_price_to_tick_size, resolve_contract_id, - validate_price_tick_size, ) if TYPE_CHECKING: @@ -462,45 +461,48 @@ async def place_order( if trail_price is not None and trail_price < 0: raise ProjectXOrderError(f"Invalid negative price: {trail_price}") - # CRITICAL: Validate tick size BEFORE any price operations - await validate_price_tick_size( - limit_price, contract_id, self.project_x, "limit_price" - ) - await validate_price_tick_size( - stop_price, contract_id, self.project_x, "stop_price" - ) - await validate_price_tick_size( - trail_price, contract_id, self.project_x, "trail_price" - ) + # CRITICAL: Align prices to tick size BEFORE any price operations + if limit_price is not None: + aligned_limit = await align_price_to_tick_size( + limit_price, contract_id, self.project_x + ) + if aligned_limit is not None and aligned_limit != limit_price: + self.logger.info( + f"Limit price aligned from {limit_price} to {aligned_limit}" + ) + limit_price = aligned_limit - # Convert prices to Decimal for precision, then align to tick size - decimal_limit_price = ( + if stop_price is not None: + aligned_stop = await align_price_to_tick_size( + stop_price, contract_id, self.project_x + ) + if aligned_stop is not None and aligned_stop != stop_price: + self.logger.info( + f"Stop price aligned from {stop_price} to {aligned_stop}" + ) + stop_price = aligned_stop + + if trail_price is not None: + aligned_trail = await align_price_to_tick_size( + trail_price, contract_id, self.project_x + ) + if aligned_trail is not None and aligned_trail != trail_price: + self.logger.info( + f"Trail price aligned from {trail_price} to {aligned_trail}" + ) + trail_price = aligned_trail + + # Convert prices to Decimal for precision (already aligned above) + aligned_limit_price = ( Decimal(str(limit_price)) if limit_price is not None else None ) - decimal_stop_price = ( + aligned_stop_price = ( Decimal(str(stop_price)) if stop_price is not None else None ) - decimal_trail_price = ( + aligned_trail_price = ( Decimal(str(trail_price)) if trail_price is not None else None ) - # Align all prices to tick size to prevent "Invalid price" errors - aligned_limit_price = await align_price_to_tick_size( - float(decimal_limit_price) if decimal_limit_price is not None else None, - contract_id, - self.project_x, - ) - aligned_stop_price = await align_price_to_tick_size( - float(decimal_stop_price) if decimal_stop_price is not None else None, - contract_id, - self.project_x, - ) - aligned_trail_price = await align_price_to_tick_size( - float(decimal_trail_price) if decimal_trail_price is not None else None, - contract_id, - self.project_x, - ) - # Use account_info if no account_id provided if account_id is None: if not self.project_x.account_info: @@ -1043,13 +1045,26 @@ async def modify_order( contract_id = existing_order.contractId - # CRITICAL: Validate tick size BEFORE any price operations - await validate_price_tick_size( - limit_price, contract_id, self.project_x, "limit_price" - ) - await validate_price_tick_size( - stop_price, contract_id, self.project_x, "stop_price" - ) + # CRITICAL: Align prices to tick size BEFORE any price operations + if limit_price is not None: + aligned_limit = await align_price_to_tick_size( + limit_price, contract_id, self.project_x + ) + if aligned_limit is not None and aligned_limit != limit_price: + self.logger.info( + f"Limit price aligned from {limit_price} to {aligned_limit}" + ) + limit_price = aligned_limit + + if stop_price is not None: + aligned_stop = await align_price_to_tick_size( + stop_price, contract_id, self.project_x + ) + if aligned_stop is not None and aligned_stop != stop_price: + self.logger.info( + f"Stop price aligned from {stop_price} to {aligned_stop}" + ) + stop_price = aligned_stop # Convert prices to Decimal for precision, then align to tick size decimal_limit = ( diff --git a/src/project_x_py/trading_suite.py b/src/project_x_py/trading_suite.py index 2253b5f..f1bf151 100644 --- a/src/project_x_py/trading_suite.py +++ b/src/project_x_py/trading_suite.py @@ -89,6 +89,7 @@ class InstrumentContext: data: Real-time data manager for OHLCV data orders: Order management system positions: Position tracking system + event_bus: Event bus for this instrument's events orderbook: Level 2 market depth (optional) risk_manager: Risk management system (optional) """ @@ -98,9 +99,60 @@ class InstrumentContext: data: RealtimeDataManager orders: OrderManager positions: PositionManager + event_bus: EventBus orderbook: OrderBook | None = None risk_manager: RiskManager | None = None + async def on(self, event: EventType | str, handler: Any) -> None: + """ + Register event handler on this instrument's event bus. + + Args: + event: Event type to listen for + handler: Async callable to handle events + """ + await self.event_bus.on(event, handler) + + async def once(self, event: EventType | str, handler: Any) -> None: + """ + Register one-time event handler on this instrument's event bus. + + Args: + event: Event type to listen for + handler: Async callable to handle event once + """ + await self.event_bus.once(event, handler) + + async def off( + self, event: EventType | str | None = None, handler: Any | None = None + ) -> None: + """ + Remove event handler(s) from this instrument's event bus. + + Args: + event: Event type to remove handler from (None for all) + handler: Specific handler to remove (None for all) + """ + await self.event_bus.off(event, handler) + + async def wait_for( + self, event: EventType | str, timeout: float | None = None + ) -> Any: + """ + Wait for specific event to occur on this instrument's event bus. + + Args: + event: Event type to wait for + timeout: Optional timeout in seconds + + Returns: + Event object when received + + Raises: + TimeoutError: If timeout expires + """ + return await self.event_bus.wait_for(event, timeout) + class Features(str, Enum): """Available feature flags for TradingSuite.""" @@ -489,6 +541,9 @@ async def create( # Create suite instance with contexts suite = cls(client, realtime_client, config, instrument_contexts) + # Set up event forwarding from instrument buses to suite bus + await suite._setup_event_forwarding() + # Store the context for cleanup later suite._client_context = client_context @@ -587,6 +642,7 @@ async def _create_single_context(symbol: str) -> tuple[str, InstrumentContext]: data=data_manager, orders=order_manager, positions=position_manager, + event_bus=event_bus, orderbook=orderbook, risk_manager=risk_manager, ) @@ -640,6 +696,20 @@ async def _create_single_context_with_tracking( await cls._cleanup_contexts(created_contexts.copy()) raise + async def _setup_event_forwarding(self) -> None: + """ + Set up event forwarding from instrument EventBuses to the suite's main EventBus. + + This ensures that events emitted to instrument-specific EventBuses are also + forwarded to the suite-level EventBus, allowing suite-level handlers to receive + events from all instruments. + """ + if not self._instruments: + return + + for context in self._instruments.values(): + await context.event_bus.forward_to(self.events) + @classmethod async def _cleanup_contexts(cls, contexts: dict[str, InstrumentContext]) -> None: """ diff --git a/tests/order_manager/test_core_advanced.py b/tests/order_manager/test_core_advanced.py index 487657d..808188f 100644 --- a/tests/order_manager/test_core_advanced.py +++ b/tests/order_manager/test_core_advanced.py @@ -248,9 +248,11 @@ async def test_place_order_aligns_all_price_types(self, order_manager): # Should have aligned all three prices assert mock_align.call_count == 3 call_args = order_manager.project_x._make_request.call_args[1]["data"] - assert call_args.get("limitPrice") == 100.78 - assert call_args.get("stopPrice") == 99.33 - assert call_args.get("trailPrice") == 2.00 + # Prices are now Decimal objects for precision + from decimal import Decimal + assert call_args.get("limitPrice") == Decimal('100.78') + assert call_args.get("stopPrice") == Decimal('99.33') + assert call_args.get("trailPrice") == Decimal('2.0') class TestConcurrentOrderOperations: diff --git a/tests/test_event_forwarding.py b/tests/test_event_forwarding.py new file mode 100644 index 0000000..7a2ab5a --- /dev/null +++ b/tests/test_event_forwarding.py @@ -0,0 +1,199 @@ +""" +Test event forwarding from instrument EventBuses to suite EventBus. + +This test validates the core event forwarding functionality without +requiring complex mocking of the real-time client. +""" + +import asyncio + +import pytest + +from project_x_py.event_bus import EventBus, EventType +from project_x_py.trading_suite import InstrumentContext + + +@pytest.mark.asyncio +async def test_event_forwarding_from_instrument_to_suite(): + """Test that events are forwarded from instrument EventBus to suite EventBus.""" + # Create suite-level EventBus + suite_event_bus = EventBus() + + # Create instrument-specific EventBus + mnq_event_bus = EventBus() + nq_event_bus = EventBus() + + # Set up event forwarding (simulating what _setup_event_forwarding does) + await mnq_event_bus.forward_to(suite_event_bus) + await nq_event_bus.forward_to(suite_event_bus) + + # Track events received at suite level + suite_events = [] + + async def suite_handler(event): + suite_events.append(event) + + # Register handler at suite level + await suite_event_bus.on(EventType.NEW_BAR, suite_handler) + + # Emit events from instrument-specific buses + await mnq_event_bus.emit( + EventType.NEW_BAR, + {"instrument": "MNQ", "timeframe": "1min", "close": 100.0} + ) + await nq_event_bus.emit( + EventType.NEW_BAR, + {"instrument": "NQ", "timeframe": "1min", "close": 200.0} + ) + + # Allow event propagation + await asyncio.sleep(0.1) + + # Suite should receive events from both instruments + assert len(suite_events) == 2 + assert any(e.data["instrument"] == "MNQ" for e in suite_events) + assert any(e.data["instrument"] == "NQ" for e in suite_events) + + +@pytest.mark.asyncio +async def test_instrument_context_event_methods(): + """Test that InstrumentContext event methods work correctly.""" + # Create a mock InstrumentContext with event_bus + event_bus = EventBus() + + # Create a simple mock context with the event_bus attribute + class MockInstrumentContext: + def __init__(self): + self.event_bus = event_bus + self.symbol = "MNQ" + + async def on(self, event, handler): + """Register event handler on this instrument's event bus.""" + await self.event_bus.on(event, handler) + + async def once(self, event, handler): + """Register one-time event handler on this instrument's event bus.""" + await self.event_bus.once(event, handler) + + async def off(self, event=None, handler=None): + """Remove event handler(s) from this instrument's event bus.""" + await self.event_bus.off(event, handler) + + async def wait_for(self, event, timeout=None): + """Wait for specific event to occur on this instrument's event bus.""" + return await self.event_bus.wait_for(event, timeout) + + context = MockInstrumentContext() + + # Test that methods exist and are callable + events_received = [] + + async def handler(event): + events_received.append(event) + + # Register handler + await context.on(EventType.NEW_BAR, handler) + + # Emit event + await context.event_bus.emit( + EventType.NEW_BAR, + {"instrument": "MNQ", "timeframe": "1min", "close": 100.0} + ) + + # Allow event propagation + await asyncio.sleep(0.1) + + # Handler should have received the event + assert len(events_received) == 1 + assert events_received[0].data["instrument"] == "MNQ" + + +@pytest.mark.asyncio +async def test_wait_for_with_event_forwarding(): + """Test that wait_for works with event forwarding.""" + # Create suite and instrument EventBuses + suite_event_bus = EventBus() + instrument_event_bus = EventBus() + + # Set up forwarding + await instrument_event_bus.forward_to(suite_event_bus) + + # Create a task that waits for an event at suite level + async def wait_for_event(): + return await suite_event_bus.wait_for(EventType.NEW_BAR, timeout=1.0) + + wait_task = asyncio.create_task(wait_for_event()) + + # Give the task time to start waiting + await asyncio.sleep(0.1) + + # Emit event from instrument bus + await instrument_event_bus.emit( + EventType.NEW_BAR, + {"instrument": "MNQ", "timeframe": "1min", "close": 100.0} + ) + + # wait_for should receive the event + try: + event = await asyncio.wait_for(wait_task, timeout=1.0) + assert event.data["instrument"] == "MNQ" + assert event.data["close"] == 100.0 + except asyncio.TimeoutError: + pytest.fail("wait_for did not receive forwarded event") + + +@pytest.mark.asyncio +async def test_multiple_instrument_event_isolation(): + """Test that instrument-specific handlers only receive their own events.""" + # Create EventBuses + suite_event_bus = EventBus() + mnq_event_bus = EventBus() + nq_event_bus = EventBus() + + # Set up forwarding + await mnq_event_bus.forward_to(suite_event_bus) + await nq_event_bus.forward_to(suite_event_bus) + + # Track events for each instrument + mnq_events = [] + nq_events = [] + suite_events = [] + + async def mnq_handler(event): + mnq_events.append(event) + + async def nq_handler(event): + nq_events.append(event) + + async def suite_handler(event): + suite_events.append(event) + + # Register handlers + await mnq_event_bus.on(EventType.NEW_BAR, mnq_handler) + await nq_event_bus.on(EventType.NEW_BAR, nq_handler) + await suite_event_bus.on(EventType.NEW_BAR, suite_handler) + + # Emit events + await mnq_event_bus.emit( + EventType.NEW_BAR, + {"instrument": "MNQ", "timeframe": "1min", "close": 100.0} + ) + await nq_event_bus.emit( + EventType.NEW_BAR, + {"instrument": "NQ", "timeframe": "1min", "close": 200.0} + ) + + # Allow event propagation + await asyncio.sleep(0.1) + + # Each instrument handler should only receive its own events + assert len(mnq_events) == 1 + assert mnq_events[0].data["instrument"] == "MNQ" + + assert len(nq_events) == 1 + assert nq_events[0].data["instrument"] == "NQ" + + # Suite handler should receive all events + assert len(suite_events) == 2 + assert any(e.data["instrument"] == "MNQ" for e in suite_events) + assert any(e.data["instrument"] == "NQ" for e in suite_events) diff --git a/tests/test_event_system_multi_instrument.py b/tests/test_event_system_multi_instrument.py new file mode 100644 index 0000000..b29b127 --- /dev/null +++ b/tests/test_event_system_multi_instrument.py @@ -0,0 +1,344 @@ +""" +Tests for multi-instrument event system functionality. + +These tests define the expected behavior of the event system when using +multiple instruments with TradingSuite. Following TDD principles, these +tests are written to specify how the system SHOULD work. +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from project_x_py import TradingSuite +from project_x_py.event_bus import EventType + + +@pytest.mark.asyncio +async def test_suite_receives_events_from_all_instruments(): + """Test that suite-level event handlers receive events from all instruments.""" + with patch("project_x_py.trading_suite.ProjectX") as mock_client_class, \ + patch("project_x_py.trading_suite.ProjectXRealtimeClient") as mock_realtime_class: + + # Mock main client + mock_client = AsyncMock() + mock_client_class.from_env.return_value.__aenter__.return_value = mock_client + + # Mock realtime client + mock_realtime = AsyncMock() + mock_realtime_class.return_value = mock_realtime + + # Setup mock responses + mock_client.authenticate.return_value = None + # Create enough mock instruments for multiple calls + def mock_get_instrument(symbol): + return MagicMock(symbol=symbol, name=symbol, exchange="CME", min_tick=0.25, id=f"id_{symbol}") + mock_client.get_instrument.side_effect = mock_get_instrument + mock_client.get_bars.return_value = MagicMock(is_empty=lambda: True) + + # Mock realtime client connection + mock_realtime.connect.return_value = None + mock_realtime.is_connected.return_value = True + + suite = await TradingSuite.create(["MNQ", "NQ"], timeframes=["1min"]) + + # Track events received at suite level + events_received = [] + + async def on_new_bar(event): + events_received.append(event) + + # Register handler at suite level + await suite.on(EventType.NEW_BAR, on_new_bar) + + # Emit events from instrument-specific event buses + await suite["MNQ"].event_bus.emit( + EventType.NEW_BAR, {"instrument": "MNQ", "timeframe": "1min", "close": 100.0} + ) + await suite["NQ"].event_bus.emit( + EventType.NEW_BAR, {"instrument": "NQ", "timeframe": "1min", "close": 200.0} + ) + + # Allow event propagation + await asyncio.sleep(0.1) + + # Suite should receive events from both instruments + assert len(events_received) == 2 + assert any(e.data["instrument"] == "MNQ" for e in events_received) + assert any(e.data["instrument"] == "NQ" for e in events_received) + + +@pytest.mark.asyncio +async def test_instrument_context_has_event_methods(): + """Test that InstrumentContext provides wait_for and on methods.""" + with patch("project_x_py.trading_suite.ProjectX") as mock_client_class, \ + patch("project_x_py.trading_suite.ProjectXRealtimeClient") as mock_realtime_class: + + # Mock main client + mock_client = AsyncMock() + mock_client_class.from_env.return_value.__aenter__.return_value = mock_client + + # Mock realtime client + mock_realtime = AsyncMock() + mock_realtime_class.return_value = mock_realtime + + # Setup mock responses + mock_client.authenticate.return_value = None + mock_client.get_instrument.return_value = MagicMock( + symbol="MNQ", name="MNQ", exchange="CME", min_tick=0.25 + ) + mock_client.get_bars.return_value = MagicMock(is_empty=lambda: True) + + # Mock realtime client connection + mock_realtime.connect.return_value = None + mock_realtime.is_connected.return_value = True + + suite = await TradingSuite.create("MNQ", timeframes=["1min"]) + mnq_context = suite["MNQ"] + + # InstrumentContext should have event methods + assert hasattr(mnq_context, "wait_for") + assert hasattr(mnq_context, "on") + assert hasattr(mnq_context, "off") + + # Test that methods are callable + events_received = [] + + async def handler(event): + events_received.append(event) + + # Should be able to register handler on instrument context + await mnq_context.on(EventType.NEW_BAR, handler) + + # Emit event to instrument's event bus + await mnq_context.event_bus.emit( + EventType.NEW_BAR, {"instrument": "MNQ", "timeframe": "1min", "close": 100.0} + ) + + # Allow event propagation + await asyncio.sleep(0.1) + + # Handler should have received the event + assert len(events_received) == 1 + assert events_received[0].data["instrument"] == "MNQ" + + +@pytest.mark.asyncio +async def test_wait_for_works_at_suite_level(): + """Test that suite.wait_for() receives events from any instrument.""" + with patch("project_x_py.trading_suite.ProjectX") as mock_client_class, \ + patch("project_x_py.trading_suite.ProjectXRealtimeClient") as mock_realtime_class: + + # Mock main client + mock_client = AsyncMock() + mock_client_class.from_env.return_value.__aenter__.return_value = mock_client + + # Mock realtime client + mock_realtime = AsyncMock() + mock_realtime_class.return_value = mock_realtime + + # Setup mock responses + mock_client.authenticate.return_value = None + # Create enough mock instruments for multiple calls + def mock_get_instrument(symbol): + return MagicMock(symbol=symbol, name=symbol, exchange="CME", min_tick=0.25, id=f"id_{symbol}") + mock_client.get_instrument.side_effect = mock_get_instrument + mock_client.get_bars.return_value = MagicMock(is_empty=lambda: True) + + # Mock realtime client connection + mock_realtime.connect.return_value = None + mock_realtime.is_connected.return_value = True + + suite = await TradingSuite.create(["MNQ", "NQ"], timeframes=["1min"]) + + # Create a task that waits for an event + async def wait_for_event(): + event = await suite.wait_for(EventType.NEW_BAR) + return event + + wait_task = asyncio.create_task(wait_for_event()) + + # Give the task time to start waiting + await asyncio.sleep(0.1) + + # Emit event from one of the instruments + await suite["NQ"].event_bus.emit( + EventType.NEW_BAR, {"instrument": "NQ", "timeframe": "1min", "close": 200.0} + ) + + # wait_for should receive the event + try: + event = await asyncio.wait_for(wait_task, timeout=1.0) + assert event.data["instrument"] == "NQ" + assert event.data["close"] == 200.0 + except asyncio.TimeoutError: + pytest.fail("suite.wait_for() did not receive event from instrument") + + +@pytest.mark.asyncio +async def test_wait_for_works_at_instrument_level(): + """Test that instrument_context.wait_for() receives events for that instrument.""" + with patch("project_x_py.trading_suite.ProjectX") as mock_client_class, \ + patch("project_x_py.trading_suite.ProjectXRealtimeClient") as mock_realtime_class: + + # Mock main client + mock_client = AsyncMock() + mock_client_class.from_env.return_value.__aenter__.return_value = mock_client + + # Mock realtime client + mock_realtime = AsyncMock() + mock_realtime_class.return_value = mock_realtime + + # Setup mock responses + mock_client.authenticate.return_value = None + mock_client.get_instrument.return_value = MagicMock( + symbol="MNQ", name="MNQ", exchange="CME", min_tick=0.25 + ) + mock_client.get_bars.return_value = MagicMock(is_empty=lambda: True) + + # Mock realtime client connection + mock_realtime.connect.return_value = None + mock_realtime.is_connected.return_value = True + + suite = await TradingSuite.create("MNQ", timeframes=["1min"]) + mnq_context = suite["MNQ"] + + # Create a task that waits for an event + async def wait_for_event(): + event = await mnq_context.wait_for(EventType.NEW_BAR) + return event + + wait_task = asyncio.create_task(wait_for_event()) + + # Give the task time to start waiting + await asyncio.sleep(0.1) + + # Emit event to instrument's event bus + await mnq_context.event_bus.emit( + EventType.NEW_BAR, {"instrument": "MNQ", "timeframe": "1min", "close": 100.0} + ) + + # wait_for should receive the event + try: + event = await asyncio.wait_for(wait_task, timeout=1.0) + assert event.data["instrument"] == "MNQ" + assert event.data["close"] == 100.0 + except asyncio.TimeoutError: + pytest.fail("instrument_context.wait_for() did not receive event") + + +@pytest.mark.asyncio +async def test_event_filtering_by_instrument(): + """Test that instrument contexts can filter events by instrument.""" + with patch("project_x_py.trading_suite.ProjectX") as mock_client_class, \ + patch("project_x_py.trading_suite.ProjectXRealtimeClient") as mock_realtime_class: + + # Mock main client + mock_client = AsyncMock() + mock_client_class.from_env.return_value.__aenter__.return_value = mock_client + + # Mock realtime client + mock_realtime = AsyncMock() + mock_realtime_class.return_value = mock_realtime + + # Setup mock responses + mock_client.authenticate.return_value = None + # Create enough mock instruments for multiple calls + def mock_get_instrument(symbol): + return MagicMock(symbol=symbol, name=symbol, exchange="CME", min_tick=0.25, id=f"id_{symbol}") + mock_client.get_instrument.side_effect = mock_get_instrument + mock_client.get_bars.return_value = MagicMock(is_empty=lambda: True) + + # Mock realtime client connection + mock_realtime.connect.return_value = None + mock_realtime.is_connected.return_value = True + + suite = await TradingSuite.create(["MNQ", "NQ"], timeframes=["1min"]) + + mnq_events = [] + nq_events = [] + + async def mnq_handler(event): + mnq_events.append(event) + + async def nq_handler(event): + nq_events.append(event) + + # Register handlers on specific instrument contexts + await suite["MNQ"].on(EventType.NEW_BAR, mnq_handler) + await suite["NQ"].on(EventType.NEW_BAR, nq_handler) + + # Emit events from both instruments + await suite["MNQ"].event_bus.emit( + EventType.NEW_BAR, {"instrument": "MNQ", "timeframe": "1min", "close": 100.0} + ) + await suite["NQ"].event_bus.emit( + EventType.NEW_BAR, {"instrument": "NQ", "timeframe": "1min", "close": 200.0} + ) + + # Allow event propagation + await asyncio.sleep(0.1) + + # Each handler should only receive events from its instrument + assert len(mnq_events) == 1 + assert mnq_events[0].data["instrument"] == "MNQ" + + assert len(nq_events) == 1 + assert nq_events[0].data["instrument"] == "NQ" + + +@pytest.mark.asyncio +async def test_suite_level_handler_receives_all_instruments(): + """Test that a single suite-level handler receives events from all instruments.""" + with patch("project_x_py.trading_suite.ProjectX") as mock_client_class, \ + patch("project_x_py.trading_suite.ProjectXRealtimeClient") as mock_realtime_class: + + # Mock main client + mock_client = AsyncMock() + mock_client_class.from_env.return_value.__aenter__.return_value = mock_client + + # Mock realtime client + mock_realtime = AsyncMock() + mock_realtime_class.return_value = mock_realtime + + # Setup mock responses + mock_client.authenticate.return_value = None + # Create enough mock instruments for multiple calls + def mock_get_instrument(symbol): + return MagicMock(symbol=symbol, name=symbol, exchange="CME", min_tick=0.25, id=f"id_{symbol}") + mock_client.get_instrument.side_effect = mock_get_instrument + mock_client.get_bars.return_value = MagicMock(is_empty=lambda: True) + + # Mock realtime client connection + mock_realtime.connect.return_value = None + mock_realtime.is_connected.return_value = True + + suite = await TradingSuite.create(["MNQ", "NQ", "ES"], timeframes=["1min"]) + + all_events = [] + + async def universal_handler(event): + all_events.append(event) + + # Register a single handler at suite level + await suite.on(EventType.NEW_BAR, universal_handler) + + # Emit events from all instruments + for symbol in ["MNQ", "NQ", "ES"]: + await suite[symbol].event_bus.emit( + EventType.NEW_BAR, + { + "instrument": symbol, + "timeframe": "1min", + "close": 100.0 * (ord(symbol[0]) - 64), + }, + ) + + # Allow event propagation + await asyncio.sleep(0.1) + + # Handler should receive events from all instruments + assert len(all_events) == 3 + instruments = {e.data["instrument"] for e in all_events} + assert instruments == {"MNQ", "NQ", "ES"} diff --git a/tests/test_trading_suite_events.py b/tests/test_trading_suite_events.py new file mode 100644 index 0000000..6df6c8d --- /dev/null +++ b/tests/test_trading_suite_events.py @@ -0,0 +1,476 @@ +""" +Comprehensive test suite for TradingSuite event system. + +These tests ensure that the multi-instrument event system works correctly +and prevents regressions in event handling functionality. +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest + +from project_x_py import TradingSuite +from project_x_py.event_bus import EventBus, EventType + + +class TestTradingSuiteEventSystem: + """Test suite for TradingSuite event system functionality.""" + + @pytest.fixture + async def mock_client(self): + """Create a mock ProjectX client.""" + with patch("project_x_py.trading_suite.ProjectX") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.from_env.return_value.__aenter__.return_value = mock_client + + # Setup mock responses + mock_client.authenticate = AsyncMock(return_value=None) + mock_client.get_instrument = AsyncMock( + return_value=MagicMock( + symbol="MNQ", name="MNQ", exchange="CME", min_tick=0.25 + ) + ) + mock_client.get_bars = AsyncMock( + return_value=MagicMock(is_empty=lambda: True) + ) + mock_client.search_positions = AsyncMock(return_value=[]) + mock_client.search_orders = AsyncMock(return_value=[]) + + yield mock_client + + @pytest.fixture + async def mock_realtime_client(self): + """Create a mock realtime client.""" + with patch("project_x_py.realtime.ProjectXRealtimeClient") as mock_rt_class: + mock_rt = AsyncMock() + mock_rt_class.return_value = mock_rt + + # Mock connection methods + mock_rt.connect = AsyncMock(return_value=True) + mock_rt.disconnect = AsyncMock() + mock_rt.is_connected = True + mock_rt.subscribe_to_market = AsyncMock() + mock_rt.subscribe_to_user = AsyncMock() + + yield mock_rt + + @pytest.mark.asyncio + async def test_instrument_context_has_event_methods(self): + """Test that InstrumentContext has all required event methods.""" + # Create an InstrumentContext directly with event_bus + from project_x_py.trading_suite import InstrumentContext + event_bus = EventBus() + context = InstrumentContext( + symbol="MNQ", + instrument_info=MagicMock(), + data=MagicMock(), + orders=MagicMock(), + positions=MagicMock(), + event_bus=event_bus + ) + + # Verify all event methods exist + assert hasattr(context, "on") + assert hasattr(context, "once") + assert hasattr(context, "off") + assert hasattr(context, "wait_for") + assert hasattr(context, "event_bus") + + # Verify methods are callable + assert callable(context.on) + assert callable(context.once) + assert callable(context.off) + assert callable(context.wait_for) + + # Test that methods work + events_received = [] + + async def handler(event): + events_received.append(event) + + # Register handler + await context.on(EventType.NEW_BAR, handler) + + # Emit event + await context.event_bus.emit( + EventType.NEW_BAR, + {"instrument": "MNQ", "timeframe": "1min", "data": {"close": 100.0}} + ) + + # Allow propagation + await asyncio.sleep(0.05) + + # Verify handler was called + assert len(events_received) == 1 + assert events_received[0].data["instrument"] == "MNQ" + + @pytest.mark.asyncio + async def test_event_forwarding_between_instruments_and_suite(self): + """Test that events are forwarded from instrument EventBuses to suite EventBus.""" + # Create independent EventBuses + suite_bus = EventBus() + mnq_bus = EventBus() + nq_bus = EventBus() + + # Set up forwarding + await mnq_bus.forward_to(suite_bus) + await nq_bus.forward_to(suite_bus) + + # Track events at different levels + suite_events = [] + mnq_events = [] + nq_events = [] + + async def suite_handler(event): + suite_events.append(event) + + async def mnq_handler(event): + mnq_events.append(event) + + async def nq_handler(event): + nq_events.append(event) + + # Register handlers + await suite_bus.on(EventType.NEW_BAR, suite_handler) + await mnq_bus.on(EventType.NEW_BAR, mnq_handler) + await nq_bus.on(EventType.NEW_BAR, nq_handler) + + # Emit events from each instrument + await mnq_bus.emit( + EventType.NEW_BAR, + {"instrument": "MNQ", "timeframe": "1min", "data": {"close": 100.0}} + ) + await nq_bus.emit( + EventType.NEW_BAR, + {"instrument": "NQ", "timeframe": "1min", "data": {"close": 200.0}} + ) + + # Allow event propagation + await asyncio.sleep(0.1) + + # Verify instrument handlers only receive their own events + assert len(mnq_events) == 1 + assert mnq_events[0].data["instrument"] == "MNQ" + + assert len(nq_events) == 1 + assert nq_events[0].data["instrument"] == "NQ" + + # Verify suite handler receives all events + assert len(suite_events) == 2 + assert any(e.data["instrument"] == "MNQ" for e in suite_events) + assert any(e.data["instrument"] == "NQ" for e in suite_events) + + @pytest.mark.asyncio + async def test_wait_for_with_forwarding(self): + """Test that wait_for works correctly with event forwarding.""" + suite_bus = EventBus() + instrument_bus = EventBus() + + # Set up forwarding + await instrument_bus.forward_to(suite_bus) + + # Start waiting at suite level + async def wait_for_event(): + return await suite_bus.wait_for(EventType.NEW_BAR, timeout=1.0) + + wait_task = asyncio.create_task(wait_for_event()) + + # Give task time to start waiting + await asyncio.sleep(0.05) + + # Emit from instrument + await instrument_bus.emit( + EventType.NEW_BAR, + {"instrument": "MNQ", "timeframe": "1min", "data": {"close": 100.0}} + ) + + # Should receive the event + event = await wait_task + assert event.data["instrument"] == "MNQ" + assert event.data["data"]["close"] == 100.0 + + @pytest.mark.asyncio + async def test_multiple_event_types_forwarding(self): + """Test that different event types are all forwarded correctly.""" + suite_bus = EventBus() + instrument_bus = EventBus() + + # Set up forwarding + await instrument_bus.forward_to(suite_bus) + + # Track different event types + events_by_type = { + EventType.NEW_BAR: [], + EventType.QUOTE_UPDATE: [], + EventType.TRADE_TICK: [], + EventType.CONNECTED: [], + } + + # Register handlers for each type + for event_type in events_by_type: + async def make_handler(et): + async def handler(event): + events_by_type[et].append(event) + return handler + + await suite_bus.on(event_type, await make_handler(event_type)) + + # Emit various events + await instrument_bus.emit(EventType.NEW_BAR, {"type": "bar"}) + await instrument_bus.emit(EventType.QUOTE_UPDATE, {"type": "quote"}) + await instrument_bus.emit(EventType.TRADE_TICK, {"type": "trade"}) + await instrument_bus.emit(EventType.CONNECTED, {"type": "connected"}) + + # Allow propagation + await asyncio.sleep(0.1) + + # Verify all event types were forwarded + assert len(events_by_type[EventType.NEW_BAR]) == 1 + assert len(events_by_type[EventType.QUOTE_UPDATE]) == 1 + assert len(events_by_type[EventType.TRADE_TICK]) == 1 + assert len(events_by_type[EventType.CONNECTED]) == 1 + + @pytest.mark.asyncio + async def test_event_handler_removal(self): + """Test that event handlers can be properly removed.""" + event_bus = EventBus() + + events_received = [] + + async def handler(event): + events_received.append(event) + + # Register handler + await event_bus.on(EventType.NEW_BAR, handler) + + # Emit event - should be received + await event_bus.emit(EventType.NEW_BAR, {"test": 1}) + await asyncio.sleep(0.05) + assert len(events_received) == 1 + + # Remove handler + await event_bus.off(EventType.NEW_BAR, handler) + + # Emit again - should not be received + await event_bus.emit(EventType.NEW_BAR, {"test": 2}) + await asyncio.sleep(0.05) + assert len(events_received) == 1 # Still only 1 + + @pytest.mark.asyncio + async def test_once_handler(self): + """Test that once handlers only fire once.""" + event_bus = EventBus() + + events_received = [] + + async def handler(event): + events_received.append(event) + + # Register once handler + await event_bus.once(EventType.NEW_BAR, handler) + + # Emit multiple events + await event_bus.emit(EventType.NEW_BAR, {"test": 1}) + await event_bus.emit(EventType.NEW_BAR, {"test": 2}) + await event_bus.emit(EventType.NEW_BAR, {"test": 3}) + + await asyncio.sleep(0.1) + + # Should only receive first event + assert len(events_received) == 1 + assert events_received[0].data["test"] == 1 + + @pytest.mark.asyncio + async def test_wildcard_event_forwarding(self): + """Test that wildcard handlers forward all events.""" + suite_bus = EventBus() + instrument_bus = EventBus() + + # Set up forwarding (uses wildcard internally) + await instrument_bus.forward_to(suite_bus) + + # Track all events at suite level + all_events = [] + + async def wildcard_handler(event): + all_events.append(event) + + await suite_bus.on_any(wildcard_handler) + + # Emit various event types + await instrument_bus.emit(EventType.NEW_BAR, {"type": "bar"}) + await instrument_bus.emit(EventType.QUOTE_UPDATE, {"type": "quote"}) + await instrument_bus.emit("custom_event", {"type": "custom"}) + + await asyncio.sleep(0.1) + + # Should receive all events + assert len(all_events) == 3 + assert any(e.data["type"] == "bar" for e in all_events) + assert any(e.data["type"] == "quote" for e in all_events) + assert any(e.data["type"] == "custom" for e in all_events) + + @pytest.mark.asyncio + async def test_event_data_structure_for_new_bar(self): + """Test the expected data structure for NEW_BAR events.""" + event_bus = EventBus() + + received_event = None + + async def handler(event): + nonlocal received_event + received_event = event + + await event_bus.on(EventType.NEW_BAR, handler) + + # Emit with expected structure + bar_data = { + "timeframe": "1min", + "data": { + "open": 100.0, + "high": 101.0, + "low": 99.0, + "close": 100.5, + "volume": 1000, + "timestamp": "2024-01-01T00:00:00Z" + } + } + + await event_bus.emit(EventType.NEW_BAR, bar_data) + await asyncio.sleep(0.05) + + # Verify structure + assert received_event is not None + assert "timeframe" in received_event.data + assert "data" in received_event.data + + inner_data = received_event.data["data"] + assert "open" in inner_data + assert "high" in inner_data + assert "low" in inner_data + assert "close" in inner_data + assert "volume" in inner_data + + @pytest.mark.asyncio + async def test_event_handling_order(self): + """Test that events are handled in order.""" + event_bus = EventBus() + + processing_order = [] + + async def handler(event): + processing_order.append(f"event_{event.data['id']}") + + await event_bus.on(EventType.NEW_BAR, handler) + + # Emit multiple events + for i in range(3): + await event_bus.emit(EventType.NEW_BAR, {"id": i}) + + # Wait for processing + await asyncio.sleep(0.1) + + # Events should be processed in order + assert len(processing_order) == 3 + assert processing_order[0] == "event_0" + assert processing_order[1] == "event_1" + assert processing_order[2] == "event_2" + + +class TestEventSystemRegression: + """Regression tests to prevent breaking the event system again.""" + + @pytest.mark.asyncio + async def test_instrument_context_methods_delegate_to_event_bus(self): + """Ensure InstrumentContext methods properly delegate to event_bus.""" + from project_x_py.trading_suite import InstrumentContext + + # Create mock event bus + mock_event_bus = AsyncMock(spec=EventBus) + + # Create context + context = InstrumentContext( + symbol="MNQ", + instrument_info=MagicMock(), + data=MagicMock(), + orders=MagicMock(), + positions=MagicMock(), + event_bus=mock_event_bus + ) + + # Test delegation + handler = AsyncMock() + + await context.on(EventType.NEW_BAR, handler) + mock_event_bus.on.assert_called_once_with(EventType.NEW_BAR, handler) + + await context.once(EventType.QUOTE_UPDATE, handler) + mock_event_bus.once.assert_called_once_with(EventType.QUOTE_UPDATE, handler) + + await context.off(EventType.NEW_BAR, handler) + mock_event_bus.off.assert_called_once_with(EventType.NEW_BAR, handler) + + await context.wait_for(EventType.CONNECTED, timeout=5.0) + mock_event_bus.wait_for.assert_called_once_with(EventType.CONNECTED, 5.0) + + @pytest.mark.asyncio + async def test_forward_to_method_exists_and_works(self): + """Ensure EventBus.forward_to method exists and functions correctly.""" + source_bus = EventBus() + target_bus = EventBus() + + # Verify method exists + assert hasattr(source_bus, "forward_to") + assert callable(source_bus.forward_to) + + # Set up forwarding + await source_bus.forward_to(target_bus) + + # Verify forwarding works + events = [] + + async def handler(event): + events.append(event) + + await target_bus.on(EventType.NEW_BAR, handler) + await source_bus.emit(EventType.NEW_BAR, {"test": "data"}) + + await asyncio.sleep(0.05) + + assert len(events) == 1 + assert events[0].data["test"] == "data" + + @pytest.mark.asyncio + async def test_trading_suite_setup_event_forwarding_called(self): + """Ensure _setup_event_forwarding is called during TradingSuite creation.""" + with patch("project_x_py.trading_suite.TradingSuite._setup_event_forwarding") as mock_setup: + mock_setup.return_value = asyncio.Future() + mock_setup.return_value.set_result(None) + + with patch("project_x_py.trading_suite.ProjectX") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.from_env.return_value.__aenter__.return_value = mock_client + mock_client.authenticate = AsyncMock() + mock_client.get_instrument = AsyncMock( + return_value=MagicMock(symbol="MNQ", name="MNQ", exchange="CME", min_tick=0.25) + ) + mock_client.get_bars = AsyncMock(return_value=MagicMock(is_empty=lambda: True)) + mock_client.search_positions = AsyncMock(return_value=[]) + mock_client.search_orders = AsyncMock(return_value=[]) + + with patch("project_x_py.realtime.ProjectXRealtimeClient") as mock_rt: + mock_rt.return_value = AsyncMock() + mock_rt.return_value.connect = AsyncMock(return_value=True) + mock_rt.return_value.is_connected = True + + try: + suite = await TradingSuite.create("MNQ", timeframes=["1min"], auto_connect=False) + # Verify _setup_event_forwarding was called + mock_setup.assert_called() + except Exception: + # Even if creation fails for other reasons, we just want to verify the method was called + if mock_setup.called: + pass # Test passes + else: + raise diff --git a/tests/trading_suite/test_complete_coverage.py b/tests/trading_suite/test_complete_coverage.py index c736174..6f2434a 100644 --- a/tests/trading_suite/test_complete_coverage.py +++ b/tests/trading_suite/test_complete_coverage.py @@ -296,6 +296,7 @@ async def test_cleanup_contexts(self): data=AsyncMock(spec=RealtimeDataManager), orders=AsyncMock(spec=OrderManager), positions=AsyncMock(spec=PositionManager), + event_bus=AsyncMock(spec=EventBus), orderbook=AsyncMock(spec=OrderBook), risk_manager=AsyncMock(spec=RiskManager), ) @@ -306,6 +307,7 @@ async def test_cleanup_contexts(self): data=AsyncMock(spec=RealtimeDataManager), orders=AsyncMock(spec=OrderManager), positions=AsyncMock(spec=PositionManager), + event_bus=AsyncMock(spec=EventBus), orderbook=None, # No orderbook risk_manager=None, # No risk manager ) @@ -337,6 +339,7 @@ async def test_cleanup_contexts_with_errors(self): data=AsyncMock(spec=RealtimeDataManager), orders=AsyncMock(spec=OrderManager), positions=AsyncMock(spec=PositionManager), + event_bus=AsyncMock(spec=EventBus), orderbook=AsyncMock(spec=OrderBook), risk_manager=None, ) @@ -491,6 +494,7 @@ async def _create_suite_with_contexts(self, symbols): data=Mock(spec=RealtimeDataManager), orders=Mock(spec=OrderManager), positions=Mock(spec=PositionManager), + event_bus=Mock(spec=EventBus), orderbook=None, risk_manager=None, ) @@ -610,6 +614,7 @@ async def _create_suite_with_session_support(self, symbols): data=data_manager, orders=Mock(spec=OrderManager), positions=Mock(spec=PositionManager), + event_bus=Mock(spec=EventBus), orderbook=None, risk_manager=None, ) @@ -711,6 +716,7 @@ async def _create_suite_with_single_context(self): data=Mock(spec=RealtimeDataManager), orders=Mock(spec=OrderManager), positions=Mock(spec=PositionManager), + event_bus=Mock(spec=EventBus), orderbook=None, risk_manager=None, ) @@ -733,6 +739,7 @@ async def _create_suite_with_multiple_contexts(self): data=Mock(spec=RealtimeDataManager), orders=Mock(spec=OrderManager), positions=Mock(spec=PositionManager), + event_bus=Mock(spec=EventBus), orderbook=None, risk_manager=None, ) @@ -791,6 +798,7 @@ async def test_disconnect_multi_instrument(self): data=data, orders=Mock(spec=OrderManager), positions=Mock(spec=PositionManager), + event_bus=Mock(spec=EventBus), orderbook=orderbook, risk_manager=None, ) @@ -897,6 +905,7 @@ async def test_data_property_multi_instrument_mode(self): data=Mock(spec=RealtimeDataManager), orders=Mock(spec=OrderManager), positions=Mock(spec=PositionManager), + event_bus=Mock(spec=EventBus), orderbook=None, risk_manager=None, ) diff --git a/tests/trading_suite/test_multi_instrument.py b/tests/trading_suite/test_multi_instrument.py index b888f41..b4dd695 100644 --- a/tests/trading_suite/test_multi_instrument.py +++ b/tests/trading_suite/test_multi_instrument.py @@ -34,6 +34,7 @@ async def test_instrument_context_creation(): mock_data_manager = MagicMock() mock_order_manager = MagicMock() mock_position_manager = MagicMock() + mock_event_bus = MagicMock() mock_orderbook = MagicMock() mock_risk_manager = MagicMock() @@ -44,6 +45,7 @@ async def test_instrument_context_creation(): data=mock_data_manager, orders=mock_order_manager, positions=mock_position_manager, + event_bus=mock_event_bus, orderbook=mock_orderbook, risk_manager=mock_risk_manager, ) diff --git a/uv.lock b/uv.lock index 2db7484..c580e72 100644 --- a/uv.lock +++ b/uv.lock @@ -2360,7 +2360,7 @@ wheels = [ [[package]] name = "project-x-py" -version = "3.5.5" +version = "3.5.6" source = { editable = "." } dependencies = [ { name = "cachetools" },