|
4 | 4 | import sys |
5 | 5 | import time |
6 | 6 | import unittest |
| 7 | +import warnings |
7 | 8 | from concurrent.futures.process import ProcessPoolExecutor |
8 | 9 | from contextlib import contextmanager |
9 | 10 | from glob import glob |
@@ -982,6 +983,50 @@ def test_MultiBacktest(self): |
982 | 983 | print(start_method, time.monotonic() - start_time) |
983 | 984 | plot_heatmaps(heatmap.mean(axis=1), open_browser=False) |
984 | 985 |
|
| 986 | + def test_MultiBacktest_keeps_zero_trade_runs(self): |
| 987 | + datasets = [GOOG[:-4], GOOG[:-3], GOOG[:-2], GOOG[:-1], GOOG] |
| 988 | + cases = { |
| 989 | + 'all_false': ([False, False, False, False, False], [0, 0, 0, 0, 0]), |
| 990 | + 'first_true_rest_false': ([True, False, False, False, False], [1, 0, 0, 0, 0]), |
| 991 | + 'first_false_second_true': ([False, True, False, False, False], [0, 1, 0, 0, 0]), |
| 992 | + } |
| 993 | + |
| 994 | + for name, (will_buys, expected_trades) in cases.items(): |
| 995 | + class TestStrat(Strategy): |
| 996 | + def init(self): |
| 997 | + self.will_buy = will_buys[len(self.data.index) - 2144] |
| 998 | + self.has_bought = False |
| 999 | + |
| 1000 | + def next(self): |
| 1001 | + if not self.will_buy: |
| 1002 | + return |
| 1003 | + if self.position: |
| 1004 | + self.position.close() |
| 1005 | + if not self.has_bought: |
| 1006 | + self.buy() |
| 1007 | + self.has_bought = True |
| 1008 | + |
| 1009 | + with self.subTest(case=name), warnings.catch_warnings(): |
| 1010 | + warnings.filterwarnings( |
| 1011 | + 'ignore', |
| 1012 | + message='If you want to use multi-process optimization', |
| 1013 | + category=RuntimeWarning, |
| 1014 | + ) |
| 1015 | + result = MultiBacktest( |
| 1016 | + datasets, |
| 1017 | + TestStrat, |
| 1018 | + cash=10_000, |
| 1019 | + commission=.002, |
| 1020 | + exclusive_orders=True, |
| 1021 | + ).run() |
| 1022 | + |
| 1023 | + self.assertIsInstance(result, pd.DataFrame) |
| 1024 | + self.assertEqual(result.columns.tolist(), [0, 1, 2, 3, 4]) |
| 1025 | + self.assertIn('# Trades', result.index) |
| 1026 | + self.assertEqual(result.loc['# Trades'].astype(int).tolist(), expected_trades) |
| 1027 | + self.assertFalse(any(isinstance(value, pd.Series) for value in result.to_numpy().ravel())) |
| 1028 | + self.assertIn('Equity Final [$]', result.index) |
| 1029 | + |
985 | 1030 |
|
986 | 1031 | class TestUtil(TestCase): |
987 | 1032 | def test_as_str(self): |
|
0 commit comments