Skip to content

Commit 048a397

Browse files
authored
feat(tests): add position manager tests with fixtures for analytics, core, risk, and tracking
Co-authored-by: Genie <[email protected]> Add comprehensive testing suite for PositionManager
2 parents 1f6fd40 + 488fc5e commit 048a397

File tree

6 files changed

+183
-0
lines changed

6 files changed

+183
-0
lines changed

tests/position_manager/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Mark tests/position_manager as a package for pytest discovery.

tests/position_manager/conftest.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import pytest
2+
from unittest.mock import AsyncMock, patch
3+
4+
from project_x_py.position_manager.core import PositionManager
5+
from project_x_py.models import Position
6+
7+
@pytest.fixture
8+
async def position_manager(initialized_client, mock_positions_data):
9+
"""Fixture for PositionManager with mocked ProjectX client and open positions."""
10+
# Convert mock_positions_data dicts to Position objects
11+
positions = [Position(**data) for data in mock_positions_data]
12+
13+
# Patch search_open_positions to AsyncMock returning Position objects
14+
initialized_client.search_open_positions = AsyncMock(return_value=positions)
15+
# Optionally patch other APIs as needed for isolation
16+
17+
pm = PositionManager(initialized_client)
18+
yield pm
19+
20+
@pytest.fixture
21+
def populate_prices():
22+
"""Optional fixture to provide a price dict for positions."""
23+
return {
24+
"MGC": 1910.0,
25+
"MNQ": 14950.0,
26+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import pytest
2+
3+
from project_x_py.position_manager.analytics import AnalyticsMixin
4+
5+
@pytest.mark.asyncio
6+
async def test_calculate_position_pnl_long_short(position_manager, mock_positions_data):
7+
pm = position_manager
8+
# Long position: current_price > average_price
9+
long_pos = [p for p in await pm.get_all_positions() if p.type == 1][0]
10+
pnl_long = pm.calculate_position_pnl(long_pos, current_price=1910.0)
11+
assert pnl_long > 0
12+
13+
# Short position: current_price < average_price
14+
short_pos = [p for p in await pm.get_all_positions() if p.type == 2][0]
15+
pnl_short = pm.calculate_position_pnl(short_pos, current_price=14950.0)
16+
assert pnl_short > 0 # Short: average 15000 > 14950 = profit
17+
18+
@pytest.mark.asyncio
19+
async def test_calculate_position_pnl_with_point_value(position_manager, mock_positions_data):
20+
pm = position_manager
21+
long_pos = [p for p in await pm.get_all_positions() if p.type == 1][0]
22+
# Use point_value scaling
23+
pnl = pm.calculate_position_pnl(long_pos, current_price=1910.0, point_value=2.0)
24+
# Should be double the default
25+
base = pm.calculate_position_pnl(long_pos, current_price=1910.0)
26+
assert abs(pnl - base * 2.0) < 1e-6
27+
28+
@pytest.mark.asyncio
29+
async def test_calculate_portfolio_pnl(position_manager, populate_prices):
30+
pm = position_manager
31+
await pm.get_all_positions()
32+
prices = populate_prices
33+
total_pnl, positions_with_prices = pm.calculate_portfolio_pnl(prices)
34+
# MGC: long, size=1, avg=1900, price=1910 => +10;
35+
# MNQ: short, size=2, avg=15000, price=14950 => (15000-14950)*2=+100
36+
assert abs(total_pnl - 110.0) < 1e-3
37+
assert positions_with_prices == 2
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import pytest
2+
from unittest.mock import AsyncMock, patch
3+
4+
import asyncio
5+
6+
@pytest.mark.asyncio
7+
async def test_get_all_positions_updates_stats(position_manager, mock_positions_data):
8+
pm = position_manager
9+
result = await pm.get_all_positions()
10+
assert len(result) == len(mock_positions_data)
11+
assert pm.stats["positions_tracked"] == len(mock_positions_data)
12+
assert set(pm.tracked_positions.keys()) == {d["contractId"] for d in mock_positions_data}
13+
14+
@pytest.mark.asyncio
15+
async def test_get_position_cache_vs_api(position_manager):
16+
pm = position_manager
17+
18+
# a) Realtime disabled: should call API
19+
pm._realtime_enabled = False
20+
with patch.object(pm.project_x, "search_open_positions", wraps=pm.project_x.search_open_positions) as mock_search:
21+
pos = await pm.get_position("MGC")
22+
assert pos.id
23+
mock_search.assert_called_once()
24+
25+
# b) Realtime enabled: should use cache only
26+
pm._realtime_enabled = True
27+
# Prepopulate cache
28+
mgc_pos = await pm.get_position("MGC")
29+
pm.tracked_positions["MGC"] = mgc_pos
30+
with patch.object(pm.project_x, "search_open_positions", side_effect=Exception("Should not be called")):
31+
pos2 = await pm.get_position("MGC")
32+
assert pos2 is pm.tracked_positions["MGC"]
33+
34+
@pytest.mark.asyncio
35+
async def test_is_position_open(position_manager):
36+
pm = position_manager
37+
await pm.get_all_positions()
38+
assert pm.is_position_open("MGC") is True
39+
assert pm.is_position_open("UNKNOWN") is False
40+
# Simulate closed size
41+
pm.tracked_positions["MGC"].size = 0
42+
assert pm.is_position_open("MGC") is False
43+
44+
@pytest.mark.asyncio
45+
async def test_refresh_positions(position_manager):
46+
pm = position_manager
47+
prev_stats = dict(pm.stats)
48+
changed = await pm.refresh_positions()
49+
assert changed is True
50+
assert pm.stats["positions_tracked"] == len(pm.tracked_positions)
51+
52+
@pytest.mark.asyncio
53+
async def test_cleanup(position_manager):
54+
pm = position_manager
55+
# Prepopulate tracked_positions and position_alerts
56+
await pm.get_all_positions()
57+
pm.position_alerts = {"foo": "bar"}
58+
pm.order_manager = object()
59+
pm._order_sync_enabled = True
60+
61+
await pm.cleanup()
62+
assert pm.tracked_positions == {}
63+
assert pm.position_alerts == {}
64+
assert pm.order_manager is None
65+
assert pm._order_sync_enabled is False
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import pytest
2+
3+
@pytest.mark.asyncio
4+
async def test_get_risk_metrics_basic(position_manager, mock_positions_data):
5+
pm = position_manager
6+
await pm.get_all_positions()
7+
metrics = await pm.get_risk_metrics()
8+
9+
# Compute expected total_exposure, num_contracts, diversification_score
10+
expected_total_exposure = sum(abs(d["size"]) for d in mock_positions_data)
11+
expected_num_contracts = len(set(d["contractId"] for d in mock_positions_data))
12+
# Diversification: 0 if only 1 contract, up to 1.0 for max diversity
13+
expected_diversification = (expected_num_contracts - 1) / (expected_num_contracts or 1)
14+
assert abs(metrics["total_exposure"] - expected_total_exposure) < 1e-3
15+
assert metrics["num_contracts"] == expected_num_contracts
16+
assert abs(metrics["diversification_score"] - expected_diversification) < 1e-3
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import pytest
2+
from unittest.mock import AsyncMock
3+
4+
@pytest.mark.asyncio
5+
async def test_validate_position_payload_valid_invalid(position_manager, mock_positions_data):
6+
pm = position_manager
7+
valid = pm._validate_position_payload(mock_positions_data[0])
8+
assert valid is True
9+
10+
# Missing required field
11+
invalid = dict(mock_positions_data[0])
12+
invalid.pop("contractId")
13+
assert pm._validate_position_payload(invalid) is False
14+
15+
# Invalid type
16+
invalid2 = dict(mock_positions_data[0])
17+
invalid2["size"] = "not_a_number"
18+
assert pm._validate_position_payload(invalid2) is False
19+
20+
@pytest.mark.asyncio
21+
async def test_process_position_data_open_and_close(position_manager, mock_positions_data):
22+
pm = position_manager
23+
# Patch callback
24+
pm._trigger_callbacks = AsyncMock()
25+
position_data = dict(mock_positions_data[0])
26+
27+
# Open/update
28+
await pm._process_position_data(position_data)
29+
key = position_data["contractId"]
30+
assert key in pm.tracked_positions
31+
32+
# Close
33+
closure_data = dict(position_data)
34+
closure_data["size"] = 0
35+
await pm._process_position_data(closure_data)
36+
assert key not in pm.tracked_positions
37+
assert pm.stats["positions_closed"] == 1
38+
pm._trigger_callbacks.assert_any_call("position_closed", closure_data)

0 commit comments

Comments
 (0)