Skip to content

Commit 4ab4a34

Browse files
TexasCodingclaude
andcommitted
test: achieve 96% coverage for risk_manager module with comprehensive test suite
- Add 132 new tests covering all edge cases and error paths - Remove 5 obsolete skipped tests for non-existent methods - Fix all test fixtures to handle async task initialization - Create comprehensive test files: * test_additional_coverage.py: 20 tests for uncovered lines * test_core_error_paths.py: 21 tests for error handling * test_core_trailing_stops.py: 25 tests for trailing stops * test_managed_trade_comprehensive.py: 51 tests for ManagedTrade * test_managed_trade_edge_cases.py: 15 tests for concurrency - Improve coverage from 74% to 96% (exceeding 90% target) - All 276 tests passing with 100% success rate 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent a32e367 commit 4ab4a34

File tree

6 files changed

+3652
-107
lines changed

6 files changed

+3652
-107
lines changed
Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
"""Additional tests for risk manager to improve coverage."""
2+
3+
import asyncio
4+
from unittest.mock import AsyncMock, MagicMock, patch
5+
6+
import pytest
7+
8+
from project_x_py.risk_manager.core import RiskManager
9+
from project_x_py.risk_manager.managed_trade import ManagedTrade
10+
from project_x_py.types import OrderSide
11+
12+
13+
@pytest.mark.asyncio
14+
class TestAdditionalCoverage:
15+
"""Additional tests to improve coverage."""
16+
17+
@pytest.fixture
18+
def setup_risk_manager(self):
19+
"""Create a RiskManager for testing."""
20+
mock_client = AsyncMock()
21+
mock_order_manager = AsyncMock()
22+
mock_position_manager = AsyncMock()
23+
mock_event_bus = AsyncMock()
24+
25+
# Mock the async task creation to avoid runtime errors
26+
with patch('asyncio.create_task'):
27+
risk_manager = RiskManager(
28+
project_x=mock_client,
29+
order_manager=mock_order_manager,
30+
event_bus=mock_event_bus,
31+
)
32+
risk_manager.set_position_manager(mock_position_manager)
33+
# Set the _init_task to a mock to avoid issues
34+
risk_manager._init_task = MagicMock()
35+
36+
return {
37+
"risk_manager": risk_manager,
38+
"client": mock_client,
39+
"order_manager": mock_order_manager,
40+
"position_manager": mock_position_manager,
41+
"event_bus": mock_event_bus,
42+
}
43+
44+
@pytest.fixture
45+
def setup_managed_trade(self):
46+
"""Create a ManagedTrade for testing."""
47+
mock_client = AsyncMock()
48+
mock_order_manager = AsyncMock()
49+
mock_position_manager = AsyncMock()
50+
mock_event_bus = AsyncMock()
51+
mock_data_manager = AsyncMock()
52+
53+
# Mock the async task creation to avoid runtime errors
54+
with patch('asyncio.create_task'):
55+
risk_manager = RiskManager(
56+
project_x=mock_client,
57+
order_manager=mock_order_manager,
58+
event_bus=mock_event_bus,
59+
)
60+
risk_manager.set_position_manager(mock_position_manager)
61+
risk_manager._init_task = MagicMock()
62+
63+
managed_trade = ManagedTrade(
64+
risk_manager=risk_manager,
65+
order_manager=mock_order_manager,
66+
position_manager=mock_position_manager,
67+
instrument_id="MNQ",
68+
data_manager=mock_data_manager,
69+
event_bus=mock_event_bus,
70+
)
71+
72+
return {
73+
"trade": managed_trade,
74+
"risk_manager": risk_manager,
75+
"order_manager": mock_order_manager,
76+
"position_manager": mock_position_manager,
77+
"event_bus": mock_event_bus,
78+
"data_manager": mock_data_manager,
79+
"client": mock_client,
80+
}
81+
82+
async def test_managed_trade_property_aliases(self, setup_managed_trade):
83+
"""Test ManagedTrade property aliases."""
84+
mocks = setup_managed_trade
85+
trade = mocks["trade"]
86+
87+
# Test property aliases
88+
assert trade.risk_manager is trade.risk
89+
assert trade.order_manager is trade.orders
90+
assert trade.position_manager is trade.positions
91+
92+
async def test_managed_trade_concurrent_entries_prevention(self, setup_managed_trade):
93+
"""Test that concurrent entries are prevented."""
94+
mocks = setup_managed_trade
95+
trade = mocks["trade"]
96+
97+
# Set existing entry order
98+
entry_order = MagicMock()
99+
entry_order.id = 123
100+
trade._entry_order = entry_order
101+
102+
# Should prevent concurrent entries
103+
with pytest.raises(ValueError, match="Trade already has entry order"):
104+
await trade.enter_long()
105+
106+
async def test_managed_trade_no_data_manager_market_price(self, setup_managed_trade):
107+
"""Test market price retrieval without data manager."""
108+
mocks = setup_managed_trade
109+
trade = mocks["trade"]
110+
trade.data_manager = None
111+
112+
with pytest.raises(RuntimeError, match="No data manager available"):
113+
await trade._get_market_price()
114+
115+
async def test_managed_trade_market_price_fallback(self, setup_managed_trade):
116+
"""Test market price fallback to current price."""
117+
mocks = setup_managed_trade
118+
trade = mocks["trade"]
119+
120+
# Mock data retrieval to fail, fallback to current price
121+
mocks["data_manager"].get_data = AsyncMock(return_value=None)
122+
mocks["data_manager"].get_current_price = AsyncMock(return_value=20000.0)
123+
124+
price = await trade._get_market_price()
125+
assert price == 20000.0
126+
127+
async def test_managed_trade_is_filled_states(self, setup_managed_trade):
128+
"""Test is_filled method with various states."""
129+
mocks = setup_managed_trade
130+
trade = mocks["trade"]
131+
132+
# No entry order
133+
assert not trade.is_filled()
134+
135+
# Entry order not filled
136+
entry_order = MagicMock()
137+
entry_order.status = 1 # Working
138+
trade._entry_order = entry_order
139+
assert not trade.is_filled()
140+
141+
# Entry order filled
142+
entry_order.status = 2 # Filled
143+
entry_order.filled_quantity = 2
144+
entry_order.size = 2
145+
assert trade.is_filled()
146+
147+
async def test_managed_trade_emergency_exit(self, setup_managed_trade):
148+
"""Test emergency exit functionality."""
149+
mocks = setup_managed_trade
150+
trade = mocks["trade"]
151+
152+
# Set up orders and position
153+
order1 = MagicMock()
154+
order1.id = 123
155+
trade._orders = [order1]
156+
157+
position = MagicMock()
158+
trade._positions = [position]
159+
160+
# Mock close position
161+
trade.close_position = AsyncMock(return_value={"success": True})
162+
163+
result = await trade.emergency_exit()
164+
assert result is True
165+
166+
async def test_risk_manager_extract_symbol(self, setup_risk_manager):
167+
"""Test symbol extraction from contract ID."""
168+
mocks = setup_risk_manager
169+
rm = mocks["risk_manager"]
170+
171+
# Normal contract ID
172+
symbol = rm._extract_symbol("CON.F.US.MNQ.U24")
173+
assert symbol == "MNQ"
174+
175+
# Short contract ID
176+
symbol = rm._extract_symbol("MNQ")
177+
assert symbol == "MNQ"
178+
179+
async def test_risk_manager_kelly_fraction_edge_cases(self, setup_risk_manager):
180+
"""Test Kelly fraction calculation edge cases."""
181+
mocks = setup_risk_manager
182+
rm = mocks["risk_manager"]
183+
184+
# Zero win rate
185+
rm._win_rate = 0.0
186+
kelly = rm._calculate_kelly_fraction()
187+
assert kelly == 0.0
188+
189+
# Zero average loss
190+
rm._win_rate = 0.6
191+
rm._avg_loss = 0.0
192+
kelly = rm._calculate_kelly_fraction()
193+
assert kelly == 0.0
194+
195+
async def test_risk_manager_calculate_stop_loss_fallback(self, setup_risk_manager):
196+
"""Test stop loss calculation fallback."""
197+
mocks = setup_risk_manager
198+
rm = mocks["risk_manager"]
199+
200+
# Unknown stop loss type
201+
rm.config.stop_loss_type = "unknown"
202+
203+
stop_price = await rm.calculate_stop_loss(20000.0, OrderSide.BUY)
204+
assert stop_price == 19950.0 # fallback
205+
206+
async def test_risk_manager_should_activate_trailing_stop_disabled(self, setup_risk_manager):
207+
"""Test trailing stop activation when disabled."""
208+
mocks = setup_risk_manager
209+
rm = mocks["risk_manager"]
210+
211+
rm.config.use_trailing_stops = False
212+
result = await rm.should_activate_trailing_stop(20000.0, 20100.0, OrderSide.BUY)
213+
assert result is False
214+
215+
async def test_risk_manager_portfolio_risk_empty_positions(self, setup_risk_manager):
216+
"""Test portfolio risk calculation with empty positions."""
217+
mocks = setup_risk_manager
218+
rm = mocks["risk_manager"]
219+
220+
risk = await rm._calculate_portfolio_risk([])
221+
assert risk == 0.0
222+
223+
async def test_risk_manager_sharpe_ratio_edge_cases(self, setup_risk_manager):
224+
"""Test Sharpe ratio calculation edge cases."""
225+
mocks = setup_risk_manager
226+
rm = mocks["risk_manager"]
227+
228+
# Empty history
229+
rm._trade_history.clear()
230+
sharpe = rm._calculate_sharpe_ratio()
231+
assert sharpe == 0.0
232+
233+
# Single trade
234+
rm._trade_history.append({"pnl": 100.0})
235+
sharpe = rm._calculate_sharpe_ratio()
236+
assert sharpe == 0.0
237+
238+
# Zero standard deviation
239+
rm._trade_history.clear()
240+
for _ in range(5):
241+
rm._trade_history.append({"pnl": 100.0})
242+
sharpe = rm._calculate_sharpe_ratio()
243+
assert sharpe == 0.0
244+
245+
async def test_managed_trade_wait_for_fill_no_event_bus(self, setup_managed_trade):
246+
"""Test wait for fill without event bus."""
247+
mocks = setup_managed_trade
248+
trade = mocks["trade"]
249+
trade.event_bus = None
250+
251+
order = MagicMock()
252+
order.id = 123
253+
254+
trade._poll_for_order_fill = AsyncMock(return_value=True)
255+
256+
result = await trade._wait_for_order_fill(order, timeout_seconds=1)
257+
assert result is True
258+
259+
async def test_managed_trade_poll_for_fill_exception_handling(self, setup_managed_trade):
260+
"""Test polling with exception handling."""
261+
mocks = setup_managed_trade
262+
trade = mocks["trade"]
263+
264+
order = MagicMock()
265+
order.id = 123
266+
267+
# First call raises exception, second succeeds
268+
call_count = 0
269+
def mock_search_with_exception():
270+
nonlocal call_count
271+
call_count += 1
272+
if call_count == 1:
273+
raise Exception("Network error")
274+
else:
275+
updated_order = MagicMock()
276+
updated_order.id = 123
277+
updated_order.is_filled = True
278+
return [updated_order]
279+
280+
mocks["order_manager"].search_open_orders = AsyncMock(side_effect=mock_search_with_exception)
281+
mocks["position_manager"].get_all_positions = AsyncMock(return_value=[])
282+
283+
result = await trade._poll_for_order_fill(order, timeout_seconds=2)
284+
assert result is True
285+
286+
async def test_managed_trade_adjust_stop_loss_failure(self, setup_managed_trade):
287+
"""Test adjust stop loss failure."""
288+
mocks = setup_managed_trade
289+
trade = mocks["trade"]
290+
291+
# No stop order
292+
result = await trade.adjust_stop_loss(19940.0)
293+
assert result is False
294+
295+
# With stop order but modify fails
296+
stop_order = MagicMock()
297+
stop_order.id = 123
298+
trade._stop_order = stop_order
299+
300+
mocks["order_manager"].modify_order = AsyncMock(side_effect=Exception("Modify failed"))
301+
result = await trade.adjust_stop_loss(19940.0)
302+
assert result is False
303+
304+
async def test_risk_manager_get_account_info_no_accounts(self, setup_risk_manager):
305+
"""Test get account info when no accounts found."""
306+
mocks = setup_risk_manager
307+
rm = mocks["risk_manager"]
308+
309+
mocks["client"].list_accounts = AsyncMock(return_value=[])
310+
311+
with pytest.raises(ValueError, match="No account found"):
312+
await rm._get_account_info()
313+
314+
async def test_risk_manager_trading_hours_validation(self, setup_risk_manager):
315+
"""Test trading hours validation."""
316+
mocks = setup_risk_manager
317+
rm = mocks["risk_manager"]
318+
319+
# Disabled restriction
320+
rm.config.restrict_trading_hours = False
321+
assert rm._is_within_trading_hours() is True
322+
323+
async def test_risk_manager_memory_stats_error(self, setup_risk_manager):
324+
"""Test memory stats error handling."""
325+
mocks = setup_risk_manager
326+
rm = mocks["risk_manager"]
327+
328+
# Force error
329+
rm._trade_history = None
330+
331+
stats = rm.get_memory_stats()
332+
assert "error_code" in stats
333+
assert stats["error_code"] == 1.0
334+
335+
async def test_managed_trade_calculate_position_size_with_overrides(self, setup_managed_trade):
336+
"""Test position size calculation with risk overrides."""
337+
mocks = setup_managed_trade
338+
trade = mocks["trade"]
339+
340+
trade.max_risk_percent = 0.02
341+
trade.max_risk_amount = 1000.0
342+
343+
mocks["risk_manager"].calculate_position_size = AsyncMock(
344+
return_value={"position_size": 2}
345+
)
346+
347+
result = await trade.calculate_position_size(
348+
entry_price=20000.0,
349+
stop_loss=19950.0
350+
)
351+
352+
assert result == 2
353+
mocks["risk_manager"].calculate_position_size.assert_called_once_with(
354+
entry_price=20000.0,
355+
stop_loss=19950.0,
356+
risk_percent=0.02,
357+
risk_amount=1000.0,
358+
)
359+
360+
async def test_managed_trade_get_account_balance_methods(self, setup_managed_trade):
361+
"""Test get account balance methods."""
362+
mocks = setup_managed_trade
363+
trade = mocks["trade"]
364+
365+
# Successful retrieval
366+
account = MagicMock()
367+
account.balance = 150000.0
368+
mocks["client"].list_accounts = AsyncMock(return_value=[account])
369+
370+
balance = await trade._get_account_balance()
371+
assert balance == 150000.0
372+
373+
# No accounts
374+
mocks["client"].list_accounts = AsyncMock(return_value=[])
375+
balance = await trade._get_account_balance()
376+
assert balance == 100000.0 # Default

0 commit comments

Comments
 (0)