Skip to content

Commit 09afc32

Browse files
committed
Add checkpoint tests
1 parent b9a7667 commit 09afc32

File tree

1 file changed

+327
-0
lines changed

1 file changed

+327
-0
lines changed
Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
import os
2+
import time
3+
from itertools import product
4+
import pandas as pd
5+
from datetime import datetime, timedelta, timezone
6+
from unittest import TestCase
7+
from typing import Dict, Any
8+
9+
from pyindicators import ema, rsi, crossover, crossunder
10+
11+
from investing_algorithm_framework import TradingStrategy, DataSource, \
12+
TimeUnit, DataType, create_app, BacktestDateRange, PositionSize, \
13+
TradeStatus, RESOURCE_DIRECTORY, SnapshotInterval, generate_strategy_id
14+
15+
16+
class RSIEMACrossoverStrategy(TradingStrategy):
17+
time_unit = TimeUnit.HOUR
18+
interval = 2
19+
20+
def __init__(
21+
self,
22+
id,
23+
symbols,
24+
position_sizes,
25+
time_unit: TimeUnit,
26+
interval: int,
27+
market: str,
28+
rsi_time_frame: str,
29+
rsi_period: int,
30+
rsi_overbought_threshold,
31+
rsi_oversold_threshold,
32+
ema_time_frame,
33+
ema_short_period,
34+
ema_long_period,
35+
ema_cross_lookback_window: int = 10
36+
):
37+
self.rsi_time_frame = rsi_time_frame
38+
self.rsi_period = rsi_period
39+
self.rsi_result_column = f"rsi_{self.rsi_period}"
40+
self.rsi_overbought_threshold = rsi_overbought_threshold
41+
self.rsi_oversold_threshold = rsi_oversold_threshold
42+
self.ema_time_frame = ema_time_frame
43+
self.ema_short_result_column = f"ema_{ema_short_period}"
44+
self.ema_long_result_column = f"ema_{ema_long_period}"
45+
self.ema_crossunder_result_column = "ema_crossunder"
46+
self.ema_crossover_result_column = "ema_crossover"
47+
self.ema_short_period = ema_short_period
48+
self.ema_long_period = ema_long_period
49+
self.ema_cross_lookback_window = ema_cross_lookback_window
50+
data_sources = []
51+
52+
super().__init__(
53+
id=id,
54+
data_sources=data_sources,
55+
time_unit=time_unit,
56+
interval=interval,
57+
symbols=symbols,
58+
position_sizes=position_sizes
59+
)
60+
61+
for symbol in self.symbols:
62+
full_symbol = f"{symbol}/EUR"
63+
data_sources.append(
64+
DataSource(
65+
identifier=f"{symbol}_rsi_data",
66+
data_type=DataType.OHLCV,
67+
time_frame=self.rsi_time_frame,
68+
market=market,
69+
symbol=full_symbol,
70+
pandas=True
71+
)
72+
)
73+
data_sources.append(
74+
DataSource(
75+
identifier=f"{symbol}_ema_data",
76+
data_type=DataType.OHLCV,
77+
time_frame=self.ema_time_frame,
78+
market=market,
79+
symbol=full_symbol,
80+
pandas=True
81+
)
82+
)
83+
84+
def prepare_indicators(
85+
self,
86+
rsi_data,
87+
ema_data
88+
):
89+
ema_data = ema(
90+
ema_data,
91+
period=self.ema_short_period,
92+
source_column="Close",
93+
result_column=self.ema_short_result_column
94+
)
95+
ema_data = ema(
96+
ema_data,
97+
period=self.ema_long_period,
98+
source_column="Close",
99+
result_column=self.ema_long_result_column
100+
)
101+
# Detect crossover (short EMA crosses above long EMA)
102+
ema_data = crossover(
103+
ema_data,
104+
first_column=self.ema_short_result_column,
105+
second_column=self.ema_long_result_column,
106+
result_column=self.ema_crossover_result_column
107+
)
108+
# Detect crossunder (short EMA crosses below long EMA)
109+
ema_data = crossunder(
110+
ema_data,
111+
first_column=self.ema_short_result_column,
112+
second_column=self.ema_long_result_column,
113+
result_column=self.ema_crossunder_result_column
114+
)
115+
rsi_data = rsi(
116+
rsi_data,
117+
period=self.rsi_period,
118+
source_column="Close",
119+
result_column=self.rsi_result_column
120+
)
121+
122+
return ema_data, rsi_data
123+
124+
def generate_buy_signals(self, data: Dict[str, Any]) -> Dict[str, pd.Series]:
125+
"""
126+
Generate buy signals based on the moving average crossover.
127+
128+
data (Dict[str, Any]): Dictionary containing all the data for
129+
the strategy data sources.
130+
131+
Returns:
132+
Dict[str, pd.Series]: A dictionary where keys are symbols and values
133+
are pandas Series indicating buy signals (True/False).
134+
"""
135+
136+
signals = {}
137+
for symbol in self.symbols:
138+
ema_data_identifier = f"{symbol}_ema_data"
139+
rsi_data_identifier = f"{symbol}_rsi_data"
140+
ema_data, rsi_data = self.prepare_indicators(
141+
data[ema_data_identifier].copy(),
142+
data[rsi_data_identifier].copy()
143+
)
144+
145+
# crossover confirmed
146+
ema_crossover_lookback = ema_data[
147+
self.ema_crossover_result_column].rolling(
148+
window=self.ema_cross_lookback_window
149+
).max().astype(bool)
150+
151+
# use only RSI column
152+
rsi_oversold = rsi_data[self.rsi_result_column] \
153+
< self.rsi_oversold_threshold
154+
155+
# Combine both conditions
156+
buy_signal = rsi_oversold & ema_crossover_lookback
157+
buy_signals = buy_signal.fillna(False).astype(bool)
158+
signals[symbol] = buy_signals
159+
return signals
160+
161+
def generate_sell_signals(self, data: Dict[str, Any]) -> Dict[str, pd.Series]:
162+
"""
163+
Generate sell signals based on the moving average crossover.
164+
165+
Args:
166+
data (Dict[str, Any]): Dictionary containing all the data for
167+
the strategy data sources.
168+
169+
Returns:
170+
Dict[str, pd.Series]: A dictionary where keys are symbols and values
171+
are pandas Series indicating sell signals (True/False).
172+
"""
173+
174+
signals = {}
175+
for symbol in self.symbols:
176+
ema_data_identifier = f"{symbol}_ema_data"
177+
rsi_data_identifier = f"{symbol}_rsi_data"
178+
ema_data, rsi_data = self.prepare_indicators(
179+
data[ema_data_identifier].copy(),
180+
data[rsi_data_identifier].copy()
181+
)
182+
183+
# Confirmed by crossover between short-term EMA and long-term EMA
184+
# within a given lookback window
185+
ema_crossunder_lookback = ema_data[
186+
self.ema_crossunder_result_column].rolling(
187+
window=self.ema_cross_lookback_window
188+
).max().astype(bool)
189+
190+
# use only RSI column
191+
rsi_overbought = rsi_data[self.rsi_result_column] \
192+
>= self.rsi_overbought_threshold
193+
194+
# Combine both conditions
195+
sell_signal = rsi_overbought & ema_crossunder_lookback
196+
sell_signal = sell_signal.fillna(False).astype(bool)
197+
signals[symbol] = sell_signal
198+
return signals
199+
200+
class Test(TestCase):
201+
202+
@staticmethod
203+
def filter_function_with_closed_trades(
204+
backtests, backtest_date_range: BacktestDateRange
205+
):
206+
"""
207+
Filter function that only keeps backtests with at least one closed trade.
208+
"""
209+
filtered = []
210+
for backtest in backtests:
211+
metrics = backtest.get_backtest_metrics(backtest_date_range)
212+
if metrics.number_of_trades_closed > 0:
213+
filtered.append(backtest)
214+
215+
return filtered
216+
217+
def test_run_with_filter_function(self):
218+
"""
219+
Test run_vector_backtests with a filter_function that filters
220+
strategies based on whether they have closed trades.
221+
"""
222+
param_grid = {
223+
"rsi_time_frame": ["2h"],
224+
"rsi_period": [14],
225+
"rsi_overbought_threshold": [70, 80],
226+
"rsi_oversold_threshold": [30, 20],
227+
"ema_time_frame": ["2h"],
228+
"ema_short_period": [50, 100],
229+
"ema_long_period": [150, 200],
230+
"ema_cross_lookback_window": [2, 4, 6, 12]
231+
}
232+
233+
param_options = param_grid
234+
param_variations = [
235+
dict(zip(param_options.keys(), values))
236+
for values in product(*param_options.values())
237+
]
238+
print(
239+
f"Total parameter combinations to evaluate: {len(param_variations)}")
240+
241+
# RESOURCE_DIRECTORY should always point to the parent directory/resources
242+
# Resource directory should point to /tests/resources
243+
# Resource directory is two levels up from the current file
244+
resource_directory = os.path.join(
245+
os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'resources'
246+
)
247+
config = {RESOURCE_DIRECTORY: resource_directory}
248+
app = create_app(name="GoldenCrossStrategy", config=config)
249+
app.add_market(market="BITVAVO", trading_symbol="EUR", initial_balance=400)
250+
end_date = datetime(2025, 12, 2, tzinfo=timezone.utc)
251+
start_date = end_date - timedelta(days=1095)
252+
253+
# Split into multiple date ranges to test progressive filtering
254+
mid_date = start_date + timedelta(days=365)
255+
date_range_1 = BacktestDateRange(
256+
start_date=start_date, end_date=end_date, name="Period 1"
257+
)
258+
date_range_2 = BacktestDateRange(
259+
start_date=mid_date, end_date=end_date, name="Period 2"
260+
)
261+
strategies = []
262+
for param_set in param_variations:
263+
strategies.append(
264+
RSIEMACrossoverStrategy(
265+
id=generate_strategy_id(param_set),
266+
time_unit=TimeUnit.HOUR,
267+
interval=2,
268+
market="BITVAVO",
269+
rsi_time_frame=param_set["rsi_time_frame"],
270+
rsi_period=param_set["rsi_period"],
271+
rsi_overbought_threshold=param_set[
272+
"rsi_overbought_threshold"
273+
],
274+
rsi_oversold_threshold=param_set[
275+
"rsi_oversold_threshold"
276+
],
277+
ema_time_frame=param_set["ema_time_frame"],
278+
ema_short_period=param_set["ema_short_period"],
279+
ema_long_period=param_set["ema_long_period"],
280+
ema_cross_lookback_window=param_set[
281+
"ema_cross_lookback_window"
282+
],
283+
symbols=[
284+
"BTC",
285+
"ETH"
286+
],
287+
position_sizes=[
288+
PositionSize(
289+
symbol="BTC", percentage_of_portfolio=20.0
290+
),
291+
PositionSize(
292+
symbol="ETH", percentage_of_portfolio=20.0
293+
)
294+
]
295+
)
296+
)
297+
298+
start_time = time.time()
299+
backtests = app.run_vector_backtests(
300+
initial_amount=1000,
301+
backtest_date_ranges=[date_range_1, date_range_2],
302+
strategies=strategies,
303+
snapshot_interval=SnapshotInterval.DAILY,
304+
risk_free_rate=0.027,
305+
trading_symbol="EUR",
306+
market="BITVAVO",
307+
# filter_function=self.filter_function_with_closed_trades,
308+
backtest_storage_directory=os.path.join(
309+
resource_directory, "backtest_reports_for_testing"
310+
),
311+
use_checkpoints=True,
312+
)
313+
end_time = time.time()
314+
duration = end_time - start_time
315+
316+
# Duration must be less than 300 seconds (5 minutes)
317+
# Each backtest should have atleast 2 backtest runs (one for each date range)
318+
for backtest in backtests:
319+
self.assertGreaterEqual(
320+
len(backtest.get_all_backtest_runs()), 2,
321+
"Each backtest should have at least 2 backtest runs"
322+
)
323+
324+
# # Should have fewer backtests than strategies if filter worked
325+
# self.assertLessEqual(len(backtests), len(strategies))
326+
327+

0 commit comments

Comments
 (0)