Skip to content

Commit b9a7667

Browse files
committed
Fix checkpointing
1 parent db1a572 commit b9a7667

File tree

3 files changed

+46
-23
lines changed

3 files changed

+46
-23
lines changed

investing_algorithm_framework/app/app.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import gc
2+
import shutil
23
import inspect
34
import logging
45
import os
@@ -1142,7 +1143,9 @@ def filter_function(
11421143
market=market,
11431144
trading_symbol=trading_symbol,
11441145
use_checkpoints=use_checkpoints,
1145-
backtest_storage_directory=None,
1146+
backtest_storage_directory=(
1147+
backtest_storage_directory
1148+
)
11461149
)
11471150
)
11481151

@@ -1200,6 +1203,9 @@ def load_backtest_filter_fn(bt: Backtest) -> bool:
12001203
backtests_ordered_by_strategy.setdefault(
12011204
backtest.metadata["id"], []
12021205
).append(backtest)
1206+
1207+
# Remove all temp storage directories
1208+
shutil.rmtree(path)
12031209
else:
12041210
# Remove all strategies that are not in the final selection
12051211
backtests_ordered_by_strategy = {
@@ -1214,6 +1220,13 @@ def load_backtest_filter_fn(bt: Backtest) -> bool:
12141220
combine_backtests(backtests_ordered_by_strategy[strategy])
12151221
)
12161222

1223+
if backtest_storage_directory is not None:
1224+
# Save final combined backtests to storage directory
1225+
save_backtests_to_directory(
1226+
backtests=backtests,
1227+
directory_path=backtest_storage_directory,
1228+
)
1229+
12171230
return backtests
12181231

12191232
def run_vector_backtest(
@@ -1339,11 +1352,12 @@ def run_vector_backtest(
13391352
backtest_date_range=backtest_date_range,
13401353
storage_directory=backtest_storage_directory,
13411354
):
1342-
backtest = backtest_service.load_backtest_by_strategy(
1343-
strategy=strategy,
1344-
backtest_date_range=backtest_date_range,
1345-
storage_directory=backtest_storage_directory,
1346-
)
1355+
backtest = backtest_service\
1356+
.load_backtest_by_strategy_and_backtest_date_range(
1357+
strategy=strategy,
1358+
backtest_date_range=backtest_date_range,
1359+
storage_directory=backtest_storage_directory,
1360+
)
13471361
else:
13481362
try:
13491363
run = backtest_service.create_vector_backtest(
@@ -1671,7 +1685,8 @@ def run_permutation_test(
16711685
snapshot_interval=SnapshotInterval.DAILY,
16721686
risk_free_rate=risk_free_rate,
16731687
market=market,
1674-
trading_symbol=trading_symbol
1688+
trading_symbol=trading_symbol,
1689+
use_checkpoints=False
16751690
)
16761691
backtest_metrics = backtest.get_backtest_metrics(backtest_date_range)
16771692

@@ -1753,7 +1768,8 @@ def run_permutation_test(
17531768
risk_free_rate=risk_free_rate,
17541769
skip_data_sources_initialization=True,
17551770
market=market,
1756-
trading_symbol=trading_symbol
1771+
trading_symbol=trading_symbol,
1772+
use_checkpoints=False
17571773
)
17581774

17591775
# Add the results of the permuted backtest to the main backtest

investing_algorithm_framework/domain/backtesting/backtest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,16 @@ def merge(self, other: 'Backtest') -> 'Backtest':
467467

468468
return merged
469469

470+
def get_metadata(self) -> Dict[str, str]:
471+
"""
472+
Get the metadata of the backtest.
473+
474+
Returns:
475+
Dict[str, str]: A dictionary containing the metadata
476+
of the backtest.
477+
"""
478+
return self.metadata
479+
470480
def get_backtest_date_ranges(self):
471481
"""
472482
Get the date ranges for the backtest.

investing_algorithm_framework/services/backtesting/backtest_service.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -672,28 +672,30 @@ def backtest_exists(
672672
backtest_directory = os.path.join(storage_directory, strategy_id)
673673

674674
if os.path.exists(backtest_directory):
675-
backtest = Backtest.load(backtest_directory)
676-
backtest_date_ranges = backtest\
677-
.get_most_granular_ohlcv_data_source()
675+
backtest = Backtest.open(backtest_directory)
676+
backtest_date_ranges = backtest.get_backtest_date_ranges()
678677

679678
for backtest_date_range_ref in backtest_date_ranges:
680679

681680
if backtest_date_range_ref.start_date \
682-
== backtest_date_range_ref.start_date and \
681+
== backtest_date_range.start_date and \
683682
backtest_date_range_ref.end_date \
684-
== backtest_date_range_ref.end_date:
683+
== backtest_date_range.end_date:
685684
return True
686685

687686
return False
688687

689-
def load_backtest_by_strategy(
688+
def load_backtest_by_strategy_and_backtest_date_range(
690689
self,
691690
strategy,
692691
backtest_date_range: BacktestDateRange,
693692
storage_directory: str
694693
) -> Backtest:
695694
"""
696695
Load a backtest for the given strategy and backtest date range.
696+
If the backtest does not exist, an exception will be raised.
697+
For the given backtest, only the run and metrics corresponding
698+
to the backtest date range will be returned.
697699
698700
Args:
699701
strategy: The strategy to load the backtest for.
@@ -709,14 +711,9 @@ def load_backtest_by_strategy(
709711
backtest_directory = os.path.join(storage_directory, strategy_id)
710712

711713
if os.path.exists(backtest_directory):
712-
backtest = Backtest.load(backtest_directory)
713-
run = backtest.get_run(backtest_date_range)
714-
metrics = backtest.get_metrics(backtest_date_range)
715-
return Backtest(
716-
backtest_runs=[run],
717-
backtest_summary=generate_backtest_summary_metrics(
718-
[metrics]
719-
)
720-
)
714+
backtest = Backtest.open(backtest_directory)
715+
run = backtest.get_backtest_run(backtest_date_range)
716+
metadata = backtest.get_metadata()
717+
return Backtest(backtest_runs=[run], metadata=metadata)
721718
else:
722719
raise OperationalException("Backtest does not exist.")

0 commit comments

Comments
 (0)