diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..e04a80b --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,39 @@ +# Repository Guidelines + +## Project Structure & Module Organization +- Source: `src/project_x_py/` (core SDK: client, realtime, orderbook, indicators, utils, etc.). +- Tests: `tests/` (pytest; async tests supported). +- Examples: `examples/` (runnable end-to-end samples). +- Docs: `docs/` (Sphinx); helper script `scripts/build-docs.py`. +- Scripts: `scripts/` (build, docs, versioning). Build artifacts in `dist/` and coverage in `htmlcov/`. + +## Build, Test, and Development Commands +- Install dev env: `uv sync` (or `pip install -e ".[dev]"`). +- Run tests + coverage: `uv run pytest` (HTML at `htmlcov/index.html`). +- Lint: `uv run ruff check .` Format: `uv run ruff format .` Types: `uv run mypy src`. +- Docs: `uv run python scripts/build-docs.py --clean --open`. +- CLI helpers: `uv run projectx-check` and `uv run projectx-config`. +- Run an example: `uv run python examples/01_basic_client_connection.py`. + +## Coding Style & Naming Conventions +- Python 3.12+, 4-space indents, max line length 88. +- Format with Ruff formatter (Black-compatible); import order via Ruff/isort. +- Naming follows PEP 8; uppercase class names allowed in `indicators/` (see Ruff per-file ignores). +- Keep functions small, typed, and documented where behavior is non-obvious. + +## Testing Guidelines +- Framework: pytest (+ pytest-asyncio). Place tests under `tests/` as `test_*.py`. +- Marks: `unit`, `integration`, `slow`, `realtime` (see `pyproject.toml`). +- Aim for meaningful coverage of public APIs; coverage reports are produced automatically. +- Prefer async-safe patterns; use fixtures and markers to isolate realtime or networked tests. + +## Commit & Pull Request Guidelines +- Use Conventional Commits: `feat:`, `fix:`, `perf:`, `docs:`, `chore:`, etc. Add scope when helpful: `fix(orderbook): ...`. +- Keep subject ≤ 72 chars; body explains what/why and migration notes if breaking. +- Before PR: run `uv run ruff format . && uv run ruff check . && uv run mypy src && uv run pytest`. +- PRs include: clear description, linked issues, test updates, docs updates (if user-facing), and screenshots/logs when relevant. + +## Security & Configuration Tips +- Auth via env vars `PROJECT_X_API_KEY`, `PROJECT_X_USERNAME`, or config at `~/.config/projectx/config.json`. +- Avoid committing secrets; prefer `.env` locally and CI secrets in GitHub. +- When adding realtime features, guard network calls in tests with markers. diff --git a/CHANGELOG.md b/CHANGELOG.md index 51ba645..2baf0d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Old implementations are removed when improved - Clean, modern code architecture is prioritized +## [3.1.1] - 2025-08-10 + +### Fixed +- **🐛 Test Suite Compatibility**: Fixed all failing tests for optimized cache implementation + - Updated test references from old cache variables (`_instrument_cache`) to new optimized ones (`_opt_instrument_cache`) + - Fixed datetime serialization/deserialization in cached DataFrames to properly preserve timezone information + - Resolved BatchedWebSocketHandler flush mechanism with event-based signaling for immediate message processing + - Fixed race condition in BatchedWebSocketHandler task creation + - Corrected SignalR mock methods in connection management tests (changed from AsyncMock to MagicMock for synchronous methods) + +### Improved +- **✨ Cache Serialization**: Enhanced datetime handling in msgpack cache + - Proper timezone preservation for datetime columns in Polars DataFrames + - More robust deserialization with fallback handling + - Better datetime string format compatibility + ## [3.1.0] - 2025-08-09 ### Added diff --git a/README.md b/README.md index 3d767ce..f94ae4b 100644 --- a/README.md +++ b/README.md @@ -21,9 +21,15 @@ A **high-performance async Python SDK** for the [ProjectX Trading Platform](http This Python SDK acts as a bridge between your trading strategies and the ProjectX platform, handling all the complex API interactions, data processing, and real-time connectivity. -## 🚀 v3.1.0 - High-Performance Production Suite +## 🚀 v3.1.1 - High-Performance Production Suite -**Latest Update (v3.1.0)**: Major performance optimizations delivering 2-5x improvements across the board with automatic memory management and enterprise-grade caching. +**Latest Update (v3.1.1)**: Bug fixes and improvements for test suite compatibility with optimized cache implementation, enhanced datetime serialization, and WebSocket handler improvements. + +### What's New in v3.1.1 +- **Fixed**: Test suite compatibility with optimized cache implementation +- **Fixed**: Datetime serialization/deserialization in cached DataFrames +- **Fixed**: BatchedWebSocketHandler flush and race condition issues +- **Fixed**: SignalR mock methods in connection management tests ### What's New in v3.1.0 diff --git a/docs/conf.py b/docs/conf.py index 30abaa4..b87905b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -23,8 +23,8 @@ project = "project-x-py" copyright = "2025, Jeff West" author = "Jeff West" -release = "3.1.0" -version = "3.1.0" +release = "3.1.1" +version = "3.1.1" # -- General configuration --------------------------------------------------- diff --git a/pyproject.toml b/pyproject.toml index 52c5bb6..100648d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "project-x-py" -version = "3.1.0" +version = "3.1.1" description = "High-performance Python SDK for futures trading with real-time WebSocket data, technical indicators, order management, and market depth analysis" readme = "README.md" license = { text = "MIT" } diff --git a/src/project_x_py/__init__.py b/src/project_x_py/__init__.py index f4e456e..65c5606 100644 --- a/src/project_x_py/__init__.py +++ b/src/project_x_py/__init__.py @@ -95,7 +95,7 @@ from project_x_py.client.base import ProjectXBase -__version__ = "3.1.0" +__version__ = "3.1.1" __author__ = "TexasCoding" # Core client classes - renamed from Async* to standard names diff --git a/src/project_x_py/client/cache.py b/src/project_x_py/client/cache.py index b4b129f..870adf1 100644 --- a/src/project_x_py/client/cache.py +++ b/src/project_x_py/client/cache.py @@ -9,6 +9,7 @@ import gc import logging +import re import time from typing import TYPE_CHECKING, Any @@ -75,9 +76,20 @@ def _serialize_dataframe(self, df: pl.DataFrame) -> bytes: return b"" # Convert to dictionary format for msgpack + columns_data = {} + for col in df.columns: + col_data = df[col] + # Convert datetime columns to ISO strings for msgpack serialization + if col_data.dtype in [pl.Datetime, pl.Date]: + columns_data[col] = col_data.dt.to_string( + "%Y-%m-%d %H:%M:%S%.f" + ).to_list() + else: + columns_data[col] = col_data.to_list() + data = { "schema": {name: str(dtype) for name, dtype in df.schema.items()}, - "columns": {col: df[col].to_list() for col in df.columns}, + "columns": columns_data, "shape": df.shape, } @@ -132,8 +144,34 @@ def _deserialize_dataframe(self, data: bytes) -> pl.DataFrame | None: if not unpacked or "columns" not in unpacked: return None - # Reconstruct DataFrame - return pl.DataFrame(unpacked["columns"]) + # Reconstruct DataFrame with proper schema + df = pl.DataFrame(unpacked["columns"]) + + # Restore datetime columns based on stored schema + if "schema" in unpacked: + for col_name, dtype_str in unpacked["schema"].items(): + if "datetime" in dtype_str.lower() and col_name in df.columns: + # Parse timezone from dtype string (e.g., "Datetime(time_unit='us', time_zone='UTC')") + time_zone = None + if "time_zone=" in dtype_str: + # Extract timezone + tz_match = re.search(r"time_zone='([^']+)'", dtype_str) + if tz_match: + time_zone = tz_match.group(1) + + # Convert string column to datetime + if df[col_name].dtype == pl.Utf8: + df = df.with_columns( + pl.col(col_name) + .str.strptime( + pl.Datetime("us", time_zone), + "%Y-%m-%d %H:%M:%S%.f", + strict=False, + ) + .alias(col_name) + ) + + return df except Exception as e: logger.debug(f"Failed to deserialize DataFrame: {e}") return None diff --git a/src/project_x_py/indicators/__init__.py b/src/project_x_py/indicators/__init__.py index 49c9633..376417b 100644 --- a/src/project_x_py/indicators/__init__.py +++ b/src/project_x_py/indicators/__init__.py @@ -202,7 +202,7 @@ ) # Version info -__version__ = "3.1.0" +__version__ = "3.1.1" __author__ = "TexasCoding" diff --git a/src/project_x_py/realtime/batched_handler.py b/src/project_x_py/realtime/batched_handler.py index 9bcd973..5ae417e 100644 --- a/src/project_x_py/realtime/batched_handler.py +++ b/src/project_x_py/realtime/batched_handler.py @@ -65,6 +65,9 @@ def __init__( # Lock for thread safety self._lock = asyncio.Lock() + # Event to signal immediate flush + self._flush_event = asyncio.Event() + async def handle_message(self, message: dict[str, Any]) -> None: """ Add a message to the batch queue for processing. @@ -76,7 +79,9 @@ async def handle_message(self, message: dict[str, Any]) -> None: self.message_queue.append(message) # Start batch processing if not already running - if not self.processing: + if not self.processing and ( + not self._processing_task or self._processing_task.done() + ): self._processing_task = asyncio.create_task(self._process_batch()) async def _process_batch(self) -> None: @@ -90,17 +95,39 @@ async def _process_batch(self) -> None: batch: list[dict[str, Any]] = [] deadline = time.time() + self.batch_timeout - # Collect messages until batch is full or timeout - while time.time() < deadline and len(batch) < self.batch_size: + # Collect messages until batch is full or timeout or flush is requested + while ( + time.time() < deadline + and len(batch) < self.batch_size + and not self._flush_event.is_set() + ): if self.message_queue: # Get all available messages up to batch size while self.message_queue and len(batch) < self.batch_size: batch.append(self.message_queue.popleft()) else: - # Wait a bit for more messages + # Wait a bit for more messages or flush event remaining = deadline - time.time() if remaining > 0: - await asyncio.sleep(min(0.001, remaining)) + try: + # Wait for either timeout or flush event + await asyncio.wait_for( + self._flush_event.wait(), + timeout=min( + 0.01, remaining + ), # Increased from 0.001 to 0.01 + ) + # Flush was triggered, break the loop + break + except TimeoutError: + # Normal timeout, continue + pass + + # If flush was triggered, get any remaining messages + if self._flush_event.is_set(): + while self.message_queue and len(batch) < 10000: # Safety limit + batch.append(self.message_queue.popleft()) + self._flush_event.clear() # Process the batch if we have messages if batch: @@ -148,17 +175,16 @@ async def _process_batch(self) -> None: async def flush(self) -> None: """Force processing of all queued messages immediately.""" - # Wait for any current processing to complete first + # Signal the processing task to flush immediately + self._flush_event.set() + + # Wait for the current processing task to complete if it exists if self._processing_task and not self._processing_task.done(): - with contextlib.suppress(TimeoutError, asyncio.CancelledError): + with contextlib.suppress(asyncio.TimeoutError, asyncio.CancelledError): await asyncio.wait_for(self._processing_task, timeout=1.0) - # Give the processing task a chance to actually process - await asyncio.sleep(0) # Yield to let other tasks run - - # Now process any remaining messages - while self.message_queue: - # Process all remaining messages + # Process any remaining messages that weren't picked up + if self.message_queue: batch = list(self.message_queue) self.message_queue.clear() @@ -170,6 +196,10 @@ async def flush(self) -> None: except Exception as e: logger.error(f"Error flushing batch of {len(batch)} messages: {e}") + # Clear the flush event for next time + self._flush_event.clear() + self.processing = False + def get_stats(self) -> dict[str, Any]: """ Get performance statistics for the batch handler. @@ -202,6 +232,9 @@ def get_stats(self) -> dict[str, Any]: async def stop(self) -> None: """Stop the batch handler and process remaining messages.""" + # Signal flush to trigger immediate processing + self._flush_event.set() + # Wait for current processing to complete if self._processing_task and not self._processing_task.done(): try: @@ -211,7 +244,7 @@ async def stop(self) -> None: with contextlib.suppress(asyncio.CancelledError): await self._processing_task - # Flush remaining messages + # Process any remaining messages that weren't handled await self.flush() logger.info( diff --git a/tests/client/test_client_integration.py b/tests/client/test_client_integration.py index 172116f..80a054c 100644 --- a/tests/client/test_client_integration.py +++ b/tests/client/test_client_integration.py @@ -75,7 +75,7 @@ async def test_auth_market_data_workflow( # Step 2: Get instrument data instrument = await client.get_instrument("MGC") assert instrument is not None - assert "MGC" in client._instrument_cache + assert "MGC" in client._opt_instrument_cache # Step 3: Get market data bars = await client.get_bars("MGC", days=5, interval=5) @@ -85,7 +85,7 @@ async def test_auth_market_data_workflow( # Step 4: Verify cache is populated cache_key = "MGC_5_5_2_True" - assert cache_key in client._market_data_cache + assert cache_key in client._opt_market_data_cache @pytest.mark.asyncio async def test_trading_workflow( diff --git a/tests/client/test_market_data.py b/tests/client/test_market_data.py index 3148650..6eea1bc 100644 --- a/tests/client/test_market_data.py +++ b/tests/client/test_market_data.py @@ -30,10 +30,10 @@ async def test_get_instrument( async with ProjectX("testuser", "test-api-key") as client: # Initialize required attributes client.api_call_count = 0 - client._instrument_cache = {} - client._instrument_cache_time = {} - client._market_data_cache = {} - client._market_data_cache_time = {} + client._opt_instrument_cache = {} + client._opt_instrument_cache_time = {} + client._opt_market_data_cache = {} + client._opt_market_data_cache_time = {} client.cache_ttl = 300 client.last_cache_cleanup = time.time() client.cache_hit_count = 0 @@ -47,7 +47,7 @@ async def test_get_instrument( assert instrument.name == "Micro Gold Futures" # Should have cached the result - assert "MGC" in client._instrument_cache + assert "MGC" in client._opt_instrument_cache @pytest.mark.asyncio async def test_get_instrument_from_cache( @@ -64,10 +64,10 @@ async def test_get_instrument_from_cache( async with ProjectX("testuser", "test-api-key") as client: # Initialize required attributes client.api_call_count = 0 - client._instrument_cache = {} - client._instrument_cache_time = {} - client._market_data_cache = {} - client._market_data_cache_time = {} + client._opt_instrument_cache = {} + client._opt_instrument_cache_time = {} + client._opt_market_data_cache = {} + client._opt_market_data_cache_time = {} client.cache_ttl = 300 client.last_cache_cleanup = time.time() client.cache_hit_count = 0 @@ -106,10 +106,10 @@ async def test_get_instrument_not_found( async with ProjectX("testuser", "test-api-key") as client: # Initialize required attributes client.api_call_count = 0 - client._instrument_cache = {} - client._instrument_cache_time = {} - client._market_data_cache = {} - client._market_data_cache_time = {} + client._opt_instrument_cache = {} + client._opt_instrument_cache_time = {} + client._opt_market_data_cache = {} + client._opt_market_data_cache_time = {} client.cache_ttl = 300 client.last_cache_cleanup = time.time() client.cache_hit_count = 0 @@ -137,10 +137,10 @@ async def test_search_instruments( async with ProjectX("testuser", "test-api-key") as client: # Initialize required attributes client.api_call_count = 0 - client._instrument_cache = {} - client._instrument_cache_time = {} - client._market_data_cache = {} - client._market_data_cache_time = {} + client._opt_instrument_cache = {} + client._opt_instrument_cache_time = {} + client._opt_market_data_cache = {} + client._opt_market_data_cache_time = {} client.cache_ttl = 300 client.last_cache_cleanup = time.time() client.cache_hit_count = 0 @@ -171,10 +171,10 @@ async def test_search_instruments_empty_result( async with ProjectX("testuser", "test-api-key") as client: # Initialize required attributes client.api_call_count = 0 - client._instrument_cache = {} - client._instrument_cache_time = {} - client._market_data_cache = {} - client._market_data_cache_time = {} + client._opt_instrument_cache = {} + client._opt_instrument_cache_time = {} + client._opt_market_data_cache = {} + client._opt_market_data_cache_time = {} client.cache_ttl = 300 client.last_cache_cleanup = time.time() client.cache_hit_count = 0 @@ -192,10 +192,10 @@ async def test_select_best_contract(self, mock_httpx_client): async with ProjectX("testuser", "test-api-key") as client: # Initialize required attributes client.api_call_count = 0 - client._instrument_cache = {} - client._instrument_cache_time = {} - client._market_data_cache = {} - client._market_data_cache_time = {} + client._opt_instrument_cache = {} + client._opt_instrument_cache_time = {} + client._opt_market_data_cache = {} + client._opt_market_data_cache_time = {} client.cache_ttl = 300 client.last_cache_cleanup = time.time() client.cache_hit_count = 0 @@ -249,10 +249,10 @@ async def test_get_bars( async with ProjectX("testuser", "test-api-key") as client: # Initialize required attributes client.api_call_count = 0 - client._instrument_cache = {} - client._instrument_cache_time = {} - client._market_data_cache = {} - client._market_data_cache_time = {} + client._opt_instrument_cache = {} + client._opt_instrument_cache_time = {} + client._opt_market_data_cache = {} + client._opt_market_data_cache_time = {} client.cache_ttl = 300 client.last_cache_cleanup = time.time() client.cache_hit_count = 0 @@ -275,7 +275,7 @@ async def test_get_bars( # Should cache the result cache_key = "MGC_5_5_2_True" - assert cache_key in client._market_data_cache + assert cache_key in client._opt_market_data_cache @pytest.mark.asyncio async def test_get_bars_from_cache(self, mock_httpx_client, mock_auth_response): @@ -290,10 +290,10 @@ async def test_get_bars_from_cache(self, mock_httpx_client, mock_auth_response): async with ProjectX("testuser", "test-api-key") as client: # Initialize required attributes client.api_call_count = 0 - client._instrument_cache = {} - client._instrument_cache_time = {} - client._market_data_cache = {} - client._market_data_cache_time = {} + client._opt_instrument_cache = {} + client._opt_instrument_cache_time = {} + client._opt_market_data_cache = {} + client._opt_market_data_cache_time = {} client.cache_ttl = 300 client.last_cache_cleanup = time.time() client.cache_hit_count = 0 @@ -353,10 +353,10 @@ async def test_get_bars_empty_response( async with ProjectX("testuser", "test-api-key") as client: # Initialize required attributes client.api_call_count = 0 - client._instrument_cache = {} - client._instrument_cache_time = {} - client._market_data_cache = {} - client._market_data_cache_time = {} + client._opt_instrument_cache = {} + client._opt_instrument_cache_time = {} + client._opt_market_data_cache = {} + client._opt_market_data_cache_time = {} client.cache_ttl = 300 client.last_cache_cleanup = time.time() client.cache_hit_count = 0 @@ -396,10 +396,10 @@ async def test_get_bars_error_response( async with ProjectX("testuser", "test-api-key") as client: # Initialize required attributes client.api_call_count = 0 - client._instrument_cache = {} - client._instrument_cache_time = {} - client._market_data_cache = {} - client._market_data_cache_time = {} + client._opt_instrument_cache = {} + client._opt_instrument_cache_time = {} + client._opt_market_data_cache = {} + client._opt_market_data_cache_time = {} client.cache_ttl = 300 client.last_cache_cleanup = time.time() client.cache_hit_count = 0 @@ -472,10 +472,10 @@ def request_matcher(**kwargs): async with ProjectX("testuser", "test-api-key") as client: # Initialize required attributes client.api_call_count = 0 - client._instrument_cache = {} - client._instrument_cache_time = {} - client._market_data_cache = {} - client._market_data_cache_time = {} + client._opt_instrument_cache = {} + client._opt_instrument_cache_time = {} + client._opt_market_data_cache = {} + client._opt_market_data_cache_time = {} client.cache_ttl = 300 client.last_cache_cleanup = time.time() client.cache_hit_count = 0 @@ -491,5 +491,5 @@ def request_matcher(**kwargs): assert not hourly_bars.is_empty() # Different cache keys should be used - assert "MGC_30_1_4_True" in client._market_data_cache - assert "MGC_7_1_3_True" in client._market_data_cache + assert "MGC_30_1_4_True" in client._opt_market_data_cache + assert "MGC_7_1_3_True" in client._opt_market_data_cache diff --git a/tests/realtime/test_batched_handler.py b/tests/realtime/test_batched_handler.py index 5552e1b..f6f9ca1 100644 --- a/tests/realtime/test_batched_handler.py +++ b/tests/realtime/test_batched_handler.py @@ -113,11 +113,13 @@ async def process_batch(batch): for i in range(7): await handler.handle_message({"id": i}) - # Flush immediately + # Flush immediately - should interrupt the processing task await handler.flush() # Should have processed all messages - assert len(processed_batches) == 1 + assert len(processed_batches) == 1, ( + f"Expected 1 batch, got {len(processed_batches)} batches: {processed_batches}" + ) assert len(processed_batches[0]) == 7 assert handler.messages_processed == 7 @@ -190,8 +192,8 @@ async def process_batch(batch): for i in range(5): await handler.handle_message({"id": i}) - # Give a tiny bit of time for the task to start (but not complete) - await asyncio.sleep(0.01) + # Give the processing task a moment to start + await asyncio.sleep(0) # Yield control to allow task to start # Stop should flush remaining messages await handler.stop() diff --git a/tests/realtime/test_connection_management.py b/tests/realtime/test_connection_management.py index ee2c28d..4cf8e30 100644 --- a/tests/realtime/test_connection_management.py +++ b/tests/realtime/test_connection_management.py @@ -1,6 +1,6 @@ """Tests for realtime connection management.""" -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest @@ -68,7 +68,8 @@ async def test_connect_success(self, connection_mixin): "project_x_py.realtime.connection_management.HubConnectionBuilder" ) as mock_builder: mock_connection = MagicMock() - mock_connection.start = AsyncMock(return_value=True) + # Use regular Mock for synchronous start method + mock_connection.start = MagicMock(return_value=True) mock_builder.return_value.with_url.return_value.configure_logging.return_value.with_automatic_reconnect.return_value.build.return_value = mock_connection result = await mixin.connect() @@ -86,7 +87,8 @@ async def test_connect_failure(self, connection_mixin): "project_x_py.realtime.connection_management.HubConnectionBuilder" ) as mock_builder: mock_connection = MagicMock() - mock_connection.start = AsyncMock( + # Use regular Mock for synchronous start method + mock_connection.start = MagicMock( side_effect=Exception("Connection failed") ) mock_builder.return_value.with_url.return_value.configure_logging.return_value.with_automatic_reconnect.return_value.build.return_value = mock_connection @@ -100,9 +102,11 @@ async def test_disconnect(self, connection_mixin): """Test graceful disconnection.""" mixin = connection_mixin mock_user_connection = MagicMock() - mock_user_connection.stop = AsyncMock(return_value=None) + # Use regular Mock for synchronous stop method + mock_user_connection.stop = MagicMock(return_value=None) mock_market_connection = MagicMock() - mock_market_connection.stop = AsyncMock(return_value=None) + # Use regular Mock for synchronous stop method + mock_market_connection.stop = MagicMock(return_value=None) mixin.user_connection = mock_user_connection mixin.market_connection = mock_market_connection @@ -124,7 +128,8 @@ async def test_reconnect_on_connection_lost(self, connection_mixin): "project_x_py.realtime.connection_management.HubConnectionBuilder" ) as mock_builder: mock_connection = MagicMock() - mock_connection.start = AsyncMock(return_value=True) + # Use regular Mock for synchronous start method + mock_connection.start = MagicMock(return_value=True) mock_builder.return_value.with_url.return_value.configure_logging.return_value.with_automatic_reconnect.return_value.build.return_value = mock_connection # Connect initially @@ -134,7 +139,7 @@ async def test_reconnect_on_connection_lost(self, connection_mixin): await mixin.disconnect() # Reconnect - result = await mixin.connect() + await mixin.connect() # Should be able to reconnect assert mock_connection.start.called diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..b31b533 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,258 @@ +"""Unit tests for data models in project_x_py.models.""" + +from __future__ import annotations + +import pytest + +from project_x_py.models import ( + Account, + BracketOrderResponse, + Instrument, + MarketDataEvent, + Order, + OrderPlaceResponse, + OrderUpdateEvent, + Position, + PositionUpdateEvent, + ProjectXConfig, + Trade, +) + + +class TestInstrumentAndAccount: + def test_instrument_creation_defaults(self): + inst = Instrument( + id="CON.F.US.MNQ.H25", + name="MNQH25", + description="Micro Nasdaq March 2025", + tickSize=0.25, + tickValue=0.5, + activeContract=True, + ) + assert inst.id == "CON.F.US.MNQ.H25" + assert inst.symbolId is None + + def test_account_creation(self): + acct = Account( + id=101, + name="Sim-101", + balance=10000.0, + canTrade=True, + isVisible=True, + simulated=True, + ) + assert acct.name == "Sim-101" + assert acct.balance == pytest.approx(10000.0) + + +class TestOrderModel: + def make_order(self, **overrides) -> Order: + base = dict( + id=1, + accountId=10, + contractId="CON.F.US.MNQ.H25", + creationTimestamp="2024-01-01T00:00:00Z", + updateTimestamp="2024-01-01T00:00:10Z", + status=1, # OPEN + type=1, # LIMIT + side=0, # BUY + size=5, + ) + base.update(overrides) + return Order(**base) + + def test_order_state_properties(self): + o = self.make_order(status=1) # OPEN + assert o.is_open and o.is_working and not o.is_terminal + + o = self.make_order(status=6) # PENDING + assert o.is_working and not o.is_open + + o = self.make_order(status=2) # FILLED + assert o.is_filled and o.is_terminal + + o = self.make_order(status=3) # CANCELLED + assert o.is_cancelled and o.is_terminal + + o = self.make_order(status=5) # REJECTED + assert o.is_rejected and o.is_terminal + + def test_order_side_type_status_strings(self): + o = self.make_order(side=0, type=2, status=6) # BUY, MARKET, PENDING + assert o.side_str == "BUY" + assert o.type_str == "MARKET" + assert o.status_str == "PENDING" + + # Unknown values map to UNKNOWN + o = self.make_order(side=1, type=99, status=99) + assert o.side_str == "SELL" + assert o.type_str == "UNKNOWN" + assert o.status_str == "UNKNOWN" + + def test_order_fill_progress_and_remaining(self): + o = self.make_order(size=10, fillVolume=None) + assert o.filled_percent == 0.0 + assert o.remaining_size == 10 + + o = self.make_order(size=10, fillVolume=3) + assert o.filled_percent == pytest.approx(30.0) + assert o.remaining_size == 7 + + o = self.make_order(size=0, fillVolume=0) + assert o.filled_percent == 0.0 + + def test_order_symbol_extraction(self): + o = self.make_order(contractId="CON.F.US.MNQ.H25") + assert o.symbol == "MNQ" + + # Fallback: no dots + o = self.make_order(contractId="MNQH25") + assert o.symbol == "MNQH25" + + +class TestPositionModel: + def make_position(self, **overrides) -> Position: + base = dict( + id=42, + accountId=10, + contractId="CON.F.US.MGC.M25", + creationTimestamp="2024-01-01T00:00:00Z", + type=1, # LONG + size=2, + averagePrice=2050.0, + ) + base.update(overrides) + return Position(**base) + + def test_basic_properties_and_indexing(self): + p = self.make_position() + assert p.is_long and not p.is_short + assert p.direction == "LONG" + assert p["averagePrice"] == pytest.approx(2050.0) + assert p.symbol == "MGC" + + def test_short_position_helpers(self): + p = self.make_position(type=2, size=3) + assert p.is_short + assert p.direction == "SHORT" + assert p.signed_size == -3 + + def test_direction_undefined(self): + p = self.make_position(type=0) + assert p.direction == "UNDEFINED" + assert p.signed_size == 2 # stays positive + + def test_total_cost_and_unrealized_pnl(self): + p = self.make_position(size=4, averagePrice=100.0) + assert p.total_cost == pytest.approx(400.0) + + # Long: price up => positive PnL + assert p.unrealized_pnl(101.0) == pytest.approx(4.0) + # Custom tick value multiplier + assert p.unrealized_pnl(101.0, tick_value=5.0) == pytest.approx(20.0) + + # Short: price down => positive PnL + p_short = self.make_position(type=2, size=4, averagePrice=100.0) + assert p_short.unrealized_pnl(99.0) == pytest.approx(4.0) + + # Undefined => zero + p_undef = self.make_position(type=0) + assert p_undef.unrealized_pnl(999.0) == 0.0 + + def test_symbol_fallback(self): + p = self.make_position(contractId="SYNTH-ABC") + assert p.symbol == "SYNTH-ABC" + + +class TestTradeModel: + def test_trade_slots_and_attributes(self): + t = Trade( + id=7, + accountId=10, + contractId="CON.F.US.MNQ.H25", + creationTimestamp="2024-01-01T00:00:00Z", + price=5000.0, + profitAndLoss=None, + fees=2.5, + side=0, + size=1, + voided=False, + orderId=123, + ) + + # Access attributes + assert t.price == pytest.approx(5000.0) + assert t.profitAndLoss is None # half-turn trade allowed + + # __slots__ should prevent setting unknown attributes + with pytest.raises(AttributeError): + t.extra = "not-allowed" # type: ignore[attr-defined] + + +class TestBracketAndResponses: + def test_order_place_response(self): + r = OrderPlaceResponse( + orderId=555, success=True, errorCode=0, errorMessage=None + ) + assert r.success is True + assert r.orderId == 555 + + def test_bracket_order_response(self): + entry = OrderPlaceResponse( + orderId=1, success=True, errorCode=0, errorMessage=None + ) + # stop/target responses could be None if not created + br = BracketOrderResponse( + success=True, + entry_order_id=1, + stop_order_id=2, + target_order_id=3, + entry_price=100.0, + stop_loss_price=95.0, + take_profit_price=110.0, + entry_response=entry, + stop_response=None, + target_response=None, + error_message=None, + ) + assert br.success is True + assert br.entry_price == pytest.approx(100.0) + assert br.entry_response and br.entry_response.orderId == 1 + + +class TestConfigAndEvents: + def test_projectxconfig_defaults_and_overrides(self): + cfg = ProjectXConfig() + assert cfg.api_url == "https://api.topstepx.com/api" + assert cfg.user_hub_url.endswith("/hubs/user") + assert cfg.market_hub_url.endswith("/hubs/market") + assert cfg.timezone == "America/Chicago" + assert cfg.timeout_seconds == 30 + assert cfg.retry_attempts == 3 + + # Override a few and ensure they stick + cfg = ProjectXConfig(api_url="https://custom/api", timezone="UTC") + assert cfg.api_url == "https://custom/api" + assert cfg.timezone == "UTC" + + def test_event_dataclasses(self): + oue = OrderUpdateEvent(orderId=1, status=2, fillVolume=1, updateTimestamp="t") + pue = PositionUpdateEvent( + positionId=2, + contractId="CON.F.US.MNQ.H25", + size=1, + averagePrice=10.0, + updateTimestamp="t", + ) + mde = MarketDataEvent( + contractId="CON.F.US.MNQ.H25", + lastPrice=1.0, + bid=None, + ask=1.25, + volume=10, + timestamp="t", + ) + + assert oue.status == 2 + assert pue.size == 1 and pue.contractId.endswith("MNQ.H25") + assert mde.ask == pytest.approx(1.25)