diff --git a/README.md b/README.md index d07b8680..d79647ed 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,19 @@ bt = Backtest(GOOG, SmaCross, commission=.002, exclusive_orders=True) stats = bt.run() bt.plot() + +# Multi-asset (same API, pass symbol->DataFrame mapping) +assets = { + 'GOOG': GOOG, + 'GOOG_HALF': GOOG.assign(Open=GOOG.Open * .5, + High=GOOG.High * .5, + Low=GOOG.Low * .5, + Close=GOOG.Close * .5) +} +btm = Backtest(assets, SmaCross, commission=.002, + exclusive_orders=True) +stats_m = btm.run() +btm.plot() ``` Results in: diff --git a/backtesting/_plotting.py b/backtesting/_plotting.py index 338454da..3988e734 100644 --- a/backtesting/_plotting.py +++ b/backtesting/_plotting.py @@ -13,10 +13,7 @@ import pandas as pd from bokeh.colors import RGB -from bokeh.colors.named import ( - lime as BULL_COLOR, - tomato as BEAR_COLOR -) +from bokeh.colors.named import lime as BULL_COLOR, tomato as BEAR_COLOR from bokeh.events import DocumentReady from bokeh.plotting import figure as _figure from bokeh.models import ( # type: ignore @@ -24,7 +21,8 @@ CustomJS, ColumnDataSource, CustomJSTransform, - Label, NumeralTickFormatter, + Label, + NumeralTickFormatter, Span, HoverTool, Range1d, @@ -32,6 +30,7 @@ WheelZoomTool, LinearColorMapper, ) + try: from bokeh.models import CustomJSTickFormatter except ImportError: # Bokeh < 3.0 @@ -44,19 +43,19 @@ from backtesting._util import _data_period, _as_list, _Indicator, try_ -with open(os.path.join(os.path.dirname(__file__), 'autoscale_cb.js'), - encoding='utf-8') as _f: +with open(os.path.join(os.path.dirname(__file__), "autoscale_cb.js"), encoding="utf-8") as _f: _AUTOSCALE_JS_CALLBACK = _f.read() -IS_JUPYTER_NOTEBOOK = ('JPY_PARENT_PID' in os.environ or - 'inline' in os.environ.get('MPLBACKEND', '')) +IS_JUPYTER_NOTEBOOK = "JPY_PARENT_PID" in os.environ or "inline" in os.environ.get("MPLBACKEND", "") if IS_JUPYTER_NOTEBOOK: - warnings.warn('Jupyter Notebook detected. ' - 'Setting Bokeh output to notebook. ' - 'This may not work in Jupyter clients without JavaScript ' - 'support, such as old IDEs. ' - 'Reset with `backtesting.set_bokeh_output(notebook=False)`.') + warnings.warn( + "Jupyter Notebook detected. " + "Setting Bokeh output to notebook. " + "This may not work in Jupyter clients without JavaScript " + "support, such as old IDEs. " + "Reset with `backtesting.set_bokeh_output(notebook=False)`." + ) output_notebook() @@ -71,16 +70,16 @@ def set_bokeh_output(notebook=False): def _windos_safe_filename(filename): - if sys.platform.startswith('win'): - return re.sub(r'[^a-zA-Z0-9,_-]', '_', filename.replace('=', '-')) + if sys.platform.startswith("win"): + return re.sub(r"[^a-zA-Z0-9,_-]", "_", filename.replace("=", "-")) return filename def _bokeh_reset(filename=None): curstate().reset() if filename: - if not filename.endswith('.html'): - filename += '.html' + if not filename.endswith(".html"): + filename += ".html" output_file(filename, title=filename) elif IS_JUPYTER_NOTEBOOK: curstate().output_notebook() @@ -88,22 +87,33 @@ def _bokeh_reset(filename=None): def _add_popcon(): - curdoc().js_on_event(DocumentReady, CustomJS(code='''(function() { var i = document.createElement('iframe'); i.style.display='none';i.width=i.height=1;i.loading='eager';i.src='https://kernc.github.io/backtesting.py/plx.gif.html?utm_source='+location.origin;document.body.appendChild(i);})();''')) # noqa: E501 + curdoc().js_on_event( + DocumentReady, + CustomJS( + code="""(function() { var i = document.createElement('iframe'); i.style.display='none';i.width=i.height=1;i.loading='eager';i.src='https://kernc.github.io/backtesting.py/plx.gif.html?utm_source='+location.origin;document.body.appendChild(i);})();""" + ), + ) # noqa: E501 def _watermark(fig: _figure): fig.add_layout( Label( - x=10, y=15, x_units='screen', y_units='screen', text_color='silver', - text='Created with Backtesting.py: http://kernc.github.io/backtesting.py', - text_alpha=.09)) + x=10, + y=15, + x_units="screen", + y_units="screen", + text_color="silver", + text="Created with Backtesting.py: http://kernc.github.io/backtesting.py", + text_alpha=0.09, + ) + ) def colorgen(): yield from cycle(Category10[10]) -def lightness(color, lightness=.94): +def lightness(color, lightness=0.94): rgb = np.array([color.r, color.g, color.b]) / 255 h, _, s = rgb_to_hls(*rgb) rgb = (np.array(hls_to_rgb(h, lightness, s)) * 255).astype(int) @@ -121,82 +131,123 @@ def _maybe_resample_data(resample_rule, df, indicators, equity_data, trades): if resample_rule is False or len(df) <= _MAX_CANDLES: return df, indicators, equity_data, trades - freq_minutes = pd.Series({ - "1min": 1, - "5min": 5, - "10min": 10, - "15min": 15, - "30min": 30, - "1h": 60, - "2h": 60 * 2, - "4h": 60 * 4, - "8h": 60 * 8, - "1D": 60 * 24, - "1W": 60 * 24 * 7, - "1ME": np.inf, - }) + freq_minutes = pd.Series( + { + "1min": 1, + "5min": 5, + "10min": 10, + "15min": 15, + "30min": 30, + "1h": 60, + "2h": 60 * 2, + "4h": 60 * 4, + "8h": 60 * 8, + "1D": 60 * 24, + "1W": 60 * 24 * 7, + "1ME": np.inf, + } + ) timespan = df.index[-1] - df.index[0] require_minutes = (timespan / _MAX_CANDLES).total_seconds() // 60 freq = freq_minutes.where(freq_minutes >= require_minutes).first_valid_index() - warnings.warn(f"Data contains too many candlesticks to plot; downsampling to {freq!r}. " - "See `Backtest.plot(resample=...)`") + warnings.warn( + f"Data contains too many candlesticks to plot; downsampling to {freq!r}. " + "See `Backtest.plot(resample=...)`" + ) from .lib import OHLCV_AGG, TRADES_AGG, _EQUITY_AGG - df = df.resample(freq, label='right').agg(OHLCV_AGG).dropna() + + df = df.resample(freq, label="right").agg(OHLCV_AGG).dropna() def try_mean_first(indicator): nonlocal freq - resampled = indicator.df.fillna(np.nan).resample(freq, label='right') + resampled = indicator.df.fillna(np.nan).resample(freq, label="right") try: return resampled.mean() except Exception: return resampled.first() - indicators = [_Indicator(try_mean_first(i).dropna().reindex(df.index).values.T, - **dict(i._opts, name=i.name, - # Replace saved index with the resampled one - index=df.index)) - for i in indicators] + indicators = [ + _Indicator( + try_mean_first(i).dropna().reindex(df.index).values.T, + **dict( + i._opts, + name=i.name, + # Replace saved index with the resampled one + index=df.index, + ), + ) + for i in indicators + ] assert not indicators or indicators[0].df.index.equals(df.index) - equity_data = equity_data.resample(freq, label='right').agg(_EQUITY_AGG).dropna(how='all') + equity_data = equity_data.resample(freq, label="right").agg(_EQUITY_AGG).dropna(how="all") assert equity_data.index.equals(df.index) def _weighted_returns(s, trades=trades): df = trades.loc[s.index] - return ((df['Size'].abs() * df['ReturnPct']) / df['Size'].abs().sum()).sum() + return ((df["Size"].abs() * df["ReturnPct"]) / df["Size"].abs().sum()).sum() def _group_trades(column): def f(s, new_index=pd.Index(df.index.astype(np.int64)), bars=trades[column]): if s.size: # Via int64 because on pandas recently broken datetime mean_time = int(bars.loc[s.index].astype(np.int64).mean()) - new_bar_idx = new_index.get_indexer([mean_time], method='nearest')[0] + new_bar_idx = new_index.get_indexer([mean_time], method="nearest")[0] return new_bar_idx + return f if len(trades): # Avoid pandas "resampling on Int64 index" error - trades = trades.assign(count=1).resample(freq, on='ExitTime', label='right').agg(dict( + trades_agg = dict( TRADES_AGG, ReturnPct=_weighted_returns, - count='sum', - EntryBar=_group_trades('EntryTime'), - ExitBar=_group_trades('ExitTime'), - )).dropna() + count="sum", + EntryBar=_group_trades("EntryTime"), + ExitBar=_group_trades("ExitTime"), + ) + if "Symbol" in trades.columns: + trades = ( + trades.assign(count=1) + .groupby("Symbol", group_keys=True) + .resample(freq, on="ExitTime", label="right") + .agg(trades_agg) + .dropna() + .reset_index(level=0) + ) + else: + trades = ( + trades.assign(count=1) + .resample(freq, on="ExitTime", label="right") + .agg(trades_agg) + .dropna() + ) return df, indicators, equity_data, trades -def plot(*, results: pd.Series, - df: pd.DataFrame, - indicators: List[_Indicator], - filename='', plot_width=None, - plot_equity=True, plot_return=False, plot_pl=True, - plot_volume=True, plot_drawdown=False, plot_trades=True, - smooth_equity=False, relative_equity=True, - superimpose=True, resample=True, - reverse_indicators=True, - show_legend=True, open_browser=True): +def plot( + *, + results: pd.Series, + df: pd.DataFrame, + assets: dict[str, pd.DataFrame] | None = None, + indicators: List[_Indicator], + filename="", + plot_width=None, + plot_equity=True, + plot_return=False, + plot_pl=True, + plot_volume=True, + plot_drawdown=False, + plot_trades=True, + smooth_equity=False, + relative_equity=True, + superimpose=True, + resample=True, + reverse_indicators=True, + show_legend=True, + open_browser=True, +): """ Like much of GUI code everywhere, this is a mess. """ @@ -208,11 +259,13 @@ def plot(*, results: pd.Series, _bokeh_reset(filename) COLORS = [BEAR_COLOR, BULL_COLOR] - BAR_WIDTH = .8 + BAR_WIDTH = 0.8 - assert df.index.equals(results['_equity_curve'].index) - equity_data = results['_equity_curve'].copy(deep=False) - trades = results['_trades'] + assert df.index.equals(results["_equity_curve"].index) + equity_data = results["_equity_curve"].copy(deep=False) + trades = results["_trades"] + is_multi_asset = bool(assets and len(assets) > 1) + asset_symbols = tuple(assets.keys()) if assets else () plot_volume = plot_volume and not df.Volume.isnull().all() plot_equity = plot_equity and not trades.empty @@ -222,85 +275,113 @@ def plot(*, results: pd.Series, is_datetime_index = isinstance(df.index, pd.DatetimeIndex) from .lib import OHLCV_AGG + # ohlc df may contain many columns. We're only interested in, and pass on to Bokeh, these df = df[list(OHLCV_AGG.keys())].copy(deep=False) # Limit data to max_candles if is_datetime_index: df, indicators, equity_data, trades = _maybe_resample_data( - resample, df, indicators, equity_data, trades) + resample, df, indicators, equity_data, trades + ) df.index.name = None # Provides source name @index - df['datetime'] = df.index # Save original, maybe datetime index + df["datetime"] = df.index # Save original, maybe datetime index df = df.reset_index(drop=True) equity_data = equity_data.reset_index(drop=True) index = df.index new_bokeh_figure = partial( # type: ignore[call-arg] _figure, - x_axis_type='linear', + x_axis_type="linear", width=plot_width, height=400, # TODO: xwheel_pan on horizontal after https://github.com/bokeh/bokeh/issues/14363 tools="xpan,xwheel_zoom,xwheel_pan,box_zoom,undo,redo,reset,save", - active_drag='xpan', - active_scroll='xwheel_zoom') + active_drag="xpan", + active_scroll="xwheel_zoom", + ) pad = (index[-1] - index[0]) / 20 - _kwargs = dict(x_range=Range1d(index[0], index[-1], # type: ignore[call-arg] - min_interval=10, - bounds=(index[0] - pad, - index[-1] + pad))) if index.size > 1 else {} + _kwargs = ( + dict( + x_range=Range1d( + index[0], + index[-1], # type: ignore[call-arg] + min_interval=10, + bounds=(index[0] - pad, index[-1] + pad), + ) + ) + if index.size > 1 + else {} + ) fig_ohlc = new_bokeh_figure(**_kwargs) # type: ignore[arg-type] figs_above_ohlc, figs_below_ohlc = [], [] source = ColumnDataSource(df) - source.add((df.Close >= df.Open).values.astype(np.uint8).astype(str), 'inc') - - trade_source = ColumnDataSource(dict( - index=trades['ExitBar'], - datetime=trades['ExitTime'], - size=trades['Size'], - returns_positive=(trades['ReturnPct'] > 0).astype(int).astype(str), - )) + source.add((df.Close >= df.Open).values.astype(np.uint8).astype(str), "inc") + primary_symbol = asset_symbols[0] if is_multi_asset else "_" + asset_fig_by_symbol = {primary_symbol: fig_ohlc} + asset_source_by_symbol = {primary_symbol: source} + asset_extra_figs_by_symbol = {primary_symbol: []} + + trade_source = ColumnDataSource( + dict( + index=trades["ExitBar"], + datetime=trades["ExitTime"], + size=trades["Size"], + returns_positive=(trades["ReturnPct"] > 0).astype(int).astype(str), + symbol=trades["Symbol"] if "Symbol" in trades else np.repeat("Asset", len(trades)), + ) + ) - inc_cmap = factor_cmap('inc', COLORS, ['0', '1']) - cmap = factor_cmap('returns_positive', COLORS, ['0', '1']) - colors_darker = [lightness(BEAR_COLOR, .35), - lightness(BULL_COLOR, .35)] - trades_cmap = factor_cmap('returns_positive', colors_darker, ['0', '1']) + inc_cmap = factor_cmap("inc", COLORS, ["0", "1"]) + cmap = factor_cmap("returns_positive", COLORS, ["0", "1"]) + colors_darker = [lightness(BEAR_COLOR, 0.35), lightness(BULL_COLOR, 0.35)] + trades_cmap = factor_cmap("returns_positive", colors_darker, ["0", "1"]) if is_datetime_index: fig_ohlc.xaxis.formatter = CustomJSTickFormatter( # type: ignore[attr-defined] - args=dict(axis=fig_ohlc.xaxis[0], - formatter=DatetimeTickFormatter(days='%a, %d %b', - months='%m/%Y'), - source=source), - code=''' + args=dict( + axis=fig_ohlc.xaxis[0], + formatter=DatetimeTickFormatter(days="%a, %d %b", months="%m/%Y"), + source=source, + ), + code=""" this.labels = this.labels || formatter.doFormat(ticks .map(i => source.data.datetime[i]) .filter(t => t !== undefined)); return this.labels[index] || ""; - ''') + """, + ) - NBSP = '\N{NBSP}' * 4 # noqa: E999 - ohlc_extreme_values = df[['High', 'Low']].copy(deep=False) + NBSP = "\N{NBSP}" * 4 # noqa: E999 + ohlc_extreme_values = df[["High", "Low"]].copy(deep=False) ohlc_tooltips = [ - ('x, y', NBSP.join(('$index', - '$y{0,0.0[0000]}'))), - ('OHLC', NBSP.join(('@Open{0,0.0[0000]}', - '@High{0,0.0[0000]}', - '@Low{0,0.0[0000]}', - '@Close{0,0.0[0000]}'))), - ('Volume', '@Volume{0,0}')] + ("x, y", NBSP.join(("$index", "$y{0,0.0[0000]}"))), + ( + "OHLC", + NBSP.join( + ( + "@Open{0,0.0[0000]}", + "@High{0,0.0[0000]}", + "@Low{0,0.0[0000]}", + "@Close{0,0.0[0000]}", + ) + ), + ), + ("Volume", "@Volume{0,0}"), + ] + asset_ohlc_extreme_by_symbol = {primary_symbol: ohlc_extreme_values} + asset_ohlc_tooltips_by_symbol = {primary_symbol: list(ohlc_tooltips)} + asset_ohlc_bars_by_symbol = {} def new_indicator_figure(**kwargs): - kwargs.setdefault('height', _INDICATOR_HEIGHT) - fig = new_bokeh_figure(x_range=fig_ohlc.x_range, - active_scroll='xwheel_zoom', - active_drag='xpan', - **kwargs) + kwargs.setdefault("height", _INDICATOR_HEIGHT) + fig = new_bokeh_figure( + x_range=fig_ohlc.x_range, active_scroll="xwheel_zoom", active_drag="xpan", **kwargs + ) fig.xaxis.visible = False fig.yaxis.minor_tick_line_color = None fig.yaxis.ticker.desired_num_ticks = 3 @@ -311,41 +392,52 @@ def set_tooltips(fig, tooltips=(), vline=True, renderers=()): renderers = list(renderers) if is_datetime_index: - formatters = {'@datetime': 'datetime'} + formatters = {"@datetime": "datetime"} tooltips = [("Date", "@datetime{%c}")] + tooltips else: formatters = {} tooltips = [("#", "@index")] + tooltips - fig.add_tools(HoverTool( - point_policy='follow_mouse', - renderers=renderers, formatters=formatters, - tooltips=tooltips, mode='vline' if vline else 'mouse')) + fig.add_tools( + HoverTool( + point_policy="follow_mouse", + renderers=renderers, + formatters=formatters, + tooltips=tooltips, + mode="vline" if vline else "mouse", + ) + ) def _plot_equity_section(is_return=False): """Equity section""" # Max DD Dur. line - equity = equity_data['Equity'].copy() - dd_end = equity_data['DrawdownDuration'].idxmax() + equity = equity_data["Equity"].copy() + dd_end = equity_data["DrawdownDuration"].idxmax() if np.isnan(dd_end): dd_start = dd_end = equity.index[0] else: dd_start = equity[:dd_end].idxmax() # If DD not extending into the future, get exact point of intersection with equity if dd_end != equity.index[-1]: - dd_end = np.interp(equity[dd_start], - (equity[dd_end - 1], equity[dd_end]), - (dd_end - 1, dd_end)) + dd_end = np.interp( + equity[dd_start], (equity[dd_end - 1], equity[dd_end]), (dd_end - 1, dd_end) + ) if smooth_equity: - interest_points = pd.Index([ - # Beginning and end - equity.index[0], equity.index[-1], - # Peak equity and peak DD - equity.idxmax(), equity_data['DrawdownPct'].idxmax(), - # Include max dd end points. Otherwise the MaxDD line looks amiss. - dd_start, int(dd_end), min(int(dd_end + 1), equity.size - 1), - ]) - select = pd.Index(trades['ExitBar']).union(interest_points) + interest_points = pd.Index( + [ + # Beginning and end + equity.index[0], + equity.index[-1], + # Peak equity and peak DD + equity.idxmax(), + equity_data["DrawdownPct"].idxmax(), + # Include max dd end points. Otherwise the MaxDD line looks amiss. + dd_start, + int(dd_end), + min(int(dd_end + 1), equity.size - 1), + ] + ) + select = pd.Index(trades["ExitBar"]).union(interest_points) select = select.unique().dropna() equity = equity.iloc[select].reindex(equity.index) equity.interpolate(inplace=True) @@ -357,172 +449,361 @@ def _plot_equity_section(is_return=False): if is_return: equity -= equity.iloc[0] - yaxis_label = 'Return' if is_return else 'Equity' - source_key = 'eq_return' if is_return else 'equity' + yaxis_label = "Return" if is_return else "Equity" + source_key = "eq_return" if is_return else "equity" source.add(equity, source_key) fig = new_indicator_figure( - y_axis_label=yaxis_label, - **(dict(height=80) if plot_drawdown else dict(height=100))) + y_axis_label=yaxis_label, **(dict(height=80) if plot_drawdown else dict(height=100)) + ) # High-watermark drawdown dents - fig.patch('index', 'equity_dd', - source=ColumnDataSource(dict( - index=np.r_[index, index[::-1]], - equity_dd=np.r_[equity, equity.cummax()[::-1]] - )), - fill_color='#ffffea', line_color='#ffcb66') + fig.patch( + "index", + "equity_dd", + source=ColumnDataSource( + dict( + index=np.r_[index, index[::-1]], equity_dd=np.r_[equity, equity.cummax()[::-1]] + ) + ), + fill_color="#ffffea", + line_color="#ffcb66", + ) # Equity line - r = fig.line('index', source_key, source=source, line_width=1.5, line_alpha=1) + r = fig.line("index", source_key, source=source, line_width=1.5, line_alpha=1) if relative_equity: - tooltip_format = f'@{source_key}{{+0,0.[000]%}}' - tick_format = '0,0.[00]%' - legend_format = '{:,.0f}%' + tooltip_format = f"@{source_key}{{+0,0.[000]%}}" + tick_format = "0,0.[00]%" + legend_format = "{:,.0f}%" else: - tooltip_format = f'@{source_key}{{$ 0,0}}' - tick_format = '$ 0.0 a' - legend_format = '${:,.0f}' + tooltip_format = f"@{source_key}{{$ 0,0}}" + tick_format = "$ 0.0 a" + legend_format = "${:,.0f}" set_tooltips(fig, [(yaxis_label, tooltip_format)], renderers=[r]) fig.yaxis.formatter = NumeralTickFormatter(format=tick_format) # Peaks argmax = equity.idxmax() - fig.scatter(argmax, equity[argmax], - legend_label='Peak ({})'.format( - legend_format.format(equity[argmax] * (100 if relative_equity else 1))), - color='cyan', size=8) - fig.scatter(index[-1], equity.values[-1], - legend_label='Final ({})'.format( - legend_format.format(equity.iloc[-1] * (100 if relative_equity else 1))), - color='blue', size=8) + fig.scatter( + argmax, + equity[argmax], + legend_label="Peak ({})".format( + legend_format.format(equity[argmax] * (100 if relative_equity else 1)) + ), + color="cyan", + size=8, + ) + fig.scatter( + index[-1], + equity.values[-1], + legend_label="Final ({})".format( + legend_format.format(equity.iloc[-1] * (100 if relative_equity else 1)) + ), + color="blue", + size=8, + ) if not plot_drawdown: - drawdown = equity_data['DrawdownPct'] + drawdown = equity_data["DrawdownPct"] argmax = drawdown.idxmax() - fig.scatter(argmax, equity[argmax], - legend_label='Max Drawdown (-{:.1f}%)'.format(100 * drawdown[argmax]), - color='red', size=8) - dd_timedelta_label = df['datetime'].iloc[int(round(dd_end))] - df['datetime'].iloc[dd_start] - fig.line([dd_start, dd_end], equity.iloc[dd_start], - line_color='red', line_width=2, - legend_label=f'Max Dd Dur. ({dd_timedelta_label})' - .replace(' 00:00:00', '') - .replace('(0 days ', '(')) + fig.scatter( + argmax, + equity[argmax], + legend_label="Max Drawdown (-{:.1f}%)".format(100 * drawdown[argmax]), + color="red", + size=8, + ) + dd_timedelta_label = df["datetime"].iloc[int(round(dd_end))] - df["datetime"].iloc[dd_start] + fig.line( + [dd_start, dd_end], + equity.iloc[dd_start], + line_color="red", + line_width=2, + legend_label=f"Max Dd Dur. ({dd_timedelta_label})".replace(" 00:00:00", "").replace( + "(0 days ", "(" + ), + ) figs_above_ohlc.append(fig) def _plot_drawdown_section(): """Drawdown section""" fig = new_indicator_figure(y_axis_label="Drawdown", height=80) - drawdown = equity_data['DrawdownPct'] + drawdown = equity_data["DrawdownPct"] argmax = drawdown.idxmax() - source.add(drawdown, 'drawdown') - r = fig.line('index', 'drawdown', source=source, line_width=1.3) - fig.scatter(argmax, drawdown[argmax], - legend_label='Peak (-{:.1f}%)'.format(100 * drawdown[argmax]), - color='red', size=8) - set_tooltips(fig, [('Drawdown', '@drawdown{-0.[0]%}')], renderers=[r]) + source.add(drawdown, "drawdown") + r = fig.line("index", "drawdown", source=source, line_width=1.3) + fig.scatter( + argmax, + drawdown[argmax], + legend_label="Peak (-{:.1f}%)".format(100 * drawdown[argmax]), + color="red", + size=8, + ) + set_tooltips(fig, [("Drawdown", "@drawdown{-0.[0]%}")], renderers=[r]) fig.yaxis.formatter = NumeralTickFormatter(format="-0.[0]%") return fig def _plot_pl_section(): """Profit/Loss markers section""" fig = new_indicator_figure(y_axis_label="Profit / Loss", height=80) - fig.add_layout(Span(location=0, dimension='width', line_color='#666666', - line_dash='dashed', level='underlay', line_width=1)) - trade_source.add(trades['ReturnPct'], 'returns') - size = trades['Size'].abs() + fig.add_layout( + Span( + location=0, + dimension="width", + line_color="#666666", + line_dash="dashed", + level="underlay", + line_width=1, + ) + ) + trade_source.add(trades["ReturnPct"], "returns") + size = trades["Size"].abs() size = np.interp(size, (size.min(), size.max()), (8, 20)) - trade_source.add(size, 'marker_size') - if 'count' in trades: - trade_source.add(trades['count'], 'count') - trade_source.add(trades[['EntryBar', 'ExitBar']].values.tolist(), 'lines') - fig.multi_line(xs='lines', - ys=transform('returns', CustomJSTransform(v_func='return [...xs].map(i => [0, i]);')), - source=trade_source, color='#999', line_width=1) - trade_source.add(np.take(['inverted_triangle', 'triangle'], trades['Size'] > 0), 'triangles') + trade_source.add(size, "marker_size") + if "count" in trades: + trade_source.add(trades["count"], "count") + trade_source.add(trades[["EntryBar", "ExitBar"]].values.tolist(), "lines") + fig.multi_line( + xs="lines", + ys=transform("returns", CustomJSTransform(v_func="return [...xs].map(i => [0, i]);")), + source=trade_source, + color="#999", + line_width=1, + ) + trade_source.add( + np.take(["inverted_triangle", "triangle"], trades["Size"] > 0), "triangles" + ) r1 = fig.scatter( - 'index', 'returns', source=trade_source, fill_color=cmap, - marker='triangles', line_color='black', size='marker_size') + "index", + "returns", + source=trade_source, + fill_color=cmap, + marker="triangles", + line_color="black", + size="marker_size", + ) tooltips = [("Size", "@size{0,0}")] - if 'count' in trades: + if "count" in trades: tooltips.append(("Count", "@count{0,0}")) - set_tooltips(fig, tooltips + [("P/L", "@returns{+0.[000]%}")], - vline=False, renderers=[r1]) + set_tooltips(fig, tooltips + [("P/L", "@returns{+0.[000]%}")], vline=False, renderers=[r1]) fig.yaxis.formatter = NumeralTickFormatter(format="0.[00]%") return fig - def _plot_volume_section(): + def _plot_volume_section(symbol=None): """Volume section""" + base_symbol = symbol or primary_symbol + base_fig_ohlc = asset_fig_by_symbol[base_symbol] + base_source = asset_source_by_symbol[base_symbol] fig = new_indicator_figure(height=70, y_axis_label="Volume") + fig.x_range = base_fig_ohlc.x_range fig.yaxis.ticker.desired_num_ticks = 3 - fig.xaxis.formatter = fig_ohlc.xaxis[0].formatter + fig.xaxis.formatter = base_fig_ohlc.xaxis[0].formatter fig.xaxis.visible = True - fig_ohlc.xaxis.visible = False # Show only Volume's xaxis - r = fig.vbar('index', BAR_WIDTH, 'Volume', source=source, color=inc_cmap) - set_tooltips(fig, [('Volume', '@Volume{0.00 a}')], renderers=[r]) + base_fig_ohlc.xaxis.visible = False # Show only Volume's xaxis + r = fig.vbar("index", BAR_WIDTH, "Volume", source=base_source, color=inc_cmap) + set_tooltips(fig, [("Volume", "@Volume{0.00 a}")], renderers=[r]) fig.yaxis.formatter = NumeralTickFormatter(format="0 a") + if symbol and symbol in asset_extra_figs_by_symbol: + asset_extra_figs_by_symbol[symbol].append(fig) return fig def _plot_superimposed_ohlc(): """Superimposed, downsampled vbars""" - time_resolution = pd.DatetimeIndex(df['datetime']).resolution - resample_rule = (superimpose if isinstance(superimpose, str) else - dict(day='ME', - hour='D', - minute='h', - second='min', - millisecond='s').get(time_resolution)) + time_resolution = pd.DatetimeIndex(df["datetime"]).resolution + resample_rule = ( + superimpose + if isinstance(superimpose, str) + else dict(day="ME", hour="D", minute="h", second="min", millisecond="s").get( + time_resolution + ) + ) if not resample_rule: warnings.warn( f"'Can't superimpose OHLC data with rule '{resample_rule}'" f"(index datetime resolution: '{time_resolution}'). Skipping.", - stacklevel=4) + stacklevel=4, + ) return - df2 = (df.assign(_width=1).set_index('datetime') - .resample(resample_rule, label='left') - .agg(dict(OHLCV_AGG, _width='count'))) + df2 = ( + df.assign(_width=1) + .set_index("datetime") + .resample(resample_rule, label="left") + .agg(dict(OHLCV_AGG, _width="count")) + ) # Check if resampling was downsampling; error on upsampling - orig_freq = _data_period(df['datetime']) + orig_freq = _data_period(df["datetime"]) resample_freq = _data_period(df2.index) if resample_freq < orig_freq: - raise ValueError('Invalid value for `superimpose`: Upsampling not supported.') + raise ValueError("Invalid value for `superimpose`: Upsampling not supported.") if resample_freq == orig_freq: - warnings.warn('Superimposed OHLC plot matches the original plot. Skipping.', - stacklevel=4) + warnings.warn( + "Superimposed OHLC plot matches the original plot. Skipping.", stacklevel=4 + ) return - df2.index = df2['_width'].cumsum().shift(1).fillna(0) - df2.index += df2['_width'] / 2 - .5 - df2['_width'] -= .1 # Candles don't touch + df2.index = df2["_width"].cumsum().shift(1).fillna(0) + df2.index += df2["_width"] / 2 - 0.5 + df2["_width"] -= 0.1 # Candles don't touch - df2['inc'] = (df2.Close >= df2.Open).astype(int).astype(str) + df2["inc"] = (df2.Close >= df2.Open).astype(int).astype(str) df2.index.name = None source2 = ColumnDataSource(df2) - fig_ohlc.segment('index', 'High', 'index', 'Low', source=source2, color='#bbbbbb') - colors_lighter = [lightness(BEAR_COLOR, .92), - lightness(BULL_COLOR, .92)] - fig_ohlc.vbar('index', '_width', 'Open', 'Close', source=source2, line_color=None, - fill_color=factor_cmap('inc', colors_lighter, ['0', '1'])) + fig_ohlc.segment("index", "High", "index", "Low", source=source2, color="#bbbbbb") + colors_lighter = [lightness(BEAR_COLOR, 0.92), lightness(BULL_COLOR, 0.92)] + fig_ohlc.vbar( + "index", + "_width", + "Open", + "Close", + source=source2, + line_color=None, + fill_color=factor_cmap("inc", colors_lighter, ["0", "1"]), + ) def _plot_ohlc(): """Main OHLC bars""" - fig_ohlc.segment('index', 'High', 'index', 'Low', source=source, color="black", - legend_label='OHLC') - r = fig_ohlc.vbar('index', BAR_WIDTH, 'Open', 'Close', source=source, - line_color="black", fill_color=inc_cmap, legend_label='OHLC') + fig_ohlc.segment( + "index", "High", "index", "Low", source=source, color="black", legend_label="OHLC" + ) + r = fig_ohlc.vbar( + "index", + BAR_WIDTH, + "Open", + "Close", + source=source, + line_color="black", + fill_color=inc_cmap, + legend_label="OHLC", + ) return r - def _plot_ohlc_trades(): + def _plot_ohlc_trades(fig, symbol=None): """Trade entry / exit markers on OHLC plot""" - trade_source.add(trades[['EntryBar', 'ExitBar']].values.tolist(), 'position_lines_xs') - trade_source.add(trades[['EntryPrice', 'ExitPrice']].values.tolist(), 'position_lines_ys') - fig_ohlc.multi_line(xs='position_lines_xs', ys='position_lines_ys', - source=trade_source, line_color=trades_cmap, - legend_label=f'Trades ({len(trades)})', - line_width=8, line_alpha=1, line_dash='dotted') + if symbol is None: + local_trades = trades + elif "Symbol" in trades.columns: + local_trades = trades.loc[trades["Symbol"] == symbol] + else: + local_trades = trades.iloc[:0] + if local_trades.empty: + return + local_trade_source = ColumnDataSource( + dict( + index=local_trades["ExitBar"], + datetime=local_trades["ExitTime"], + size=local_trades["Size"], + returns_positive=(local_trades["ReturnPct"] > 0).astype(int).astype(str), + position_lines_xs=local_trades[["EntryBar", "ExitBar"]].values.tolist(), + position_lines_ys=local_trades[["EntryPrice", "ExitPrice"]].values.tolist(), + ) + ) + fig.multi_line( + xs="position_lines_xs", + ys="position_lines_ys", + source=local_trade_source, + line_color=trades_cmap, + legend_label=f"Trades ({len(local_trades)})", + line_width=8, + line_alpha=1, + line_dash="dotted", + ) + + def _plot_asset_ohlc_sections(): + if not is_multi_asset: + return [] + + def _align_asset_to_primary_index(asset_df: pd.DataFrame, target_index: pd.DatetimeIndex): + if asset_df.index.equals(target_index): + return asset_df + + # Map each source bar to the first target timestamp at or after it + bucket = target_index.searchsorted(asset_df.index, side="left") + valid = bucket < len(target_index) + if not np.any(valid): + return asset_df.iloc[:0] + + grouped = asset_df.iloc[np.flatnonzero(valid)].groupby(bucket[valid]).agg(OHLCV_AGG) + out = pd.DataFrame(index=target_index, columns=asset_df.columns, dtype=float) + out.iloc[grouped.index.to_numpy(dtype=int)] = grouped.to_numpy() + return out + + def new_asset_figure(primary_x_range): + fig = new_bokeh_figure( + x_range=primary_x_range, + active_scroll="xwheel_zoom", + active_drag="xpan", + y_axis_label="Price", + height=220, + ) + fig.xaxis.visible = False + fig.yaxis.minor_tick_line_color = None + fig.yaxis.ticker.desired_num_ticks = 3 + return fig + + asset_figs = [] + for symbol in asset_symbols[1:]: + asset_df = assets[symbol][list(OHLCV_AGG.keys())].copy(deep=False) + if is_datetime_index: + target_index = pd.DatetimeIndex(df["datetime"]) + asset_df = _align_asset_to_primary_index(asset_df, target_index) + asset_df.index.name = None + asset_df["datetime"] = asset_df.index + asset_df = asset_df.reset_index(drop=True) + + asset_source = ColumnDataSource(asset_df) + asset_source.add( + (asset_df.Close >= asset_df.Open).values.astype(np.uint8).astype(str), "inc" + ) + + fig = new_asset_figure(fig_ohlc.x_range) + asset_fig_by_symbol[symbol] = fig + asset_source_by_symbol[symbol] = asset_source + asset_extra_figs_by_symbol[symbol] = [] + asset_ohlc_extreme_by_symbol[symbol] = asset_df[["High", "Low"]].copy(deep=False) + asset_ohlc_tooltips_by_symbol[symbol] = [ + ("x, y", NBSP.join(("$index", "$y{0,0.0[0000]}"))), + ( + "OHLC", + NBSP.join( + ( + "@Open{0,0.0[0000]}", + "@High{0,0.0[0000]}", + "@Low{0,0.0[0000]}", + "@Close{0,0.0[0000]}", + ) + ), + ), + ("Volume", "@Volume{0,0}"), + ] + fig.segment( + "index", + "High", + "index", + "Low", + source=asset_source, + color="black", + legend_label=symbol, + ) + asset_bars = fig.vbar( + "index", + BAR_WIDTH, + "Open", + "Close", + source=asset_source, + line_color="black", + fill_color=factor_cmap("inc", COLORS, ["0", "1"]), + legend_label=symbol, + ) + asset_ohlc_bars_by_symbol[symbol] = asset_bars + if is_datetime_index: + fig.xaxis.formatter = fig_ohlc.xaxis[0].formatter + if plot_trades: + _plot_ohlc_trades(fig, symbol=symbol) + asset_figs.append(fig) + return asset_figs def _plot_indicators(): """Strategy indicators""" @@ -530,8 +811,7 @@ def _plot_indicators(): def _too_many_dims(value): assert value.ndim >= 2 if value.ndim > 2: - warnings.warn(f"Can't plot indicators with >2D ('{value.name}')", - stacklevel=5) + warnings.warn(f"Can't plot indicators with >2D ('{value.name}')", stacklevel=5) return True return False @@ -552,25 +832,42 @@ def __eq__(self, other): if _too_many_dims(value): continue + indicator_symbol = value._opts.get("symbol") + overlay_symbol = ( + indicator_symbol if indicator_symbol in asset_fig_by_symbol else primary_symbol + ) + # Use .get()! A user might have assigned a Strategy.data-evolved # _Array without Strategy.I() - is_overlay = value._opts.get('overlay') - is_scatter = value._opts.get('scatter') - is_muted = not value._opts.get('plot') + is_overlay = value._opts.get("overlay") + is_scatter = value._opts.get("scatter") + is_muted = not value._opts.get("plot") # is overlay => show muted, hide legend item. non-overlay => don't show at all if is_muted and not is_overlay: continue if is_overlay: - fig = fig_ohlc + fig = asset_fig_by_symbol[overlay_symbol] + indicator_source = asset_source_by_symbol[overlay_symbol] + indicator_ohlc_extreme = asset_ohlc_extreme_by_symbol[overlay_symbol] + indicator_ohlc_tooltips = asset_ohlc_tooltips_by_symbol[overlay_symbol] else: fig = new_indicator_figure() - indicator_figs.append(fig) + if overlay_symbol in asset_extra_figs_by_symbol: + asset_extra_figs_by_symbol[overlay_symbol].append(fig) + else: + indicator_figs.append(fig) + indicator_source = source + indicator_ohlc_extreme = ohlc_extreme_values + indicator_ohlc_tooltips = ohlc_tooltips tooltips = [] - colors = value._opts['color'] - colors = colors and cycle(_as_list(colors)) or ( - cycle([next(ohlc_colors)]) if is_overlay else colorgen()) + colors = value._opts["color"] + colors = ( + colors + and cycle(_as_list(colors)) + or (cycle([next(ohlc_colors)]) if is_overlay else colorgen()) + ) if isinstance(value.name, str): tooltip_label = value.name @@ -581,46 +878,76 @@ def __eq__(self, other): for j, arr in enumerate(value): color = next(colors) - source_name = f'{legend_labels[j]}_{i}_{j}' + source_name = f"{legend_labels[j]}_{i}_{j}" if arr.dtype == bool: arr = arr.astype(int) - source.add(arr, source_name) - tooltips.append(f'@{{{source_name}}}{{0,0.0[0000]}}') + indicator_source.add(arr, source_name) + tooltips.append(f"@{{{source_name}}}{{0,0.0[0000]}}") kwargs = {} if not is_muted: - kwargs['legend_label'] = legend_labels[j] + kwargs["legend_label"] = legend_labels[j] if is_overlay: - ohlc_extreme_values[source_name] = arr + indicator_ohlc_extreme[source_name] = arr if is_scatter: r2 = fig.circle( - 'index', source_name, source=source, - color=color, line_color='black', fill_alpha=.8, - radius=BAR_WIDTH / 2 * .9, **kwargs) + "index", + source_name, + source=indicator_source, + color=color, + line_color="black", + fill_alpha=0.8, + radius=BAR_WIDTH / 2 * 0.9, + **kwargs, + ) else: r2 = fig.line( - 'index', source_name, source=source, - line_color=color, line_width=1.4 if is_muted else 1.5, **kwargs) + "index", + source_name, + source=indicator_source, + line_color=color, + line_width=1.4 if is_muted else 1.5, + **kwargs, + ) # r != r2 r2.muted = is_muted else: if is_scatter: r = fig.circle( - 'index', source_name, source=source, - color=color, radius=BAR_WIDTH / 2 * .6, **kwargs) + "index", + source_name, + source=indicator_source, + color=color, + radius=BAR_WIDTH / 2 * 0.6, + **kwargs, + ) else: r = fig.line( - 'index', source_name, source=source, - line_color=color, line_width=1.3, **kwargs) + "index", + source_name, + source=indicator_source, + line_color=color, + line_width=1.3, + **kwargs, + ) # Add dashed centerline just because mean = try_(lambda: float(pd.Series(arr).mean()), default=np.nan) - if not np.isnan(mean) and (abs(mean) < .1 or - round(abs(mean), 1) == .5 or - round(abs(mean), -1) in (50, 100, 200)): - fig.add_layout(Span(location=float(mean), dimension='width', - line_color='#666666', line_dash='dashed', - level='underlay', line_width=.5)) + if not np.isnan(mean) and ( + abs(mean) < 0.1 + or round(abs(mean), 1) == 0.5 + or round(abs(mean), -1) in (50, 100, 200) + ): + fig.add_layout( + Span( + location=float(mean), + dimension="width", + line_color="#666666", + line_dash="dashed", + level="underlay", + line_width=0.5, + ) + ) if is_overlay: - ohlc_tooltips.append((tooltip_label, NBSP.join(tooltips))) + indicator_ohlc_tooltips.append((tooltip_label, NBSP.join(tooltips))) else: set_tooltips(fig, [(tooltip_label, NBSP.join(tooltips))], vline=True, renderers=[r]) # If the sole indicator line on this figure, @@ -643,16 +970,23 @@ def __eq__(self, other): if plot_pl: figs_above_ohlc.append(_plot_pl_section()) + asset_ohlc_figs = _plot_asset_ohlc_sections() + + for symbol in asset_symbols[1:] if is_multi_asset else (): + if plot_volume: + _plot_volume_section(symbol=symbol) + if plot_volume: - fig_volume = _plot_volume_section() + fig_volume = _plot_volume_section(symbol=primary_symbol) figs_below_ohlc.append(fig_volume) if superimpose and is_datetime_index: _plot_superimposed_ohlc() ohlc_bars = _plot_ohlc() + asset_ohlc_bars_by_symbol[primary_symbol] = ohlc_bars if plot_trades: - _plot_ohlc_trades() + _plot_ohlc_trades(fig_ohlc, symbol=asset_symbols[0] if is_multi_asset else None) indicator_figs = _plot_indicators() if reverse_indicators: indicator_figs = indicator_figs[::-1] @@ -660,43 +994,56 @@ def __eq__(self, other): _watermark(fig_ohlc) - set_tooltips(fig_ohlc, ohlc_tooltips, vline=True, renderers=[ohlc_bars]) + for symbol, asset_fig in asset_fig_by_symbol.items(): + set_tooltips( + asset_fig, + asset_ohlc_tooltips_by_symbol[symbol], + vline=True, + renderers=[asset_ohlc_bars_by_symbol[symbol]], + ) - source.add(ohlc_extreme_values.min(1), 'ohlc_low') - source.add(ohlc_extreme_values.max(1), 'ohlc_high') + source.add(asset_ohlc_extreme_by_symbol[primary_symbol].min(1), "ohlc_low") + source.add(asset_ohlc_extreme_by_symbol[primary_symbol].max(1), "ohlc_high") - custom_js_args = dict(ohlc_range=fig_ohlc.y_range, - source=source) + custom_js_args = dict(ohlc_range=fig_ohlc.y_range, source=source) if plot_volume: custom_js_args.update(volume_range=fig_volume.y_range) - fig_ohlc.x_range.js_on_change('end', CustomJS(args=custom_js_args, - code=_AUTOSCALE_JS_CALLBACK)) + fig_ohlc.x_range.js_on_change("end", CustomJS(args=custom_js_args, code=_AUTOSCALE_JS_CALLBACK)) - figs = figs_above_ohlc + [fig_ohlc] + figs_below_ohlc + per_asset_blocks = [] + if is_multi_asset: + for symbol in asset_symbols[1:]: + per_asset_blocks.extend(asset_extra_figs_by_symbol.get(symbol, [])) + per_asset_blocks.append(asset_fig_by_symbol[symbol]) + + figs = figs_above_ohlc + per_asset_blocks + [fig_ohlc] + figs_below_ohlc linked_crosshair = CrosshairTool( - dimensions='both', line_color='lightgrey', - overlay=(Span(dimension="width", line_dash="dotted", line_width=1), - Span(dimension="height", line_dash="dotted", line_width=1)), + dimensions="both", + line_color="lightgrey", + overlay=( + Span(dimension="width", line_dash="dotted", line_width=1), + Span(dimension="height", line_dash="dotted", line_width=1), + ), ) for f in figs: if f.legend: f.legend.visible = show_legend - f.legend.location = 'top_left' + f.legend.location = "top_left" f.legend.border_line_width = 1 - f.legend.border_line_color = '#333333' + f.legend.border_line_color = "#333333" f.legend.padding = 5 f.legend.spacing = 0 f.legend.margin = 0 - f.legend.label_text_font_size = '8pt' + f.legend.label_text_font_size = "8pt" f.legend.click_policy = "hide" - f.legend.background_fill_alpha = .9 + f.legend.background_fill_alpha = 0.9 f.min_border_left = 0 f.min_border_top = 3 f.min_border_bottom = 6 f.min_border_right = 10 - f.outline_line_color = '#666666' + f.outline_line_color = "#666666" f.add_tools(linked_crosshair) wheelzoom_tool = next(wz for wz in f.tools if isinstance(wz, WheelZoomTool)) @@ -704,82 +1051,92 @@ def __eq__(self, other): kwargs = {} if plot_width is None: - kwargs['sizing_mode'] = 'stretch_width' + kwargs["sizing_mode"] = "stretch_width" fig = gridplot( figs, ncols=1, - toolbar_location='right', + toolbar_location="right", toolbar_options=dict(logo=None), merge_tools=True, - **kwargs # type: ignore + **kwargs, # type: ignore ) - show(fig, browser=None if open_browser else 'none') + show(fig, browser=None if open_browser else "none") return fig -def plot_heatmaps(heatmap: pd.Series, agg: Union[Callable, str], ncols: int, - filename: str = '', plot_width: int = 1200, open_browser: bool = True): - if not (isinstance(heatmap, pd.Series) and - isinstance(heatmap.index, pd.MultiIndex)): - raise ValueError('heatmap must be heatmap Series as returned by ' - '`Backtest.optimize(..., return_heatmap=True)`') +def plot_heatmaps( + heatmap: pd.Series, + agg: Union[Callable, str], + ncols: int, + filename: str = "", + plot_width: int = 1200, + open_browser: bool = True, +): + if not (isinstance(heatmap, pd.Series) and isinstance(heatmap.index, pd.MultiIndex)): + raise ValueError( + "heatmap must be heatmap Series as returned by " + "`Backtest.optimize(..., return_heatmap=True)`" + ) if len(heatmap.index.levels) < 2: - raise ValueError('`plot_heatmap()` requires at least two optimization ' - 'variables to plot') + raise ValueError("`plot_heatmap()` requires at least two optimization variables to plot") _bokeh_reset(filename) param_combinations = combinations(heatmap.index.names, 2) - dfs = [heatmap.groupby(list(dims)).agg(agg).to_frame(name='_Value') - for dims in param_combinations] + dfs = [ + heatmap.groupby(list(dims)).agg(agg).to_frame(name="_Value") for dims in param_combinations + ] figs: list[_figure] = [] - cmap = LinearColorMapper(palette='Viridis256', - low=min(df.min().min() for df in dfs), - high=max(df.max().max() for df in dfs), - nan_color='white') + cmap = LinearColorMapper( + palette="Viridis256", + low=min(df.min().min() for df in dfs), + high=max(df.max().max() for df in dfs), + nan_color="white", + ) for df in dfs: name1, name2 = df.index.names level1 = df.index.levels[0].astype(str).tolist() level2 = df.index.levels[1].astype(str).tolist() df = df.reset_index() - df[name1] = df[name1].astype('str') - df[name2] = df[name2].astype('str') - - fig = _figure(x_range=level1, # type: ignore[call-arg] - y_range=level2, - x_axis_label=name1, - y_axis_label=name2, - width=plot_width // ncols, - height=plot_width // ncols, - tools='box_zoom,reset,save', - tooltips=[(name1, '@' + name1), - (name2, '@' + name2), - ('Value', '@_Value{0.[000]}')]) - fig.grid.grid_line_color = None # type: ignore[attr-defined] - fig.axis.axis_line_color = None # type: ignore[attr-defined] + df[name1] = df[name1].astype("str") + df[name2] = df[name2].astype("str") + + fig = _figure( + x_range=level1, # type: ignore[call-arg] + y_range=level2, + x_axis_label=name1, + y_axis_label=name2, + width=plot_width // ncols, + height=plot_width // ncols, + tools="box_zoom,reset,save", + tooltips=[(name1, "@" + name1), (name2, "@" + name2), ("Value", "@_Value{0.[000]}")], + ) + fig.grid.grid_line_color = None # type: ignore[attr-defined] + fig.axis.axis_line_color = None # type: ignore[attr-defined] fig.axis.major_tick_line_color = None # type: ignore[attr-defined] - fig.axis.major_label_standoff = 0 # type: ignore[attr-defined] + fig.axis.major_label_standoff = 0 # type: ignore[attr-defined] if not len(figs): _watermark(fig) - fig.rect(x=name1, - y=name2, - width=1, - height=1, - source=df, - line_color=None, - fill_color=dict(field='_Value', - transform=cmap)) + fig.rect( + x=name1, + y=name2, + width=1, + height=1, + source=df, + line_color=None, + fill_color=dict(field="_Value", transform=cmap), + ) figs.append(fig) fig = gridplot( figs, # type: ignore ncols=ncols, toolbar_options=dict(logo=None), - toolbar_location='above', + toolbar_location="above", merge_tools=True, ) - show(fig, browser=None if open_browser else 'none') + show(fig, browser=None if open_browser else "none") return fig diff --git a/backtesting/_stats.py b/backtesting/_stats.py index 1f01c5a3..e8dc0097 100644 --- a/backtesting/_stats.py +++ b/backtesting/_stats.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Union, cast +from typing import TYPE_CHECKING, Dict, List, Union, cast import numpy as np import pandas as pd @@ -14,17 +14,17 @@ def compute_drawdown_duration_peaks(dd: pd.Series): iloc = np.unique(np.r_[(dd == 0).values.nonzero()[0], len(dd) - 1]) iloc = pd.Series(iloc, index=dd.index[iloc]) - df = iloc.to_frame('iloc').assign(prev=iloc.shift()) - df = df[df['iloc'] > df['prev'] + 1].astype(np.int64) + df = iloc.to_frame("iloc").assign(prev=iloc.shift()) + df = df[df["iloc"] > df["prev"] + 1].astype(np.int64) # If no drawdown since no trade, avoid below for pandas sake and return nan series if not len(df): return (dd.replace(0, np.nan),) * 2 - df['duration'] = df['iloc'].map(dd.index.__getitem__) - df['prev'].map(dd.index.__getitem__) - df['peak_dd'] = df.apply(lambda row: dd.iloc[row['prev']:row['iloc'] + 1].max(), axis=1) + df["duration"] = df["iloc"].map(dd.index.__getitem__) - df["prev"].map(dd.index.__getitem__) + df["peak_dd"] = df.apply(lambda row: dd.iloc[row["prev"] : row["iloc"] + 1].max(), axis=1) df = df.reindex(dd.index) - return df['duration'], df['peak_dd'] + return df["duration"], df["peak_dd"] def geometric_mean(returns: pd.Series) -> float: @@ -35,86 +35,97 @@ def geometric_mean(returns: pd.Series) -> float: def compute_stats( - trades: Union[List['Trade'], pd.DataFrame], - equity: np.ndarray, - ohlc_data: pd.DataFrame, - strategy_instance: Strategy | None, - risk_free_rate: float = 0, + trades: Union[List["Trade"], pd.DataFrame], + equity: np.ndarray, + ohlc_data: Union[pd.DataFrame, Dict[str, pd.DataFrame]], + strategy_instance: Strategy | None, + risk_free_rate: float = 0, ) -> pd.Series: assert -1 < risk_free_rate < 1 - index = ohlc_data.index + if isinstance(ohlc_data, dict): + first_symbol = next(iter(ohlc_data)) + primary_ohlc_data = ohlc_data[first_symbol] + else: + first_symbol = None + primary_ohlc_data = ohlc_data + + index = primary_ohlc_data.index dd = 1 - equity / np.maximum.accumulate(equity) dd_dur, dd_peaks = compute_drawdown_duration_peaks(pd.Series(dd, index=index)) - equity_df = pd.DataFrame({ - 'Equity': equity, - 'DrawdownPct': dd, - 'DrawdownDuration': dd_dur}, - index=index) + equity_df = pd.DataFrame( + {"Equity": equity, "DrawdownPct": dd, "DrawdownDuration": dd_dur}, index=index + ) if isinstance(trades, pd.DataFrame): trades_df: pd.DataFrame = trades commissions = None # Not shown else: # Came straight from Backtest.run() - trades_df = pd.DataFrame({ - 'Size': [t.size for t in trades], - 'EntryBar': [t.entry_bar for t in trades], - 'ExitBar': [t.exit_bar for t in trades], - 'EntryPrice': [t.entry_price for t in trades], - 'ExitPrice': [t.exit_price for t in trades], - 'SL': [t.sl for t in trades], - 'TP': [t.tp for t in trades], - 'PnL': [t.pl for t in trades], - 'Commission': [t._commissions for t in trades], - 'ReturnPct': [t.pl_pct for t in trades], - 'EntryTime': [t.entry_time for t in trades], - 'ExitTime': [t.exit_time for t in trades], - }) - trades_df['Duration'] = trades_df['ExitTime'] - trades_df['EntryTime'] - trades_df['Tag'] = [t.tag for t in trades] + trades_df = pd.DataFrame( + { + "Size": [t.size for t in trades], + "EntryBar": [t.entry_bar for t in trades], + "ExitBar": [t.exit_bar for t in trades], + "EntryPrice": [t.entry_price for t in trades], + "ExitPrice": [t.exit_price for t in trades], + "SL": [t.sl for t in trades], + "TP": [t.tp for t in trades], + "PnL": [t.pl for t in trades], + "Commission": [t._commissions for t in trades], + "ReturnPct": [t.pl_pct for t in trades], + "EntryTime": [t.entry_time for t in trades], + "ExitTime": [t.exit_time for t in trades], + } + ) + trades_df["Duration"] = trades_df["ExitTime"] - trades_df["EntryTime"] + trades_df["Tag"] = [t.tag for t in trades] + if isinstance(ohlc_data, dict): + trades_df["Symbol"] = [getattr(t, "symbol", first_symbol) for t in trades] # Add indicator values if len(trades_df) and strategy_instance: for ind in strategy_instance._indicators: ind = np.atleast_2d(ind) for i, values in enumerate(ind): # multi-d indicators - suffix = f'_{i}' if len(ind) > 1 else '' - trades_df[f'Entry_{ind.name}{suffix}'] = values[trades_df['EntryBar'].values] - trades_df[f'Exit_{ind.name}{suffix}'] = values[trades_df['ExitBar'].values] + suffix = f"_{i}" if len(ind) > 1 else "" + trades_df[f"Entry_{ind.name}{suffix}"] = values[trades_df["EntryBar"].values] + trades_df[f"Exit_{ind.name}{suffix}"] = values[trades_df["ExitBar"].values] commissions = sum(t._commissions for t in trades) del trades - pl = trades_df['PnL'] - returns = trades_df['ReturnPct'] - durations = trades_df['Duration'] + pl = trades_df["PnL"] + returns = trades_df["ReturnPct"] + durations = trades_df["Duration"] def _round_timedelta(value, _period=_data_period(index)): if not isinstance(value, pd.Timedelta): return value - resolution = getattr(_period, 'resolution_string', None) or _period.resolution + resolution = getattr(_period, "resolution_string", None) or _period.resolution return value.ceil(resolution) s = pd.Series(dtype=object) - s.loc['Start'] = index[0] - s.loc['End'] = index[-1] - s.loc['Duration'] = s.End - s.Start + s.loc["Start"] = index[0] + s.loc["End"] = index[-1] + s.loc["Duration"] = s.End - s.Start have_position = np.repeat(0, len(index)) - for t in trades_df[['EntryBar', 'ExitBar']].itertuples(index=False): - have_position[t.EntryBar:t.ExitBar + 1] = 1 + for t in trades_df[["EntryBar", "ExitBar"]].itertuples(index=False): + have_position[t.EntryBar : t.ExitBar + 1] = 1 - s.loc['Exposure Time [%]'] = have_position.mean() * 100 # In "n bars" time, not index time - s.loc['Equity Final [$]'] = equity[-1] - s.loc['Equity Peak [$]'] = equity.max() + s.loc["Exposure Time [%]"] = have_position.mean() * 100 # In "n bars" time, not index time + s.loc["Equity Final [$]"] = equity[-1] + s.loc["Equity Peak [$]"] = equity.max() if commissions: - s.loc['Commissions [$]'] = commissions - s.loc['Return [%]'] = (equity[-1] - equity[0]) / equity[0] * 100 + s.loc["Commissions [$]"] = commissions + s.loc["Return [%]"] = (equity[-1] - equity[0]) / equity[0] * 100 first_trading_bar = _indicator_warmup_nbars(strategy_instance) - c = ohlc_data.Close.values - s.loc['Buy & Hold Return [%]'] = (c[-1] - c[first_trading_bar]) / c[first_trading_bar] * 100 # long-only return + c = primary_ohlc_data.Close.values + s.loc["Buy & Hold Return [%]"] = ( + (c[-1] - c[first_trading_bar]) / c[first_trading_bar] * 100 + ) # long-only return gmean_day_return: float = 0 day_returns = np.array(np.nan) @@ -122,37 +133,58 @@ def _round_timedelta(value, _period=_data_period(index)): is_datetime_index = isinstance(index, pd.DatetimeIndex) if is_datetime_index: freq_days = cast(pd.Timedelta, _data_period(index)).days - have_weekends = index.dayofweek.to_series().between(5, 6).mean() > 2 / 7 * .6 + have_weekends = index.dayofweek.to_series().between(5, 6).mean() > 2 / 7 * 0.6 annual_trading_days = ( - 52 if freq_days == 7 else - 12 if freq_days == 31 else - 1 if freq_days == 365 else - (365 if have_weekends else 252)) - freq = {7: 'W', 31: 'ME', 365: 'YE'}.get(freq_days, 'D') - day_returns = equity_df['Equity'].resample(freq).last().dropna().pct_change() + 52 + if freq_days == 7 + else 12 + if freq_days == 31 + else 1 + if freq_days == 365 + else (365 if have_weekends else 252) + ) + freq = {7: "W", 31: "ME", 365: "YE"}.get(freq_days, "D") + day_returns = equity_df["Equity"].resample(freq).last().dropna().pct_change() gmean_day_return = geometric_mean(day_returns) # Annualized return and risk metrics are computed based on the (mostly correct) # assumption that the returns are compounded. See: https://dx.doi.org/10.2139/ssrn.3054517 # Our annualized return matches `empyrical.annual_return(day_returns)` whereas # our risk doesn't; they use the simpler approach below. - annualized_return = (1 + gmean_day_return)**annual_trading_days - 1 - s.loc['Return (Ann.) [%]'] = annualized_return * 100 - s.loc['Volatility (Ann.) [%]'] = np.sqrt((day_returns.var(ddof=int(bool(day_returns.shape))) + (1 + gmean_day_return)**2)**annual_trading_days - (1 + gmean_day_return)**(2 * annual_trading_days)) * 100 # noqa: E501 + annualized_return = (1 + gmean_day_return) ** annual_trading_days - 1 + s.loc["Return (Ann.) [%]"] = annualized_return * 100 + s.loc["Volatility (Ann.) [%]"] = ( + np.sqrt( + (day_returns.var(ddof=int(bool(day_returns.shape))) + (1 + gmean_day_return) ** 2) + ** annual_trading_days + - (1 + gmean_day_return) ** (2 * annual_trading_days) + ) + * 100 + ) # noqa: E501 # s.loc['Return (Ann.) [%]'] = gmean_day_return * annual_trading_days * 100 # s.loc['Risk (Ann.) [%]'] = day_returns.std(ddof=1) * np.sqrt(annual_trading_days) * 100 if is_datetime_index: - time_in_years = (s.loc['Duration'].days + s.loc['Duration'].seconds / 86400) / annual_trading_days - s.loc['CAGR [%]'] = ((s.loc['Equity Final [$]'] / equity[0])**(1 / time_in_years) - 1) * 100 if time_in_years else np.nan # noqa: E501 + time_in_years = ( + s.loc["Duration"].days + s.loc["Duration"].seconds / 86400 + ) / annual_trading_days + s.loc["CAGR [%]"] = ( + ((s.loc["Equity Final [$]"] / equity[0]) ** (1 / time_in_years) - 1) * 100 + if time_in_years + else np.nan + ) # noqa: E501 # Our Sharpe mismatches `empyrical.sharpe_ratio()` because they use arithmetic mean return # and simple standard deviation - s.loc['Sharpe Ratio'] = (s.loc['Return (Ann.) [%]'] - risk_free_rate * 100) / (s.loc['Volatility (Ann.) [%]'] or np.nan) # noqa: E501 + s.loc["Sharpe Ratio"] = (s.loc["Return (Ann.) [%]"] - risk_free_rate * 100) / ( + s.loc["Volatility (Ann.) [%]"] or np.nan + ) # noqa: E501 # Our Sortino mismatches `empyrical.sortino_ratio()` because they use arithmetic mean return - with np.errstate(divide='ignore'): - s.loc['Sortino Ratio'] = (annualized_return - risk_free_rate) / (np.sqrt(np.mean(day_returns.clip(-np.inf, 0)**2)) * np.sqrt(annual_trading_days)) # noqa: E501 + with np.errstate(divide="ignore"): + s.loc["Sortino Ratio"] = (annualized_return - risk_free_rate) / ( + np.sqrt(np.mean(day_returns.clip(-np.inf, 0) ** 2)) * np.sqrt(annual_trading_days) + ) # noqa: E501 max_dd = -np.nan_to_num(dd.max()) - s.loc['Calmar Ratio'] = annualized_return / (-max_dd or np.nan) + s.loc["Calmar Ratio"] = annualized_return / (-max_dd or np.nan) equity_log_returns = np.log(equity[1:] / equity[:-1]) market_log_returns = np.log(c[1:] / c[:-1]) beta = np.nan @@ -161,29 +193,35 @@ def _round_timedelta(value, _period=_data_period(index)): cov_matrix = np.cov(equity_log_returns, market_log_returns) beta = cov_matrix[0, 1] / cov_matrix[1, 1] # Jensen CAPM Alpha: can be strongly positive when beta is negative and B&H Return is large - s.loc['Alpha [%]'] = s.loc['Return [%]'] - risk_free_rate * 100 - beta * (s.loc['Buy & Hold Return [%]'] - risk_free_rate * 100) # noqa: E501 - s.loc['Beta'] = beta - s.loc['Max. Drawdown [%]'] = max_dd * 100 - s.loc['Avg. Drawdown [%]'] = -dd_peaks.mean() * 100 - s.loc['Max. Drawdown Duration'] = _round_timedelta(dd_dur.max()) - s.loc['Avg. Drawdown Duration'] = _round_timedelta(dd_dur.mean()) - s.loc['# Trades'] = n_trades = len(trades_df) + s.loc["Alpha [%]"] = ( + s.loc["Return [%]"] + - risk_free_rate * 100 + - beta * (s.loc["Buy & Hold Return [%]"] - risk_free_rate * 100) + ) # noqa: E501 + s.loc["Beta"] = beta + s.loc["Max. Drawdown [%]"] = max_dd * 100 + s.loc["Avg. Drawdown [%]"] = -dd_peaks.mean() * 100 + s.loc["Max. Drawdown Duration"] = _round_timedelta(dd_dur.max()) + s.loc["Avg. Drawdown Duration"] = _round_timedelta(dd_dur.mean()) + s.loc["# Trades"] = n_trades = len(trades_df) win_rate = np.nan if not n_trades else (pl > 0).mean() - s.loc['Win Rate [%]'] = win_rate * 100 - s.loc['Best Trade [%]'] = returns.max() * 100 - s.loc['Worst Trade [%]'] = returns.min() * 100 + s.loc["Win Rate [%]"] = win_rate * 100 + s.loc["Best Trade [%]"] = returns.max() * 100 + s.loc["Worst Trade [%]"] = returns.min() * 100 mean_return = geometric_mean(returns) - s.loc['Avg. Trade [%]'] = mean_return * 100 - s.loc['Max. Trade Duration'] = _round_timedelta(durations.max()) - s.loc['Avg. Trade Duration'] = _round_timedelta(durations.mean()) - s.loc['Profit Factor'] = returns[returns > 0].sum() / (abs(returns[returns < 0].sum()) or np.nan) # noqa: E501 - s.loc['Expectancy [%]'] = returns.mean() * 100 - s.loc['SQN'] = np.sqrt(n_trades) * pl.mean() / (pl.std() or np.nan) - s.loc['Kelly Criterion'] = win_rate - (1 - win_rate) / (pl[pl > 0].mean() / -pl[pl < 0].mean()) - - s.loc['_strategy'] = strategy_instance - s.loc['_equity_curve'] = equity_df - s.loc['_trades'] = trades_df + s.loc["Avg. Trade [%]"] = mean_return * 100 + s.loc["Max. Trade Duration"] = _round_timedelta(durations.max()) + s.loc["Avg. Trade Duration"] = _round_timedelta(durations.mean()) + s.loc["Profit Factor"] = returns[returns > 0].sum() / ( + abs(returns[returns < 0].sum()) or np.nan + ) # noqa: E501 + s.loc["Expectancy [%]"] = returns.mean() * 100 + s.loc["SQN"] = np.sqrt(n_trades) * pl.mean() / (pl.std() or np.nan) + s.loc["Kelly Criterion"] = win_rate - (1 - win_rate) / (pl[pl > 0].mean() / -pl[pl < 0].mean()) + + s.loc["_strategy"] = strategy_instance + s.loc["_equity_curve"] = equity_df + s.loc["_trades"] = trades_df s = _Stats(s) return s @@ -192,9 +230,12 @@ def _round_timedelta(value, _period=_data_period(index)): class _Stats(pd.Series): def __repr__(self): with pd.option_context( - 'display.max_colwidth', 20, # Prevent expansion due to _equity and _trades dfs - 'display.max_rows', len(self), # Reveal self whole - 'display.precision', 5, # Enough for my eyes at least + "display.max_colwidth", + 20, # Prevent expansion due to _equity and _trades dfs + "display.max_rows", + len(self), # Reveal self whole + "display.precision", + 5, # Enough for my eyes at least # 'format.na_rep', '--', # TODO: Enable once it works ): return super().__repr__() @@ -202,11 +243,27 @@ def __repr__(self): def dummy_stats(): from .backtesting import Trade, _Broker - index = pd.DatetimeIndex(['2025']) - data = pd.DataFrame({col: [np.nan] for col in ('Close',)}, index=index) - trade = Trade(_Broker(data=data, cash=10000, spread=.01, commission=.01, margin=.1, - trade_on_close=True, hedging=True, exclusive_orders=False, index=index), - 1, 1, 0, None) + + index = pd.DatetimeIndex(["2025"]) + data = pd.DataFrame({col: [np.nan] for col in ("Close",)}, index=index) + trade = Trade( + _Broker( + data=data, + cash=10000, + spread=0.01, + commission=0.01, + margin=0.1, + trade_on_close=True, + hedging=True, + exclusive_orders=False, + index=index, + ), + 1, + 1, + 0, + None, + "_", + ) trade._replace(exit_price=1, exit_bar=0) trade._commissions = np.nan return compute_stats([trade], np.r_[[np.nan]], data, None, 0) diff --git a/backtesting/_util.py b/backtesting/_util.py index 123abe4e..52493dbb 100644 --- a/backtesting/_util.py +++ b/backtesting/_util.py @@ -10,15 +10,17 @@ from multiprocessing import shared_memory as _mpshm from numbers import Number from threading import Lock -from typing import Dict, List, Optional, Sequence, Union, cast +from typing import Dict, List, Mapping, Optional, Sequence, Union, cast import numpy as np import pandas as pd try: from tqdm.auto import tqdm as _tqdm + _tqdm = partial(_tqdm, leave=False) except ImportError: + def _tqdm(seq, **_): return seq @@ -48,14 +50,14 @@ def _as_str(value) -> str: if isinstance(value, (Number, str)): return str(value) if isinstance(value, pd.DataFrame): - return 'df' - name = str(getattr(value, 'name', '') or '') - if name in ('Open', 'High', 'Low', 'Close', 'Volume'): + return "df" + name = str(getattr(value, "name", "") or "") + if name in ("Open", "High", "Low", "Close", "Volume"): return name[:1] if callable(value): - name = getattr(value, '__name__', value.__class__.__name__).replace('', 'λ') + name = getattr(value, "__name__", value.__class__.__name__).replace("", "λ") if len(name) > 10: - name = name[:9] + '…' + name = name[:9] + "…" return name @@ -69,7 +71,7 @@ def _batch(seq): # XXX: Replace with itertools.batched n = np.clip(int(len(seq) // (os.cpu_count() or 1)), 1, 300) for i in range(0, len(seq), n): - yield seq[i:i + n] + yield seq[i : i + n] def _data_period(index) -> Union[pd.Timedelta, Number]: @@ -79,17 +81,24 @@ def _data_period(index) -> Union[pd.Timedelta, Number]: def _strategy_indicators(strategy): - return {attr: indicator - for attr, indicator in strategy.__dict__.items() - if isinstance(indicator, _Indicator)}.items() + return { + attr: indicator + for attr, indicator in strategy.__dict__.items() + if isinstance(indicator, _Indicator) + }.items() def _indicator_warmup_nbars(strategy): if strategy is None: return 0 - nbars = max((np.isnan(indicator.astype(float)).argmin(axis=-1).max() - for _, indicator in _strategy_indicators(strategy) - if not indicator._opts['scatter']), default=0) + nbars = max( + ( + np.isnan(indicator.astype(float)).argmin(axis=-1).max() + for _, indicator in _strategy_indicators(strategy) + if not indicator._opts["scatter"] + ), + default=0, + ) return nbars @@ -98,6 +107,7 @@ class _Array(np.ndarray): ndarray extended to supply .name and other arbitrary properties in ._opts dict. """ + def __new__(cls, array, *, name=None, **kwargs): obj = np.asarray(array).view(cls) obj.name = name or array.name @@ -106,8 +116,8 @@ def __new__(cls, array, *, name=None, **kwargs): def __array_finalize__(self, obj): if obj is not None: - self.name = getattr(obj, 'name', '') - self._opts = getattr(obj, '_opts', {}) + self.name = getattr(obj, "name", "") + self._opts = getattr(obj, "_opts", {}) # Make sure properties name and _opts are carried over # when (un-)pickling. @@ -138,13 +148,13 @@ def to_series(self): @property def s(self) -> pd.Series: values = np.atleast_2d(self) - index = self._opts['index'][:values.shape[1]] + index = self._opts["index"][: values.shape[1]] return pd.Series(values[0], index=index, name=self.name) @property def df(self) -> pd.DataFrame: values = np.atleast_2d(np.asarray(self)) - index = self._opts['index'][:values.shape[1]] + index = self._opts["index"][: values.shape[1]] df = pd.DataFrame(values.T, index=index, columns=[self.name] * len(values)) return df @@ -160,8 +170,10 @@ class _Data: and the returned "series" are _not_ `pd.Series` but `np.ndarray` for performance reasons. """ - def __init__(self, df: pd.DataFrame): + + def __init__(self, df: pd.DataFrame, *, symbol: Optional[str] = None): self.__df = df + self.__symbol = symbol self.__len = len(df) # Current length self.__pip: Optional[float] = None self.__cache: Dict[str, _Array] = {} @@ -183,62 +195,65 @@ def _set_length(self, length): def _update(self): index = self.__df.index.copy() - self.__arrays = {col: _Array(arr, index=index) - for col, arr in self.__df.items()} + self.__arrays = { + col: _Array(arr, index=index, symbol=self.__symbol) for col, arr in self.__df.items() + } # Leave index as Series because pd.Timestamp nicer API to work with - self.__arrays['__index'] = index + self.__arrays["__index"] = index def __repr__(self): i = min(self.__len, len(self.__df)) - 1 - index = self.__arrays['__index'][i] - items = ', '.join(f'{k}={v}' for k, v in self.__df.iloc[i].items()) - return f'' + index = self.__arrays["__index"][i] + items = ", ".join(f"{k}={v}" for k, v in self.__df.iloc[i].items()) + return f"" def __len__(self): return self.__len @property def df(self) -> pd.DataFrame: - return (self.__df.iloc[:self.__len] - if self.__len < len(self.__df) - else self.__df) + return self.__df.iloc[: self.__len] if self.__len < len(self.__df) else self.__df @property def pip(self) -> float: if self.__pip is None: - self.__pip = float(10**-np.median([len(s.partition('.')[-1]) - for s in self.__arrays['Close'].astype(str)])) + self.__pip = float( + 10 + ** -np.median( + [len(s.partition(".")[-1]) for s in self.__arrays["Close"].astype(str)] + ) + ) return self.__pip def __get_array(self, key) -> _Array: arr = self.__cache.get(key) if arr is None: - arr = self.__cache[key] = cast(_Array, self.__arrays[key][:self.__len]) + arr = self.__cache[key] = cast(_Array, self.__arrays[key][: self.__len]) return arr @property def Open(self) -> _Array: - return self.__get_array('Open') + return self.__get_array("Open") @property def High(self) -> _Array: - return self.__get_array('High') + return self.__get_array("High") @property def Low(self) -> _Array: - return self.__get_array('Low') + return self.__get_array("Low") @property def Close(self) -> _Array: - return self.__get_array('Close') + return self.__get_array("Close") @property def Volume(self) -> _Array: - return self.__get_array('Volume') + return self.__get_array("Volume") @property def index(self) -> pd.DatetimeIndex: - return self.__get_array('__index') + return self.__get_array("__index") # Make pickling in Backtest.optimize() work with our catch-all __getattr__ def __getstate__(self): @@ -248,9 +263,133 @@ def __setstate__(self, state): self.__dict__ = state +class _DataColumnView: + def __init__(self, data: "_MultiData", column: str): + self._data = data + self._column = column + + def __getitem__(self, key): + if isinstance(key, str): + return self._data[key][self._column] + return self._data[self._data.primary_symbol][self._column][key] + + def __len__(self): + return len(self._data[self._data.primary_symbol][self._column]) + + def __array__(self, dtype=None): + return np.asarray(self._data[self._data.primary_symbol][self._column], dtype=dtype) + + @property + def shape(self): + return self._data[self._data.primary_symbol][self._column].shape + + @property + def ndim(self): + return self._data[self._data.primary_symbol][self._column].ndim + + @property + def dtype(self): + return self._data[self._data.primary_symbol][self._column].dtype + + @property + def s(self): + return self._data[self._data.primary_symbol][self._column].s + + @property + def df(self): + return self._data[self._data.primary_symbol][self._column].df + + def get(self, symbol: str, default=None): + try: + return self[symbol] + except KeyError: + return default + + def keys(self): + return self._data.symbols + + def items(self): + for symbol in self._data.symbols: + yield symbol, self[symbol] + + def values(self): + for _, values in self.items(): + yield values + + +class _MultiData: + def __init__(self, data: Mapping[str, pd.DataFrame]): + if not data: + raise ValueError("Need at least one asset data frame") + self.__assets: Dict[str, _Data] = { + str(symbol): _Data(df, symbol=str(symbol)) for symbol, df in data.items() + } + self.__symbols = tuple(self.__assets) + self.__primary_symbol = self.__symbols[0] + self.__columns: Dict[str, _DataColumnView] = { + col: _DataColumnView(self, col) for col in ("Open", "High", "Low", "Close", "Volume") + } + + @property + def symbols(self) -> tuple[str, ...]: + return self.__symbols + + @property + def primary_symbol(self) -> str: + return self.__primary_symbol + + @property + def assets(self) -> Dict[str, _Data]: + return self.__assets + + def __len__(self): + return len(next(iter(self.__assets.values()))) + + def __getitem__(self, item): + return self.__assets[item] + + def __getattr__(self, item): + if item in self.__columns: + return self.__columns[item] + raise AttributeError(f"Asset '{item}' not in data") + + def get(self, symbol: str, default=None): + return self.__assets.get(symbol, default) + + def keys(self): + return self.__assets.keys() + + def items(self): + return self.__assets.items() + + def values(self): + return self.__assets.values() + + def _set_length(self, length): + for data in self.__assets.values(): + data._set_length(length) + + def _update(self): + for data in self.__assets.values(): + data._update() + + @property + def index(self) -> pd.DatetimeIndex: + return next(iter(self.__assets.values())).index + + @property + def df(self) -> pd.DataFrame: + return self[self.__primary_symbol].df + + @property + def pip(self) -> float: + return self[self.__primary_symbol].pip + + if sys.version_info >= (3, 13): SharedMemory = _mpshm.SharedMemory else: + class SharedMemory(_mpshm.SharedMemory): # From https://github.com/python/cpython/issues/82300#issuecomment-2169035092 __lock = Lock() @@ -260,7 +399,7 @@ def __init__(self, *args, track: bool = True, **kwargs): if track: return super().__init__(*args, **kwargs) with self.__lock: - with patch(_mprt, 'register', lambda *a, **kw: None): + with patch(_mprt, "register", lambda *a, **kw: None): super().__init__(*args, **kwargs) def unlink(self): @@ -275,6 +414,7 @@ class SharedMemoryManager: A simple shared memory contextmanager based on https://docs.python.org/3/library/multiprocessing.shared_memory.html#multiprocessing.shared_memory.SharedMemory """ + def __init__(self, create=False) -> None: self._shms: list[SharedMemory] = [] self.__create = create @@ -297,8 +437,11 @@ def __exit__(self, *args, **kwargs): if shm._create: shm.unlink() except Exception: - warnings.warn(f'Failed to unlink shared memory {shm.name!r}', - category=ResourceWarning, stacklevel=2) + warnings.warn( + f"Failed to unlink shared memory {shm.name!r}", + category=ResourceWarning, + stacklevel=2, + ) raise def arr2shm(self, vals): @@ -308,15 +451,17 @@ def arr2shm(self, vals): # np.array can't handle pandas' tz-aware datetimes # https://github.com/numpy/numpy/issues/18279 buf = np.ndarray(vals.shape, dtype=vals.dtype.base, buffer=shm.buf) - has_tz = getattr(vals.dtype, 'tz', None) + has_tz = getattr(vals.dtype, "tz", None) buf[:] = vals.tz_localize(None) if has_tz else vals # Copy into shared memory return shm.name, vals.shape, vals.dtype def df2shm(self, df): - return tuple(( - (column, *self.arr2shm(values)) - for column, values in chain([(self._DF_INDEX_COL, df.index)], df.items()) - )) + return tuple( + ( + (column, *self.arr2shm(values)) + for column, values in chain([(self._DF_INDEX_COL, df.index)], df.items()) + ) + ) @staticmethod def shm2s(shm, shape, dtype) -> pd.Series: @@ -324,14 +469,17 @@ def shm2s(shm, shape, dtype) -> pd.Series: arr.setflags(write=False) return pd.Series(arr, dtype=dtype) - _DF_INDEX_COL = '__bt_index' + _DF_INDEX_COL = "__bt_index" @staticmethod def shm2df(data_shm): shm = [SharedMemory(name=name, create=False, track=False) for _, name, _, _ in data_shm] - df = pd.DataFrame({ - col: SharedMemoryManager.shm2s(shm, shape, dtype) - for shm, (col, _, shape, dtype) in zip(shm, data_shm)}) + df = pd.DataFrame( + { + col: SharedMemoryManager.shm2s(shm, shape, dtype) + for shm, (col, _, shape, dtype) in zip(shm, data_shm) + } + ) df.set_index(SharedMemoryManager._DF_INDEX_COL, drop=True, inplace=True) df.index.name = None return df, shm diff --git a/backtesting/backtesting.py b/backtesting/backtesting.py index 9ed77d6b..1922ca4a 100644 --- a/backtesting/backtesting.py +++ b/backtesting/backtesting.py @@ -17,7 +17,7 @@ from itertools import chain, product, repeat from math import copysign from numbers import Number -from typing import Callable, List, Optional, Sequence, Tuple, Type, Union +from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union import numpy as np import pandas as pd @@ -26,15 +26,25 @@ from ._plotting import plot # noqa: I001 from ._stats import compute_stats, dummy_stats from ._util import ( - SharedMemoryManager, _as_str, _Indicator, _Data, _batch, _indicator_warmup_nbars, - _strategy_indicators, patch, try_, _tqdm, + SharedMemoryManager, + _Array, + _as_str, + _Data, + _Indicator, + _MultiData, + _batch, + _indicator_warmup_nbars, + _strategy_indicators, + patch, + try_, + _tqdm, ) __pdoc__ = { - 'Strategy.__init__': False, - 'Order.__init__': False, - 'Position.__init__': False, - 'Trade.__init__': False, + "Strategy.__init__": False, + "Order.__init__": False, + "Position.__init__": False, + "Trade.__init__": False, } @@ -46,38 +56,50 @@ class Strategy(metaclass=ABCMeta): `backtesting.backtesting.Strategy.next` to define your own strategy. """ + def __init__(self, broker, data, params): self._indicators = [] self._broker: _Broker = broker - self._data: _Data = data + self._data: Union[_Data, _MultiData] = data self._params = self._check_params(params) def __repr__(self): - return '' + return "" def __str__(self): - params = ','.join(f'{i[0]}={i[1]}' for i in zip(self._params.keys(), - map(_as_str, self._params.values()))) + params = ",".join( + f"{i[0]}={i[1]}" for i in zip(self._params.keys(), map(_as_str, self._params.values())) + ) if params: - params = '(' + params + ')' - return f'{self.__class__.__name__}{params}' + params = "(" + params + ")" + return f"{self.__class__.__name__}{params}" def _check_params(self, params): for k, v in params.items(): if not hasattr(self, k): - suggestions = get_close_matches(k, (attr for attr in dir(self) if not attr.startswith('_'))) + suggestions = get_close_matches( + k, (attr for attr in dir(self) if not attr.startswith("_")) + ) hint = f" Did you mean: {', '.join(suggestions)}?" if suggestions else "" raise AttributeError( f"Strategy '{self.__class__.__name__}' is missing parameter '{k}'. " "Strategy class should define parameters as class variables before they " - "can be optimized or run with." + hint) + "can be optimized or run with." + hint + ) setattr(self, k, v) return params - def I(self, # noqa: E743 - func: Callable, *args, - name=None, plot=True, overlay=None, color=None, scatter=False, - **kwargs) -> np.ndarray: + def I( + self, # noqa: E743 + func: Callable, + *args, + name=None, + plot=True, + overlay=None, + color=None, + scatter=False, + **kwargs, + ) -> np.ndarray: """ Declare an indicator. An indicator is just an array of values (or a tuple of such arrays in case of, e.g., MACD indicator), @@ -124,21 +146,24 @@ def init(): strategy that uses a 200-bar MA). This can affect results. """ + def _format_name(name: str) -> str: - return name.format(*map(_as_str, args), - **dict(zip(kwargs.keys(), map(_as_str, kwargs.values())))) + return name.format( + *map(_as_str, args), **dict(zip(kwargs.keys(), map(_as_str, kwargs.values()))) + ) if name is None: - params = ','.join(filter(None, map(_as_str, chain(args, kwargs.values())))) + params = ",".join(filter(None, map(_as_str, chain(args, kwargs.values())))) func_name = _as_str(func) - name = (f'{func_name}({params})' if params else f'{func_name}') + name = f"{func_name}({params})" if params else f"{func_name}" elif isinstance(name, str): name = _format_name(name) elif try_(lambda: all(isinstance(item, str) for item in name), False): name = [_format_name(item) for item in name] else: - raise TypeError(f'Unexpected `name=` type {type(name)}; expected `str` or ' - '`Sequence[str]`') + raise TypeError( + f"Unexpected `name=` type {type(name)}; expected `str` or `Sequence[str]`" + ) try: value = func(*args, **kwargs) @@ -149,7 +174,7 @@ def _format_name(name: str) -> str: value = value.values.T if value is not None: - value = try_(lambda: np.asarray(value, order='C'), None) + value = try_(lambda: np.asarray(value, order="C"), None) is_arraylike = bool(value is not None and value.shape) # Optionally flip the array if the user returned e.g. `df.values` @@ -158,26 +183,49 @@ def _format_name(name: str) -> str: if isinstance(name, list) and (np.atleast_2d(value).shape[0] != len(name)): raise ValueError( - f'Length of `name=` ({len(name)}) must agree with the number ' - f'of arrays the indicator returns ({value.shape[0]}).') + f"Length of `name=` ({len(name)}) must agree with the number " + f"of arrays the indicator returns ({value.shape[0]})." + ) if not is_arraylike or not 1 <= value.ndim <= 2 or value.shape[-1] != len(self._data.Close): raise ValueError( - 'Indicators must return (optionally a tuple of) numpy.arrays of same ' + "Indicators must return (optionally a tuple of) numpy.arrays of same " f'length as `data` (data shape: {self._data.Close.shape}; indicator "{name}" ' - f'shape: {getattr(value, "shape", "")}, returned value: {value})') + f"shape: {getattr(value, 'shape', '')}, returned value: {value})" + ) + + indicator_symbol = next( + ( + arg._opts.get("symbol") + for arg in args + if isinstance(arg, _Array) and arg._opts.get("symbol") + ), + None, + ) if overlay is None and np.issubdtype(value.dtype, np.number): - x = value / self._data.Close + reference_close = ( + self._data[indicator_symbol].Close + if indicator_symbol and isinstance(self._data, _MultiData) + else self._data.Close + ) + x = value / reference_close # By default, overlay if strong majority of indicator values # is within 30% of Close - with np.errstate(invalid='ignore'): - overlay = ((x < 1.4) & (x > .6)).mean() > .6 - - value = _Indicator(value, name=name, plot=plot, overlay=overlay, - color=color, scatter=scatter, - # _Indicator.s Series accessor uses this: - index=self.data.index) + with np.errstate(invalid="ignore"): + overlay = ((x < 1.4) & (x > 0.6)).mean() > 0.6 + + value = _Indicator( + value, + name=name, + plot=plot, + overlay=overlay, + color=color, + scatter=scatter, + symbol=indicator_symbol, + # _Indicator.s Series accessor uses this: + index=self.data.index, + ) self._indicators.append(value) return value @@ -213,16 +261,22 @@ def next(self): """ class __FULL_EQUITY(float): # noqa: N801 - def __repr__(self): return '.9999' # noqa: E704 + def __repr__(self): + return ".9999" # noqa: E704 + _FULL_EQUITY = __FULL_EQUITY(1 - sys.float_info.epsilon) - def buy(self, *, - size: float = _FULL_EQUITY, - limit: Optional[float] = None, - stop: Optional[float] = None, - sl: Optional[float] = None, - tp: Optional[float] = None, - tag: object = None) -> 'Order': + def buy( + self, + *, + size: float = _FULL_EQUITY, + limit: Optional[float] = None, + stop: Optional[float] = None, + sl: Optional[float] = None, + tp: Optional[float] = None, + tag: object = None, + symbol: Optional[str] = None, + ) -> "Order": """ Place a new long order and return it. For explanation of parameters, see `Order` and its properties. @@ -235,17 +289,22 @@ def buy(self, *, See also `Strategy.sell()`. """ - assert 0 < size < 1 or round(size) == size >= 1, \ + assert 0 < size < 1 or round(size) == size >= 1, ( "size must be a positive fraction of equity, or a positive whole number of units" - return self._broker.new_order(size, limit, stop, sl, tp, tag) - - def sell(self, *, - size: float = _FULL_EQUITY, - limit: Optional[float] = None, - stop: Optional[float] = None, - sl: Optional[float] = None, - tp: Optional[float] = None, - tag: object = None) -> 'Order': + ) + return self._broker.new_order(size, limit, stop, sl, tp, tag, symbol=symbol) + + def sell( + self, + *, + size: float = _FULL_EQUITY, + limit: Optional[float] = None, + stop: Optional[float] = None, + sl: Optional[float] = None, + tp: Optional[float] = None, + tag: object = None, + symbol: Optional[str] = None, + ) -> "Order": """ Place a new short order and return it. For explanation of parameters, see `Order` and its properties. @@ -266,9 +325,10 @@ def sell(self, *, If you merely want to close an existing long position, use `Position.close()` or `Trade.close()`. """ - assert 0 < size < 1 or round(size) == size >= 1, \ + assert 0 < size < 1 or round(size) == size >= 1, ( "size must be a positive fraction of equity, or a positive whole number of units" - return self._broker.new_order(-size, limit, stop, sl, tp, tag) + ) + return self._broker.new_order(-size, limit, stop, sl, tp, tag, symbol=symbol) @property def equity(self) -> float: @@ -276,7 +336,7 @@ def equity(self) -> float: return self._broker.equity @property - def data(self) -> _Data: + def data(self) -> Union[_Data, _MultiData]: """ Price data, roughly as passed into `backtesting.backtesting.Backtest.__init__`, @@ -305,22 +365,22 @@ def data(self) -> _Data: return self._data @property - def position(self) -> 'Position': + def position(self) -> "Position": """Instance of `backtesting.backtesting.Position`.""" return self._broker.position @property - def orders(self) -> 'Tuple[Order, ...]': + def orders(self) -> "Tuple[Order, ...]": """List of orders (see `Order`) waiting for execution.""" return tuple(self._broker.orders) @property - def trades(self) -> 'Tuple[Trade, ...]': + def trades(self) -> "Tuple[Trade, ...]": """List of active trades (see `Trade`).""" return tuple(self._broker.trades) @property - def closed_trades(self) -> 'Tuple[Trade, ...]': + def closed_trades(self) -> "Tuple[Trade, ...]": """List of settled trades (see `Trade`).""" return tuple(self._broker.closed_trades) @@ -335,8 +395,11 @@ class Position: if self.position: ... # we have a position, either long or short """ - def __init__(self, broker: '_Broker'): + + def __init__(self, broker: "_Broker", symbol: Optional[str] = None): self.__broker = broker + self.__symbol = symbol + self.__by_symbol: Dict[str, "Position"] = {} def __bool__(self): return self.size != 0 @@ -344,17 +407,20 @@ def __bool__(self): @property def size(self) -> float: """Position size in units of asset. Negative if position is short.""" - return sum(trade.size for trade in self.__broker.trades) + return sum(trade.size for trade in self.__broker._symbol_trades(self.__symbol)) @property def pl(self) -> float: """Profit (positive) or loss (negative) of the current position in cash units.""" - return sum(trade.pl for trade in self.__broker.trades) + return sum(trade.pl for trade in self.__broker._symbol_trades(self.__symbol)) @property def pl_pct(self) -> float: """Profit (positive) or loss (negative) of the current position in percent.""" - total_invested = sum(trade.entry_price * abs(trade.size) for trade in self.__broker.trades) + total_invested = sum( + trade.entry_price * abs(trade.size) + for trade in self.__broker._symbol_trades(self.__symbol) + ) return (self.pl / total_invested) * 100 if total_invested else 0 @property @@ -367,15 +433,28 @@ def is_short(self) -> bool: """True if the position is short (position size is negative).""" return self.size < 0 - def close(self, portion: float = 1.): + def close(self, portion: float = 1.0): """ Close portion of position by closing `portion` of each active trade. See `Trade.close`. """ - for trade in self.__broker.trades: + for trade in tuple(self.__broker._symbol_trades(self.__symbol)): trade.close(portion) + def __getitem__(self, symbol: str) -> "Position": + if len(self.__broker.symbols) == 1: + return self + if symbol not in self.__broker.symbols: + raise KeyError(f"Unknown symbol {symbol!r}") + return self.__by_symbol.setdefault(symbol, Position(self.__broker, symbol=symbol)) + + def get(self, symbol: str, default=None): + try: + return self[symbol] + except KeyError: + return default + def __repr__(self): - return f'' + return f"" class _OutOfMoneyError(Exception): @@ -397,14 +476,19 @@ class Order: [filled]: https://www.investopedia.com/terms/f/fill.asp [Good 'Til Canceled]: https://www.investopedia.com/terms/g/gtc.asp """ - def __init__(self, broker: '_Broker', - size: float, - limit_price: Optional[float] = None, - stop_price: Optional[float] = None, - sl_price: Optional[float] = None, - tp_price: Optional[float] = None, - parent_trade: Optional['Trade'] = None, - tag: object = None): + + def __init__( + self, + broker: "_Broker", + size: float, + limit_price: Optional[float] = None, + stop_price: Optional[float] = None, + sl_price: Optional[float] = None, + tp_price: Optional[float] = None, + parent_trade: Optional["Trade"] = None, + tag: object = None, + symbol: Optional[str] = None, + ): self.__broker = broker assert size != 0 self.__size = size @@ -414,23 +498,30 @@ def __init__(self, broker: '_Broker', self.__tp_price = tp_price self.__parent_trade = parent_trade self.__tag = tag + self.__symbol = symbol or broker.primary_symbol def _replace(self, **kwargs): for k, v in kwargs.items(): - setattr(self, f'_{self.__class__.__qualname__}__{k}', v) + setattr(self, f"_{self.__class__.__qualname__}__{k}", v) return self def __repr__(self): - return ''.format(', '.join(f'{param}={try_(lambda: round(value, 5), value)!r}' - for param, value in ( - ('size', self.__size), - ('limit', self.__limit_price), - ('stop', self.__stop_price), - ('sl', self.__sl_price), - ('tp', self.__tp_price), - ('contingent', self.is_contingent), - ('tag', self.__tag), - ) if value is not None)) # noqa: E126 + return "".format( + ", ".join( + f"{param}={try_(lambda: round(value, 5), value)!r}" + for param, value in ( + ("size", self.__size), + ("limit", self.__limit_price), + ("stop", self.__stop_price), + ("sl", self.__sl_price), + ("tp", self.__tp_price), + ("contingent", self.is_contingent), + ("symbol", self.__symbol), + ("tag", self.__tag), + ) + if value is not None + ) + ) # noqa: E126 def cancel(self): """Cancel the order.""" @@ -508,7 +599,11 @@ def tag(self): """ return self.__tag - __pdoc__['Order.parent_trade'] = False + @property + def symbol(self) -> str: + return self.__symbol + + __pdoc__["Order.parent_trade"] = False # Extra properties @@ -534,9 +629,10 @@ def is_contingent(self): [contingent]: https://www.investopedia.com/terms/c/contingentorder.asp [OCO]: https://www.investopedia.com/terms/o/oco.asp """ - return bool((parent := self.__parent_trade) and - (self is parent._sl_order or - self is parent._tp_order)) + return bool( + (parent := self.__parent_trade) + and (self is parent._sl_order or self is parent._tp_order) + ) class Trade: @@ -544,7 +640,16 @@ class Trade: When an `Order` is filled, it results in an active `Trade`. Find active trades in `Strategy.trades` and closed, settled trades in `Strategy.closed_trades`. """ - def __init__(self, broker: '_Broker', size: int, entry_price: float, entry_bar, tag): + + def __init__( + self, + broker: "_Broker", + size: int, + entry_price: float, + entry_bar, + tag, + symbol: str, + ): self.__broker = broker self.__size = size self.__entry_price = entry_price @@ -554,27 +659,30 @@ def __init__(self, broker: '_Broker', size: int, entry_price: float, entry_bar, self.__sl_order: Optional[Order] = None self.__tp_order: Optional[Order] = None self.__tag = tag + self.__symbol = symbol self._commissions = 0 def __repr__(self): - return f'' + return ( + f"" + ) def _replace(self, **kwargs): for k, v in kwargs.items(): - setattr(self, f'_{self.__class__.__qualname__}__{k}', v) + setattr(self, f"_{self.__class__.__qualname__}__{k}", v) return self def _copy(self, **kwargs): return copy(self)._replace(**kwargs) - def close(self, portion: float = 1.): + def close(self, portion: float = 1.0): """Place new `Order` to close `portion` of the trade at next market price.""" assert 0 < portion <= 1, "portion must be a fraction between 0 and 1" # Ensure size is an int to avoid rounding errors on 32-bit OS size = copysign(max(1, int(round(abs(self.__size) * portion))), -self.__size) - order = Order(self.__broker, size, parent_trade=self, tag=self.__tag) + order = Order(self.__broker, size, parent_trade=self, tag=self.__tag, symbol=self.__symbol) self.__broker.orders.insert(0, order) # Fields getters @@ -620,6 +728,10 @@ def tag(self): """ return self.__tag + @property + def symbol(self) -> str: + return self.__symbol + @property def _sl_order(self): return self.__sl_order @@ -633,14 +745,14 @@ def _tp_order(self): @property def entry_time(self) -> Union[pd.Timestamp, int]: """Datetime of when the trade was entered.""" - return self.__broker._data.index[self.__entry_bar] + return self.__broker._symbol_data(self.__symbol).index[self.__entry_bar] @property def exit_time(self) -> Optional[Union[pd.Timestamp, int]]: """Datetime of when the trade was exited.""" if self.__exit_bar is None: return None - return self.__broker._data.index[self.__exit_bar] + return self.__broker._symbol_data(self.__symbol).index[self.__exit_bar] @property def is_long(self): @@ -658,13 +770,13 @@ def pl(self): Trade profit (positive) or loss (negative) in cash units. Commissions are reflected only after the Trade is closed. """ - price = self.__exit_price or self.__broker.last_price + price = self.__exit_price or self.__broker.last_price(self.__symbol) return (self.__size * (price - self.__entry_price)) - self._commissions @property def pl_pct(self): """Trade profit (positive) or loss (negative) in percent relative to trade entry price.""" - price = self.__exit_price or self.__broker.last_price + price = self.__exit_price or self.__broker.last_price(self.__symbol) gross_pl_pct = copysign(1, self.__size) * (price / self.__entry_price - 1) # Total commission across the entire trade size to individual units @@ -674,7 +786,7 @@ def pl_pct(self): @property def value(self): """Trade total value in cash (volume × price).""" - price = self.__exit_price or self.__broker.last_price + price = self.__exit_price or self.__broker.last_price(self.__symbol) return abs(self.__size) * price # SL/TP management API @@ -692,7 +804,7 @@ def sl(self): @sl.setter def sl(self, price: float): - self.__set_contingent('sl', price) + self.__set_contingent("sl", price) @property def tp(self): @@ -707,27 +819,44 @@ def tp(self): @tp.setter def tp(self, price: float): - self.__set_contingent('tp', price) + self.__set_contingent("tp", price) def __set_contingent(self, type, price): - assert type in ('sl', 'tp') - assert price is None or 0 < price < np.inf, f'Make sure 0 < price < inf! price: {price}' - attr = f'_{self.__class__.__qualname__}__{type}_order' + assert type in ("sl", "tp") + assert price is None or 0 < price < np.inf, f"Make sure 0 < price < inf! price: {price}" + attr = f"_{self.__class__.__qualname__}__{type}_order" order: Order = getattr(self, attr) if order: order.cancel() if price: - kwargs = {'stop': price} if type == 'sl' else {'limit': price} + kwargs = {"stop": price} if type == "sl" else {"limit": price} order = self.__broker.new_order(-self.size, trade=self, tag=self.tag, **kwargs) setattr(self, attr, order) class _Broker: - def __init__(self, *, data, cash, spread, commission, margin, - trade_on_close, hedging, exclusive_orders, index): + def __init__( + self, + *, + data, + cash, + spread, + commission, + margin, + trade_on_close, + hedging, + exclusive_orders, + index, + ): assert cash > 0, f"cash should be > 0, is {cash}" assert 0 < margin <= 1, f"margin should be between 0 and 1, is {margin}" - self._data: _Data = data + self._data: Union[_Data, _MultiData] = data + if isinstance(data, _MultiData): + self.symbols = data.symbols + self.primary_symbol = data.primary_symbol + else: + self.symbols = ("_",) + self.primary_symbol = self.symbols[0] self._cash = cash if callable(commission): @@ -737,10 +866,11 @@ def __init__(self, *, data, cash, spread, commission, margin, self._commission_fixed, self._commission_relative = commission except TypeError: self._commission_fixed, self._commission_relative = 0, commission - assert self._commission_fixed >= 0, 'Need fixed cash commission in $ >= 0' - assert -.1 <= self._commission_relative < .1, \ - ("commission should be between -10% " - f"(e.g. market-maker's rebates) and 10% (fees), is {self._commission_relative}") + assert self._commission_fixed >= 0, "Need fixed cash commission in $ >= 0" + assert -0.1 <= self._commission_relative < 0.1, ( + "commission should be between -10% " + f"(e.g. market-maker's rebates) and 10% (fees), is {self._commission_relative}" + ) self._commission = self._commission_func self._spread = spread @@ -755,21 +885,35 @@ def __init__(self, *, data, cash, spread, commission, margin, self.position = Position(self) self.closed_trades: List[Trade] = [] + def _symbol_data(self, symbol: Optional[str] = None) -> _Data: + symbol = symbol or self.primary_symbol + if isinstance(self._data, _MultiData): + return self._data[symbol] + return self._data + + def _symbol_trades(self, symbol: Optional[str] = None): + if symbol is None: + return self.trades + return [trade for trade in self.trades if trade.symbol == symbol] + def _commission_func(self, order_size, price): return self._commission_fixed + abs(order_size) * price * self._commission_relative def __repr__(self): - return f'' - - def new_order(self, - size: float, - limit: Optional[float] = None, - stop: Optional[float] = None, - sl: Optional[float] = None, - tp: Optional[float] = None, - tag: object = None, - *, - trade: Optional[Trade] = None) -> Order: + return f"" + + def new_order( + self, + size: float, + limit: Optional[float] = None, + stop: Optional[float] = None, + sl: Optional[float] = None, + tp: Optional[float] = None, + tag: object = None, + symbol: Optional[str] = None, + *, + trade: Optional[Trade] = None, + ) -> Order: """ Argument size indicates whether the order is long or short """ @@ -781,29 +925,36 @@ def new_order(self, is_long = size > 0 assert size != 0, size - adjusted_price = self._adjusted_price(size) + symbol = symbol or (trade.symbol if trade else self.primary_symbol) + if len(self.symbols) > 1 and symbol not in self.symbols: + raise KeyError(f"Unknown symbol {symbol!r}") + if len(self.symbols) == 1: + symbol = self.primary_symbol + adjusted_price = self._adjusted_price(size, symbol=symbol) if is_long: if not (sl or -np.inf) < (limit or stop or adjusted_price) < (tp or np.inf): raise ValueError( "Long orders require: " - f"SL ({sl}) < LIMIT ({limit or stop or adjusted_price}) < TP ({tp})") + f"SL ({sl}) < LIMIT ({limit or stop or adjusted_price}) < TP ({tp})" + ) else: if not (tp or -np.inf) < (limit or stop or adjusted_price) < (sl or np.inf): raise ValueError( "Short orders require: " - f"TP ({tp}) < LIMIT ({limit or stop or adjusted_price}) < SL ({sl})") + f"TP ({tp}) < LIMIT ({limit or stop or adjusted_price}) < SL ({sl})" + ) - order = Order(self, size, limit, stop, sl, tp, trade, tag) + order = Order(self, size, limit, stop, sl, tp, trade, tag, symbol=symbol) if not trade: # If exclusive orders (each new order auto-closes previous orders/position), # cancel all non-contingent orders and close all open trades beforehand if self._exclusive_orders: for o in self.orders: - if not o.is_contingent: + if not o.is_contingent and o.symbol == symbol: o.cancel() - for t in self.trades: + for t in tuple(self._symbol_trades(symbol)): t.close() # Put the new order in the order queue, Ensure SL orders are processed first @@ -811,17 +962,16 @@ def new_order(self, return order - @property - def last_price(self) -> float: - """ Price at the last (current) close. """ - return self._data.Close[-1] + def last_price(self, symbol: Optional[str] = None) -> float: + """Price at the last (current) close.""" + return self._symbol_data(symbol).Close[-1] - def _adjusted_price(self, size=None, price=None) -> float: + def _adjusted_price(self, size=None, price=None, symbol: Optional[str] = None) -> float: """ Long/short `price`, adjusted for spread. In long positions, the adjusted price is a fraction higher, and vice versa. """ - return (price or self.last_price) * (1 + copysign(self._spread, size)) + return (price or self.last_price(symbol)) * (1 + copysign(self._spread, size)) @property def equity(self) -> float: @@ -834,7 +984,7 @@ def margin_available(self) -> float: return max(0, self.equity - margin_used) def next(self): - i = self._i = len(self._data) - 1 + i = self._i = len(self._symbol_data()) - 1 self._process_orders() # Log account equity for the equity curve @@ -844,28 +994,28 @@ def next(self): # If equity is negative, set all to 0 and stop the simulation if equity <= 0: assert self.margin_available <= 0 - for trade in self.trades: - self._close_trade(trade, self._data.Close[-1], i) + for trade in tuple(self.trades): + self._close_trade(trade, self.last_price(trade.symbol), i) self._cash = 0 self._equity[i:] = 0 raise _OutOfMoneyError def _process_orders(self): - data = self._data - open, high, low = data.Open[-1], data.High[-1], data.Low[-1] reprocess_orders = False # Process orders for order in list(self.orders): # type: Order - # Related SL/TP order was already removed if order not in self.orders: continue + data = self._symbol_data(order.symbol) + open, high, low = data.Open[-1], data.High[-1], data.Low[-1] + # Check if stop condition was hit stop_price = order.stop if stop_price: - is_stop_hit = ((high >= stop_price) if order.is_long else (low <= stop_price)) + is_stop_hit = (high >= stop_price) if order.is_long else (low <= stop_price) if not is_stop_hit: continue @@ -879,17 +1029,20 @@ def _process_orders(self): is_limit_hit = low <= order.limit if order.is_long else high >= order.limit # When stop and limit are hit within the same bar, we pessimistically # assume limit was hit before the stop (i.e. "before it counts") - is_limit_hit_before_stop = (is_limit_hit and - (order.limit <= (stop_price or -np.inf) - if order.is_long - else order.limit >= (stop_price or np.inf))) + is_limit_hit_before_stop = is_limit_hit and ( + order.limit <= (stop_price or -np.inf) + if order.is_long + else order.limit >= (stop_price or np.inf) + ) if not is_limit_hit or is_limit_hit_before_stop: continue # stop_price, if set, was hit within this bar - price = (min(stop_price or open, order.limit) - if order.is_long else - max(stop_price or open, order.limit)) + price = ( + min(stop_price or open, order.limit) + if order.is_long + else max(stop_price or open, order.limit) + ) else: # Market-if-touched / market order # Contingent orders always on next open @@ -902,8 +1055,9 @@ def _process_orders(self): is_market_order = not order.limit and not stop_price time_index = ( (self._i - 1) - if is_market_order and self._trade_on_close and not order.is_contingent else - self._i) + if is_market_order and self._trade_on_close and not order.is_contingent + else self._i + ) # If order is a SL/TP order, it should close an existing trade it was contingent upon if order.parent_trade: @@ -919,8 +1073,7 @@ def _process_orders(self): if price == stop_price: # Set SL back on the order for stats._trades["SL"] trade._sl_order._replace(stop_price=stop_price) - if order in (trade._sl_order, - trade._tp_order): + if order in (trade._sl_order, trade._tp_order): assert order.size == -trade.size assert order not in self.orders # Removed when trade was closed else: @@ -933,22 +1086,29 @@ def _process_orders(self): # Adjust price to include commission (or bid-ask spread). # In long positions, the adjusted price is a fraction higher, and vice versa. - adjusted_price = self._adjusted_price(order.size, price) - adjusted_price_plus_commission = \ - adjusted_price + self._commission(order.size, price) / abs(order.size) + adjusted_price = self._adjusted_price(order.size, price, symbol=order.symbol) + adjusted_price_plus_commission = adjusted_price + self._commission( + order.size, price + ) / abs(order.size) # If order size was specified proportionally, # precompute true size in units, accounting for margin and spread/commissions size = order.size if -1 < size < 1: - size = copysign(int((self.margin_available * self._leverage * abs(size)) - // adjusted_price_plus_commission), size) + size = copysign( + int( + (self.margin_available * self._leverage * abs(size)) + // adjusted_price_plus_commission + ), + size, + ) # Not enough cash/margin even for a single unit if not size: warnings.warn( - f'time={self._i}: Broker canceled the relative-sized order due to insufficient margin ' - f'(equity={self.equity:.2f}, margin_available={self.margin_available:.2f}).', - category=UserWarning) + f"time={self._i}: Broker canceled the relative-sized order due to insufficient margin " + f"(equity={self.equity:.2f}, margin_available={self.margin_available:.2f}).", + category=UserWarning, + ) # XXX: The order is canceled by the broker? self.orders.remove(order) continue @@ -960,6 +1120,8 @@ def _process_orders(self): # Existing trades are closed at unadjusted price, because the adjustment # was already made when buying. for trade in list(self.trades): + if trade.symbol != order.symbol: + continue if trade.is_long == order.is_long: continue assert trade.size * order.size < 0 @@ -979,23 +1141,29 @@ def _process_orders(self): break # If we don't have enough liquidity to cover for the order, the broker CANCELS it - if abs(need_size) * adjusted_price_plus_commission > \ - self.margin_available * self._leverage: + if ( + abs(need_size) * adjusted_price_plus_commission + > self.margin_available * self._leverage + ): warnings.warn( - f'time={self._i}: Broker canceled the order due to insufficient margin ' - f'(equity={self.equity:.2f}, margin_available={self.margin_available:.2f}).', - category=UserWarning) + f"time={self._i}: Broker canceled the order due to insufficient margin " + f"(equity={self.equity:.2f}, margin_available={self.margin_available:.2f}).", + category=UserWarning, + ) self.orders.remove(order) continue # Open a new trade if need_size: - self._open_trade(adjusted_price, - need_size, - order.sl, - order.tp, - time_index, - order.tag) + self._open_trade( + adjusted_price, + need_size, + order.sl, + order.tp, + time_index, + order.tag, + order.symbol, + ) # We need to reprocess the SL/TP orders newly added to the queue. # This allows e.g. SL hitting in the same bar the order was open. @@ -1005,12 +1173,19 @@ def _process_orders(self): reprocess_orders = True # Order.stop and TP hit within the same bar, but SL wasn't. This case # is not ambiguous, because stop and TP go in the same price direction. - elif stop_price and not order.limit and order.tp and ( - (order.is_long and order.tp <= high and (order.sl or -np.inf) < low) or - (order.is_short and order.tp >= low and (order.sl or np.inf) > high)): + elif ( + stop_price + and not order.limit + and order.tp + and ( + (order.is_long and order.tp <= high and (order.sl or -np.inf) < low) + or (order.is_short and order.tp >= low and (order.sl or np.inf) > high) + ) + ): reprocess_orders = True - elif (low <= (order.sl or -np.inf) <= high or - low <= (order.tp or -np.inf) <= high): + elif ( + low <= (order.sl or -np.inf) <= high or low <= (order.tp or -np.inf) <= high + ): warnings.warn( f"({data.index[-1]}) A contingent SL/TP order would execute in the " "same bar its parent stop/limit order was turned into a trade. " @@ -1019,7 +1194,8 @@ def _process_orders(self): "the next (matching) price/bar, making the result (of this trade) " "somewhat dubious. " "See https://github.com/kernc/backtesting.py/issues/119", - UserWarning) + UserWarning, + ) # Order processed self.orders.remove(order) @@ -1067,9 +1243,17 @@ def _close_trade(self, trade: Trade, price: float, time_index: int): # by way of _reduce_trade() closed_trade._commissions = commission + trade_open_commission - def _open_trade(self, price: float, size: int, - sl: Optional[float], tp: Optional[float], time_index: int, tag): - trade = Trade(self, size, price, time_index, tag) + def _open_trade( + self, + price: float, + size: int, + sl: Optional[float], + tp: Optional[float], + time_index: int, + tag, + symbol: str, + ): + trade = Trade(self, size, price, time_index, tag, symbol) self.trades.append(trade) # Apply broker commission at trade open self._cash -= self._commission(size, price) @@ -1164,77 +1348,139 @@ class Backtest: [FIFO]: https://www.investopedia.com/terms/n/nfa-compliance-rule-2-43b.asp [active and ongoing]: https://kernc.github.io/backtesting.py/doc/backtesting/backtesting.html#backtesting.backtesting.Strategy.trades """ # noqa: E501 - def __init__(self, - data: pd.DataFrame, - strategy: Type[Strategy], - *, - cash: float = 10_000, - spread: float = .0, - commission: Union[float, Tuple[float, float]] = .0, - margin: float = 1., - trade_on_close=False, - hedging=False, - exclusive_orders=False, - finalize_trades=False, - ): + + def __init__( + self, + data: Union[pd.DataFrame, Mapping[str, pd.DataFrame], Sequence[pd.DataFrame]], + strategy: Type[Strategy], + *, + cash: float = 10_000, + spread: float = 0.0, + commission: Union[float, Tuple[float, float]] = 0.0, + margin: float = 1.0, + trade_on_close=False, + hedging=False, + exclusive_orders=False, + finalize_trades=False, + ): if not (isinstance(strategy, type) and issubclass(strategy, Strategy)): - raise TypeError('`strategy` must be a Strategy sub-type') - if not isinstance(data, pd.DataFrame): - raise TypeError("`data` must be a pandas.DataFrame with columns") + raise TypeError("`strategy` must be a Strategy sub-type") + if isinstance(data, Mapping): + if not data: + raise ValueError("`data` mapping must not be empty") + assets = {str(symbol): df for symbol, df in data.items()} + elif isinstance(data, Sequence) and not isinstance(data, (str, bytes, pd.DataFrame)): + if not data: + raise ValueError("`data` sequence must not be empty") + assets = {str(i): df for i, df in enumerate(data)} + elif isinstance(data, pd.DataFrame): + assets = None + else: + raise TypeError("`data` must be a pandas.DataFrame or mapping of symbol -> DataFrame") if not isinstance(spread, Number): - raise TypeError('`spread` must be a float value, percent of ' - 'entry order price') + raise TypeError("`spread` must be a float value, percent of entry order price") if not isinstance(commission, (Number, tuple)) and not callable(commission): - raise TypeError('`commission` must be a float percent of order value, ' - 'a tuple of `(fixed, relative)` commission, ' - 'or a function that takes `(order_size, price)`' - 'and returns commission dollar value') - - data = data.copy(deep=False) - - # Convert index to datetime index - if (not isinstance(data.index, pd.DatetimeIndex) and - not isinstance(data.index, pd.RangeIndex) and - # Numeric index with most large numbers - (data.index.is_numeric() and - (data.index > pd.Timestamp('1975').timestamp()).mean() > .8)): - try: - data.index = pd.to_datetime(data.index, infer_datetime_format=True) - except ValueError: - pass - - if 'Volume' not in data: - data['Volume'] = np.nan - - if len(data) == 0: - raise ValueError('OHLC `data` is empty') - if len(data.columns.intersection({'Open', 'High', 'Low', 'Close', 'Volume'})) != 5: - raise ValueError("`data` must be a pandas.DataFrame with columns " - "'Open', 'High', 'Low', 'Close', and (optionally) 'Volume'") - if data[['Open', 'High', 'Low', 'Close']].isnull().values.any(): - raise ValueError('Some OHLC values are missing (NaN). ' - 'Please strip those lines with `df.dropna()` or ' - 'fill them in with `df.interpolate()` or whatever.') - if np.any(data['Close'] > cash): - warnings.warn('Some prices are larger than initial cash value. Note that fractional ' - 'trading is not supported by this class. If you want to trade Bitcoin, ' - 'increase initial cash, or trade μBTC or satoshis instead (see e.g. class ' - '`backtesting.lib.FractionalBacktest`.', - stacklevel=2) - if not data.index.is_monotonic_increasing: - warnings.warn('Data index is not sorted in ascending order. Sorting.', - stacklevel=2) - data = data.sort_index() - if not isinstance(data.index, pd.DatetimeIndex): - warnings.warn('Data index is not datetime. Assuming simple periods, ' - 'but `pd.DateTimeIndex` is advised.', - stacklevel=2) - - self._data: pd.DataFrame = data + raise TypeError( + "`commission` must be a float percent of order value, " + "a tuple of `(fixed, relative)` commission, " + "or a function that takes `(order_size, price)`" + "and returns commission dollar value" + ) + + def _prepare_single_data(df: pd.DataFrame, *, symbol: Optional[str] = None) -> pd.DataFrame: + if not isinstance(df, pd.DataFrame): + if symbol is None: + raise TypeError("`data` must be a pandas.DataFrame with columns") + raise TypeError(f"`data[{symbol!r}]` must be a pandas.DataFrame with columns") + + df = df.copy(deep=False) + + # Convert index to datetime index + if ( + not isinstance(df.index, pd.DatetimeIndex) + and not isinstance(df.index, pd.RangeIndex) + and + # Numeric index with most large numbers + ( + df.index.is_numeric() + and (df.index > pd.Timestamp("1975").timestamp()).mean() > 0.8 + ) + ): + try: + df.index = pd.to_datetime(df.index, infer_datetime_format=True) + except ValueError: + pass + + if "Volume" not in df: + df["Volume"] = np.nan + + if len(df) == 0: + raise ValueError("OHLC `data` is empty") + if len(df.columns.intersection({"Open", "High", "Low", "Close", "Volume"})) != 5: + raise ValueError( + "`data` must be a pandas.DataFrame with columns " + "'Open', 'High', 'Low', 'Close', and (optionally) 'Volume'" + ) + if df[["Open", "High", "Low", "Close"]].isnull().values.any(): + raise ValueError( + "Some OHLC values are missing (NaN). " + "Please strip those lines with `df.dropna()` or " + "fill them in with `df.interpolate()` or whatever." + ) + if np.any(df["Close"] > cash): + warnings.warn( + "Some prices are larger than initial cash value. Note that fractional " + "trading is not supported by this class. If you want to trade Bitcoin, " + "increase initial cash, or trade μBTC or satoshis instead (see e.g. class " + "`backtesting.lib.FractionalBacktest`.", + stacklevel=2, + ) + if not df.index.is_monotonic_increasing: + warnings.warn("Data index is not sorted in ascending order. Sorting.", stacklevel=2) + df = df.sort_index() + if not isinstance(df.index, pd.DatetimeIndex): + warnings.warn( + "Data index is not datetime. Assuming simple periods, " + "but `pd.DateTimeIndex` is advised.", + stacklevel=2, + ) + return df + + if assets is None: + data = _prepare_single_data(data) + self._data: Union[pd.DataFrame, Dict[str, pd.DataFrame]] = data + self._is_multi_asset = False + self._primary_symbol = "_" + self._asset_data = {self._primary_symbol: data} + else: + prepared_assets = { + symbol: _prepare_single_data(df, symbol=symbol) for symbol, df in assets.items() + } + indexes = {symbol: df.index for symbol, df in prepared_assets.items()} + first_symbol = next(iter(prepared_assets)) + first_index = indexes[first_symbol] + mismatch = [s for s, idx in indexes.items() if not idx.equals(first_index)] + if mismatch: + raise ValueError( + "All asset DataFrames must have identical, aligned indexes in multi-asset mode. " + f"Mismatched symbols: {', '.join(mismatch)}" + ) + self._data = prepared_assets + self._is_multi_asset = True + self._primary_symbol = first_symbol + self._asset_data = prepared_assets + data = prepared_assets[first_symbol] + self._broker = partial( - _Broker, cash=cash, spread=spread, commission=commission, margin=margin, - trade_on_close=trade_on_close, hedging=hedging, - exclusive_orders=exclusive_orders, index=data.index, + _Broker, + cash=cash, + spread=spread, + commission=commission, + margin=margin, + trade_on_close=trade_on_close, + hedging=hedging, + exclusive_orders=exclusive_orders, + index=data.index, ) self._strategy = strategy self._results: Optional[pd.Series] = None @@ -1290,7 +1536,10 @@ def run(self, **kwargs) -> pd.Series: period of the `Strategy.I` indicator which lags the most. Obviously, this can affect results. """ - data = _Data(self._data.copy(deep=False)) + if self._is_multi_asset: + data = _MultiData({s: df.copy(deep=False) for s, df in self._asset_data.items()}) + else: + data = _Data(self._asset_data[self._primary_symbol].copy(deep=False)) broker: _Broker = self._broker(data=data) strategy: Strategy = self._strategy(broker, data, kwargs) @@ -1306,15 +1555,19 @@ def run(self, **kwargs) -> pd.Series: # Disable "invalid value encountered in ..." warnings. Comparison # np.nan >= 3 is not invalid; it's False. - with np.errstate(invalid='ignore'): - - for i in _tqdm(range(start, len(self._data)), desc=self.run.__qualname__, - unit='bar', mininterval=2, miniters=100): + with np.errstate(invalid="ignore"): + for i in _tqdm( + range(start, len(self._asset_data[self._primary_symbol])), + desc=self.run.__qualname__, + unit="bar", + mininterval=2, + miniters=100, + ): # Prepare data and indicators for `next` call data._set_length(i + 1) for attr, indicator in indicator_attrs: # Slice indicator on the last dimension (case of 2d indicator) - setattr(strategy, attr, indicator[..., :i + 1]) + setattr(strategy, attr, indicator[..., : i + 1]) # Handle orders processing and broker stuff try: @@ -1332,17 +1585,19 @@ def run(self, **kwargs) -> pd.Series: # HACK: Re-run broker one last time to handle close orders placed in the last # strategy iteration. Use the same OHLC values as in the last broker iteration. - if start < len(self._data): + if start < len(self._asset_data[self._primary_symbol]): try_(broker.next, exception=_OutOfMoneyError) elif len(broker.trades): warnings.warn( - 'Some trades remain open at the end of backtest. Use ' - '`Backtest(..., finalize_trades=True)` to close them and ' - 'include them in stats.', stacklevel=2) + "Some trades remain open at the end of backtest. Use " + "`Backtest(..., finalize_trades=True)` to close them and " + "include them in stats.", + stacklevel=2, + ) # Set data back to full length # for future `indicator._opts['data'].index` calls to work - data._set_length(len(self._data)) + data._set_length(len(self._asset_data[self._primary_symbol])) equity = pd.Series(broker._equity).bfill().fillna(broker._cash).values self._results = compute_stats( @@ -1355,17 +1610,18 @@ def run(self, **kwargs) -> pd.Series: return self._results - def optimize(self, *, - maximize: Union[str, Callable[[pd.Series], float]] = 'SQN', - method: str = 'grid', - max_tries: Optional[Union[int, float]] = None, - constraint: Optional[Callable[[dict], bool]] = None, - return_heatmap: bool = False, - return_optimization: bool = False, - random_state: Optional[int] = None, - **kwargs) -> Union[pd.Series, - Tuple[pd.Series, pd.Series], - Tuple[pd.Series, pd.Series, dict]]: + def optimize( + self, + *, + maximize: Union[str, Callable[[pd.Series], float]] = "SQN", + method: str = "grid", + max_tries: Optional[Union[int, float]] = None, + constraint: Optional[Callable[[dict], bool]] = None, + return_heatmap: bool = False, + return_optimization: bool = False, + random_state: Optional[int] = None, + **kwargs, + ) -> Union[pd.Series, Tuple[pd.Series, pd.Series], Tuple[pd.Series, pd.Series, dict]]: """ Optimize strategy parameters to an optimal combination. Returns result `pd.Series` of the best run. @@ -1425,22 +1681,25 @@ def optimize(self, *, constraint=lambda p: p.sma1 < p.sma2) """ if not kwargs: - raise ValueError('Need some strategy parameters to optimize') + raise ValueError("Need some strategy parameters to optimize") maximize_key = None if isinstance(maximize, str): maximize_key = str(maximize) if maximize not in dummy_stats().index: - raise ValueError('`maximize`, if str, must match a key in pd.Series ' - 'result of backtest.run()') + raise ValueError( + "`maximize`, if str, must match a key in pd.Series result of backtest.run()" + ) def maximize(stats: pd.Series, _key=maximize): return stats[_key] elif not callable(maximize): - raise TypeError('`maximize` must be str (a field of backtest.run() result ' - 'Series) or a function that accepts result Series ' - 'and returns a number; the higher the better') + raise TypeError( + "`maximize` must be str (a field of backtest.run() result " + "Series) or a function that accepts result Series " + "and returns a number; the higher the better" + ) assert callable(maximize), maximize have_constraint = bool(constraint) @@ -1450,16 +1709,21 @@ def constraint(_): return True elif not callable(constraint): - raise TypeError("`constraint` must be a function that accepts a dict " - "of strategy parameters and returns a bool whether " - "the combination of parameters is admissible or not") + raise TypeError( + "`constraint` must be a function that accepts a dict " + "of strategy parameters and returns a bool whether " + "the combination of parameters is admissible or not" + ) assert callable(constraint), constraint - if method == 'skopt': - method = 'sambo' - warnings.warn('`Backtest.optimize(method="skopt")` is deprecated. Use `method="sambo"`.', - DeprecationWarning, stacklevel=2) - if return_optimization and method != 'sambo': + if method == "skopt": + method = "sambo" + warnings.warn( + '`Backtest.optimize(method="skopt")` is deprecated. Use `method="sambo"`.', + DeprecationWarning, + stacklevel=2, + ) + if return_optimization and method != "sambo": raise ValueError("return_optimization=True only valid if method='sambo'") def _tuple(x): @@ -1467,8 +1731,9 @@ def _tuple(x): for k, v in kwargs.items(): if len(_tuple(v)) == 0: - raise ValueError(f"Optimization variable '{k}' is passed no " - f"optimization values: {k}={v}") + raise ValueError( + f"Optimization variable '{k}' is passed no optimization values: {k}={v}" + ) class AttrDict(dict): def __getattr__(self, item): @@ -1477,51 +1742,78 @@ def __getattr__(self, item): def _grid_size(): size = int(np.prod([len(_tuple(v)) for v in kwargs.values()])) if size < 10_000 and have_constraint: - size = sum(1 for p in product(*(zip(repeat(k), _tuple(v)) - for k, v in kwargs.items())) - if constraint(AttrDict(p))) + size = sum( + 1 + for p in product(*(zip(repeat(k), _tuple(v)) for k, v in kwargs.items())) + if constraint(AttrDict(p)) + ) return size def _optimize_grid() -> Union[pd.Series, Tuple[pd.Series, pd.Series]]: rand = default_rng(random_state).random - grid_frac = (1 if max_tries is None else - max_tries if 0 < max_tries <= 1 else - max_tries / _grid_size()) - param_combos = [dict(params) # back to dict so it pickles - for params in (AttrDict(params) - for params in product(*(zip(repeat(k), _tuple(v)) - for k, v in kwargs.items()))) - if constraint(params) - and rand() <= grid_frac] + grid_frac = ( + 1 + if max_tries is None + else max_tries + if 0 < max_tries <= 1 + else max_tries / _grid_size() + ) + param_combos = [ + dict(params) # back to dict so it pickles + for params in ( + AttrDict(params) + for params in product(*(zip(repeat(k), _tuple(v)) for k, v in kwargs.items())) + ) + if constraint(params) and rand() <= grid_frac + ] if not param_combos: - raise ValueError('No admissible parameter combinations to test') + raise ValueError("No admissible parameter combinations to test") if len(param_combos) > 300: - warnings.warn(f'Searching for best of {len(param_combos)} configurations.', - stacklevel=2) - - heatmap = pd.Series(np.nan, - name=maximize_key, - index=pd.MultiIndex.from_tuples( - [p.values() for p in param_combos], - names=next(iter(param_combos)).keys())) - - from . import Pool - with Pool() as pool, \ - SharedMemoryManager() as smm: - with patch(self, '_data', None): - bt = copy(self) # bt._data will be reassigned in _mp_task worker - results = _tqdm( - pool.imap(Backtest._mp_task, - ((bt, smm.df2shm(self._data), params_batch) - for params_batch in _batch(param_combos))), - total=len(param_combos), - desc='Backtest.optimize' + warnings.warn( + f"Searching for best of {len(param_combos)} configurations.", stacklevel=2 ) - for param_batch, result in zip(_batch(param_combos), results): - for params, stats in zip(param_batch, result): - if stats is not None: - heatmap[tuple(params.values())] = maximize(stats) + + heatmap = pd.Series( + np.nan, + name=maximize_key, + index=pd.MultiIndex.from_tuples( + [p.values() for p in param_combos], names=next(iter(param_combos)).keys() + ), + ) + + if self._is_multi_asset: + for params in _tqdm( + param_combos, total=len(param_combos), desc="Backtest.optimize" + ): + stats = self.run(**params) + if stats["# Trades"]: + heatmap[tuple(params.values())] = maximize(stats) + else: + from . import Pool + + with Pool() as pool, SharedMemoryManager() as smm: + with patch(self, "_data", None): + bt = copy(self) # bt._data will be reassigned in _mp_task worker + results = _tqdm( + pool.imap( + Backtest._mp_task, + ( + ( + bt, + smm.df2shm(self._asset_data[self._primary_symbol]), + params_batch, + ) + for params_batch in _batch(param_combos) + ), + ), + total=len(param_combos), + desc="Backtest.optimize", + ) + for param_batch, result in zip(_batch(param_combos), results): + for params, stats in zip(param_batch, result): + if stats is not None: + heatmap[tuple(params.values())] = maximize(stats) if pd.isnull(heatmap).all(): # No trade was made in any of the runs. Just make a random @@ -1535,30 +1827,36 @@ def _optimize_grid() -> Union[pd.Series, Tuple[pd.Series, pd.Series]]: return stats, heatmap return stats - def _optimize_sambo() -> Union[pd.Series, - Tuple[pd.Series, pd.Series], - Tuple[pd.Series, pd.Series, dict]]: + def _optimize_sambo() -> Union[ + pd.Series, Tuple[pd.Series, pd.Series], Tuple[pd.Series, pd.Series, dict] + ]: try: import sambo except ImportError: - raise ImportError("Need package 'sambo' for method='sambo'. pip install sambo") from None + raise ImportError( + "Need package 'sambo' for method='sambo'. pip install sambo" + ) from None nonlocal max_tries - max_tries = (200 if max_tries is None else - max(1, int(max_tries * _grid_size())) if 0 < max_tries <= 1 else - max_tries) + max_tries = ( + 200 + if max_tries is None + else max(1, int(max_tries * _grid_size())) + if 0 < max_tries <= 1 + else max_tries + ) dimensions = [] for key, values in kwargs.items(): values = np.asarray(values) - if values.dtype.kind in 'mM': # timedelta, datetime64 + if values.dtype.kind in "mM": # timedelta, datetime64 # these dtypes are unsupported in SAMBO, so convert to raw int # TODO: save dtype and convert back later values = values.astype(np.int64) - if values.dtype.kind in 'iumM': + if values.dtype.kind in "iumM": dimensions.append((values.min(), values.max() + 1)) - elif values.dtype.kind == 'f': + elif values.dtype.kind == "f": dimensions.append((values.min(), values.max())) else: dimensions.append(values.tolist()) @@ -1570,8 +1868,15 @@ def memoized_run(tup): stats = self.run(**dict(tup)) return -maximize(stats) - progress = iter(_tqdm(repeat(None), total=max_tries, leave=False, - desc=self.optimize.__qualname__, mininterval=2)) + progress = iter( + _tqdm( + repeat(None), + total=max_tries, + leave=False, + desc=self.optimize.__qualname__, + mininterval=2, + ) + ) _names = tuple(kwargs.keys()) def objective_function(x): @@ -1589,15 +1894,15 @@ def cons(x): bounds=dimensions, constraints=cons, max_iter=max_tries, - method='sceua', - rng=random_state) + method="sceua", + rng=random_state, + ) stats = self.run(**dict(zip(kwargs.keys(), res.x))) output = [stats] if return_heatmap: - heatmap = pd.Series(dict(zip(map(tuple, res.xv), -res.funv)), - name=maximize_key) + heatmap = pd.Series(dict(zip(map(tuple, res.xv), -res.funv)), name=maximize_key) heatmap.index.names = kwargs.keys() heatmap.sort_index(inplace=True) output.append(heatmap) @@ -1607,9 +1912,9 @@ def cons(x): return stats if len(output) == 1 else tuple(output) - if method == 'grid': + if method == "grid": output = _optimize_grid() - elif method in ('sambo', 'skopt'): + elif method in ("sambo", "skopt"): output = _optimize_sambo() else: raise ValueError(f"Method should be 'grid' or 'sambo', not {method!r}") @@ -1618,22 +1923,40 @@ def cons(x): @staticmethod def _mp_task(arg): bt, data_shm, params_batch = arg - bt._data, shm = SharedMemoryManager.shm2df(data_shm) + data, shm = SharedMemoryManager.shm2df(data_shm) + bt._data = data + bt._is_multi_asset = False + bt._primary_symbol = "_" + bt._asset_data = {bt._primary_symbol: data} try: - return [stats.filter(regex='^[^_]') if stats['# Trades'] else None - for stats in (bt.run(**params) - for params in params_batch)] + return [ + stats.filter(regex="^[^_]") if stats["# Trades"] else None + for stats in (bt.run(**params) for params in params_batch) + ] finally: for shmem in shm: shmem.close() - def plot(self, *, results: pd.Series = None, filename=None, plot_width=None, - plot_equity=True, plot_return=False, plot_pl=True, - plot_volume=True, plot_drawdown=False, plot_trades=True, - smooth_equity=False, relative_equity=True, - superimpose: Union[bool, str] = True, - resample=True, reverse_indicators=False, - show_legend=True, open_browser=True): + def plot( + self, + *, + results: pd.Series = None, + filename=None, + plot_width=None, + plot_equity=True, + plot_return=False, + plot_pl=True, + plot_volume=True, + plot_drawdown=False, + plot_trades=True, + smooth_equity=False, + relative_equity=True, + superimpose: Union[bool, str] = True, + resample=True, + reverse_indicators=False, + show_legend=True, + open_browser=True, + ): """ Plot the progression of the last backtest run. @@ -1715,12 +2038,13 @@ def plot(self, *, results: pd.Series = None, filename=None, plot_width=None, """ if results is None: if self._results is None: - raise RuntimeError('First issue `backtest.run()` to obtain results.') + raise RuntimeError("First issue `backtest.run()` to obtain results.") results = self._results return plot( results=results, - df=self._data, + df=self._asset_data[self._primary_symbol], + assets=self._asset_data if self._is_multi_asset else None, indicators=results._strategy._indicators, filename=filename, plot_width=plot_width, @@ -1736,15 +2060,24 @@ def plot(self, *, results: pd.Series = None, filename=None, plot_width=None, resample=resample, reverse_indicators=reverse_indicators, show_legend=show_legend, - open_browser=open_browser) + open_browser=open_browser, + ) # NOTE: Don't put anything public below this __all__ list -__all__ = [getattr(v, '__name__', k) - for k, v in globals().items() # export - if ((callable(v) and getattr(v, '__module__', None) == __name__ or # callables from this module; getattr for Python 3.9; # noqa: E501 - k.isupper()) and # or CONSTANTS - not getattr(v, '__name__', k).startswith('_'))] # neither marked internal +__all__ = [ + getattr(v, "__name__", k) + for k, v in globals().items() # export + if ( + ( + callable(v) + and getattr(v, "__module__", None) + == __name__ # callables from this module; getattr for Python 3.9; # noqa: E501 + or k.isupper() + ) # or CONSTANTS + and not getattr(v, "__name__", k).startswith("_") + ) +] # neither marked internal # NOTE: Don't put anything public below here. See above. diff --git a/backtesting/lib.py b/backtesting/lib.py index fd05026a..9b97af8a 100644 --- a/backtesting/lib.py +++ b/backtesting/lib.py @@ -18,7 +18,7 @@ from inspect import currentframe from itertools import chain, compress, count from numbers import Number -from typing import Callable, Generator, Optional, Sequence, Union +from typing import Callable, Generator, Mapping, Optional, Sequence, Union import numpy as np import pandas as pd @@ -31,31 +31,35 @@ __pdoc__ = {} -OHLCV_AGG = OrderedDict(( - ('Open', 'first'), - ('High', 'max'), - ('Low', 'min'), - ('Close', 'last'), - ('Volume', 'sum'), -)) +OHLCV_AGG = OrderedDict( + ( + ("Open", "first"), + ("High", "max"), + ("Low", "min"), + ("Close", "last"), + ("Volume", "sum"), + ) +) """Dictionary of rules for aggregating resampled OHLCV data frames, e.g. df.resample('4H', label='right').agg(OHLCV_AGG).dropna() """ -TRADES_AGG = OrderedDict(( - ('Size', 'sum'), - ('EntryBar', 'first'), - ('ExitBar', 'last'), - ('EntryPrice', 'mean'), - ('ExitPrice', 'mean'), - ('PnL', 'sum'), - ('ReturnPct', 'mean'), - ('EntryTime', 'first'), - ('ExitTime', 'last'), - ('Duration', 'sum'), -)) +TRADES_AGG = OrderedDict( + ( + ("Size", "sum"), + ("EntryBar", "first"), + ("ExitBar", "last"), + ("EntryPrice", "mean"), + ("ExitPrice", "mean"), + ("PnL", "sum"), + ("ReturnPct", "mean"), + ("EntryTime", "first"), + ("ExitTime", "last"), + ("Duration", "sum"), + ) +) """Dictionary of rules for aggregating resampled trades data, e.g. @@ -64,9 +68,9 @@ """ _EQUITY_AGG = { - 'Equity': 'last', - 'DrawdownPct': 'max', - 'DrawdownDuration': 'max', + "Equity": "last", + "DrawdownPct": "max", + "DrawdownDuration": "max", } @@ -102,26 +106,34 @@ def crossover(series1: Sequence, series2: Sequence) -> bool: True """ series1 = ( - series1.values if isinstance(series1, pd.Series) else - (series1, series1) if isinstance(series1, Number) else - series1) + series1.values + if isinstance(series1, pd.Series) + else (series1, series1) + if isinstance(series1, Number) + else series1 + ) series2 = ( - series2.values if isinstance(series2, pd.Series) else - (series2, series2) if isinstance(series2, Number) else - series2) + series2.values + if isinstance(series2, pd.Series) + else (series2, series2) + if isinstance(series2, Number) + else series2 + ) try: return series1[-2] < series2[-2] and series1[-1] > series2[-1] # type: ignore except IndexError: return False -def plot_heatmaps(heatmap: pd.Series, - agg: Union[str, Callable] = 'max', - *, - ncols: int = 3, - plot_width: int = 1200, - filename: str = '', - open_browser: bool = True): +def plot_heatmaps( + heatmap: pd.Series, + agg: Union[str, Callable] = "max", + *, + ncols: int = 3, + plot_width: int = 1200, + filename: str = "", + open_browser: bool = True, +): """ Plots a grid of heatmaps, one for every pair of parameters in `heatmap`. See example in [the tutorial]. @@ -172,11 +184,12 @@ def quantile(series: Sequence, quantile: Union[None, float] = None): def compute_stats( - *, - stats: pd.Series, - data: pd.DataFrame, - trades: pd.DataFrame = None, - risk_free_rate: float = 0.) -> pd.Series: + *, + stats: pd.Series, + data: pd.DataFrame, + trades: pd.DataFrame = None, + risk_free_rate: float = 0.0, +) -> pd.Series: """ (Re-)compute strategy performance metrics. @@ -199,17 +212,24 @@ def compute_stats( equity = equity.copy() equity[:] = stats._equity_curve.Equity.iloc[0] for t in trades.itertuples(index=False): - equity.iloc[t.EntryBar:] += t.PnL - return _compute_stats(trades=trades, equity=equity.values, ohlc_data=data, - risk_free_rate=risk_free_rate, strategy_instance=stats._strategy) - - -def resample_apply(rule: str, - func: Optional[Callable[..., Sequence]], - series: Union[pd.Series, pd.DataFrame, _Array], - *args, - agg: Optional[Union[str, dict]] = None, - **kwargs): + equity.iloc[t.EntryBar :] += t.PnL + return _compute_stats( + trades=trades, + equity=equity.values, + ohlc_data=data, + risk_free_rate=risk_free_rate, + strategy_instance=stats._strategy, + ) + + +def resample_apply( + rule: str, + func: Optional[Callable[..., Sequence]], + series: Union[pd.Series, pd.DataFrame, _Array], + *args, + agg: Optional[Union[str, dict]] = None, + **kwargs, +): """ Apply `func` (such as an indicator) to `series`, resampled to a time frame specified by `rule`. When called from inside @@ -282,24 +302,26 @@ def SMA(series, n): """ if func is None: + def func(x, *_, **__): return x - assert callable(func), 'resample_apply(func=) must be callable' + + assert callable(func), "resample_apply(func=) must be callable" if not isinstance(series, (pd.Series, pd.DataFrame)): - assert isinstance(series, _Array), \ - 'resample_apply(series=) must be `pd.Series`, `pd.DataFrame`, ' \ - 'or a `Strategy.data.*` array' + assert isinstance(series, _Array), ( + "resample_apply(series=) must be `pd.Series`, `pd.DataFrame`, " + "or a `Strategy.data.*` array" + ) series = series.s if agg is None: - agg = OHLCV_AGG.get(getattr(series, 'name', ''), 'last') + agg = OHLCV_AGG.get(getattr(series, "name", ""), "last") if isinstance(series, pd.DataFrame): - agg = {column: OHLCV_AGG.get(column, 'last') - for column in series.columns} + agg = {column: OHLCV_AGG.get(column, "last") for column in series.columns} - resampled = series.resample(rule, label='right').agg(agg).dropna() - resampled.name = _as_str(series) + '[' + rule + ']' + resampled = series.resample(rule, label="right").agg(agg).dropna() + resampled.name = _as_str(series) + "[" + rule + "]" # Check first few stack frames if we are being called from # inside Strategy.init, and if so, extract Strategy.I wrapper. @@ -307,10 +329,11 @@ def func(x, *_, **__): while frame and level <= 3: frame = frame.f_back level += 1 - if isinstance(frame.f_locals.get('self'), Strategy): # type: ignore - strategy_I = frame.f_locals['self'].I # type: ignore + if isinstance(frame.f_locals.get("self"), Strategy): # type: ignore + strategy_I = frame.f_locals["self"].I # type: ignore break else: + def strategy_I(func, *args, **kwargs): # noqa: F811 return func(*args, **kwargs) @@ -325,8 +348,9 @@ def wrap_func(resampled, *args, **kwargs): # Resample back to data index if not isinstance(result.index, pd.DatetimeIndex): result.index = resampled.index - result = result.reindex(index=series.index.union(resampled.index), - method='ffill').reindex(series.index) + result = result.reindex(index=series.index.union(resampled.index), method="ffill").reindex( + series.index + ) return result wrap_func.__name__ = func.__name__ @@ -335,8 +359,9 @@ def wrap_func(resampled, *args, **kwargs): return array -def random_ohlc_data(example_data: pd.DataFrame, *, - frac=1., random_state: Optional[int] = None) -> Generator[pd.DataFrame, None, None]: +def random_ohlc_data( + example_data: pd.DataFrame, *, frac=1.0, random_state: Optional[int] = None +) -> Generator[pd.DataFrame, None, None]: """ OHLC data generator. The generated OHLC data has basic [descriptive statistics](https://en.wikipedia.org/wiki/Descriptive_statistics) @@ -355,19 +380,21 @@ def random_ohlc_data(example_data: pd.DataFrame, *, >>> next(ohlc_generator) # returns new random data ... """ + def shuffle(x): return x.sample(frac=frac, replace=frac > 1, random_state=random_state) - if len(example_data.columns.intersection({'Open', 'High', 'Low', 'Close'})) != 4: - raise ValueError("`data` must be a pandas.DataFrame with columns " - "'Open', 'High', 'Low', 'Close'") + if len(example_data.columns.intersection({"Open", "High", "Low", "Close"})) != 4: + raise ValueError( + "`data` must be a pandas.DataFrame with columns 'Open', 'High', 'Low', 'Close'" + ) while True: df = shuffle(example_data) df.index = example_data.index padding = df.Close - df.Open.shift(-1) gaps = shuffle(example_data.Open.shift(-1) - example_data.Close) deltas = (padding + gaps).shift(1).fillna(0).cumsum() - for key in ('Open', 'High', 'Low', 'Close'): + for key in ("Open", "High", "Low", "Close"): df[key] += deltas yield df @@ -394,13 +421,17 @@ def init(self): Remember to call `super().init()` and `super().next()` in your overridden methods. """ + __entry_signal = (0,) __exit_signal = (False,) - def set_signal(self, entry_size: Sequence[float], - exit_portion: Optional[Sequence[float]] = None, - *, - plot: bool = True): + def set_signal( + self, + entry_size: Sequence[float], + exit_portion: Optional[Sequence[float]] = None, + *, + plot: bool = True, + ): """ Set entry/exit signal vectors (arrays). @@ -417,12 +448,22 @@ def set_signal(self, entry_size: Sequence[float], """ self.__entry_signal = self.I( # type: ignore lambda: pd.Series(entry_size, dtype=float).replace(0, np.nan), - name='entry size', plot=plot, overlay=False, scatter=True, color='black') + name="entry size", + plot=plot, + overlay=False, + scatter=True, + color="black", + ) if exit_portion is not None: self.__exit_signal = self.I( # type: ignore lambda: pd.Series(exit_portion, dtype=float).replace(0, np.nan), - name='exit portion', plot=plot, overlay=False, scatter=True, color='black') + name="exit portion", + plot=plot, + overlay=False, + scatter=True, + color="black", + ) def next(self): super().next() @@ -456,7 +497,8 @@ class TrailingStrategy(Strategy): Remember to call `super().init()` and `super().next()` in your overridden methods. """ - __n_atr = 6. + + __n_atr = 6.0 __atr = None def init(self): @@ -480,7 +522,7 @@ def set_trailing_sl(self, n_atr: float = 6): """ self.__n_atr = n_atr - def set_trailing_pct(self, pct: float = .05): + def set_trailing_pct(self, pct: float = 0.05): """ Set the future trailing stop-loss as some percent (`0 < pct < 1`) below the current price (default 5% below). @@ -489,7 +531,7 @@ def set_trailing_pct(self, pct: float = .05): Stop-loss set by `set_trailing_pct` is converted to units of ATR with `mean(Close * pct / atr)` and set with `set_trailing_sl`. """ - assert 0 < pct < 1, 'Need pct= as rate, i.e. 5% == 0.05' + assert 0 < pct < 1, "Need pct= as rate, i.e. 5% == 0.05" pct_in_atr = np.mean(self.data.Close * pct / self.__atr) # type: ignore self.set_trailing_sl(pct_in_atr) @@ -499,11 +541,13 @@ def next(self): index = len(self.data) - 1 for trade in self.trades: if trade.is_long: - trade.sl = max(trade.sl or -np.inf, - self.data.Close[index] - self.__atr[index] * self.__n_atr) + trade.sl = max( + trade.sl or -np.inf, self.data.Close[index] - self.__atr[index] * self.__n_atr + ) else: - trade.sl = min(trade.sl or np.inf, - self.data.Close[index] + self.__atr[index] * self.__n_atr) + trade.sl = min( + trade.sl or np.inf, self.data.Close[index] + self.__atr[index] * self.__n_atr + ) class FractionalBacktest(Backtest): @@ -522,38 +566,53 @@ class FractionalBacktest(Backtest): [satoshi]: https://en.wikipedia.org/wiki/Bitcoin#Units_and_divisibility """ - def __init__(self, - data, - *args, - fractional_unit=1 / 100e6, - **kwargs): - if 'satoshi' in kwargs: + + def __init__(self, data, *args, fractional_unit=1 / 100e6, **kwargs): + if "satoshi" in kwargs: warnings.warn( - 'Parameter `FractionalBacktest(..., satoshi=)` is deprecated. ' - 'Use `FractionalBacktest(..., fractional_unit=)`.', - category=DeprecationWarning, stacklevel=2) - fractional_unit = 1 / kwargs.pop('satoshi') + "Parameter `FractionalBacktest(..., satoshi=)` is deprecated. " + "Use `FractionalBacktest(..., fractional_unit=)`.", + category=DeprecationWarning, + stacklevel=2, + ) + fractional_unit = 1 / kwargs.pop("satoshi") self._fractional_unit = fractional_unit - self.__data: pd.DataFrame = data.copy(deep=False) # Shallow copy - for col in ('Open', 'High', 'Low', 'Close',): - self.__data[col] = self.__data[col] * self._fractional_unit - for col in ('Volume',): - self.__data[col] = self.__data[col] / self._fractional_unit + + def _fractionalize(df: pd.DataFrame) -> pd.DataFrame: + transformed = df.copy(deep=False) + for col in ("Open", "High", "Low", "Close"): + transformed[col] = transformed[col] * self._fractional_unit + for col in ("Volume",): + transformed[col] = transformed[col] / self._fractional_unit + return transformed + + if isinstance(data, pd.DataFrame): + self.__data = _fractionalize(data) + elif isinstance(data, Mapping): + self.__data = {str(symbol): _fractionalize(df) for symbol, df in data.items()} + elif isinstance(data, Sequence) and not isinstance(data, (str, bytes)): + self.__data = {str(i): _fractionalize(df) for i, df in enumerate(data)} + else: + self.__data = data with warnings.catch_warnings(record=True): - warnings.filterwarnings(action='ignore', message='frac') + warnings.filterwarnings(action="ignore", message="frac") super().__init__(data, *args, **kwargs) def run(self, **kwargs) -> pd.Series: - with patch(self, '_data', self.__data): + if isinstance(self.__data, dict): + patched_asset_data = self.__data + else: + patched_asset_data = {getattr(self, "_primary_symbol", "_"): self.__data} + with patch(self, "_data", self.__data), patch(self, "_asset_data", patched_asset_data): result = super().run(**kwargs) - trades: pd.DataFrame = result['_trades'] - trades['Size'] *= self._fractional_unit - trades[['EntryPrice', 'ExitPrice', 'TP', 'SL']] /= self._fractional_unit + trades: pd.DataFrame = result["_trades"] + trades["Size"] *= self._fractional_unit + trades[["EntryPrice", "ExitPrice", "TP", "SL"]] /= self._fractional_unit - indicators = result['_strategy']._indicators + indicators = result["_strategy"]._indicators for indicator in indicators: - if indicator._opts['overlay']: + if indicator._opts["overlay"]: indicator /= self._fractional_unit return result @@ -562,7 +621,7 @@ def run(self, **kwargs) -> pd.Series: # Prevent pdoc3 documenting __init__ signature of Strategy subclasses for cls in list(globals().values()): if isinstance(cls, type) and issubclass(cls, Strategy): - __pdoc__[f'{cls.__name__}.__init__'] = False + __pdoc__[f"{cls.__name__}.__init__"] = False class MultiBacktest: @@ -578,6 +637,7 @@ class MultiBacktest: stats_per_ticker: pd.DataFrame = btm.run(fast=10, slow=20) heatmap_per_ticker: pd.DataFrame = btm.optimize(...) """ + def __init__(self, df_list, strategy_cls, **kwargs): self._dfs = df_list self._strategy = strategy_cls @@ -589,16 +649,20 @@ def run(self, **kwargs): currency indexes in columns. """ from . import Pool - with Pool() as pool, \ - SharedMemoryManager() as smm: + + with Pool() as pool, SharedMemoryManager() as smm: shm = [smm.df2shm(df) for df in self._dfs] results = _tqdm( - pool.imap(self._mp_task_run, - ((df_batch, self._strategy, self._bt_kwargs, kwargs) - for df_batch in _batch(shm))), + pool.imap( + self._mp_task_run, + ( + (df_batch, self._strategy, self._bt_kwargs, kwargs) + for df_batch in _batch(shm) + ), + ), total=len(shm), desc=self.run.__qualname__, - mininterval=2 + mininterval=2, ) df = pd.DataFrame(list(chain(*results))).transpose() return df @@ -608,9 +672,10 @@ def _mp_task_run(args): data_shm, strategy, bt_kwargs, run_kwargs = args dfs, shms = zip(*(SharedMemoryManager.shm2df(i) for i in data_shm)) try: - return [stats.filter(regex='^[^_]') if stats['# Trades'] else None - for stats in (Backtest(df, strategy, **bt_kwargs).run(**run_kwargs) - for df in dfs)] + return [ + stats.filter(regex="^[^_]") if stats["# Trades"] else None + for stats in (Backtest(df, strategy, **bt_kwargs).run(**run_kwargs) for df in dfs) + ] finally: for shmem in chain(*shms): shmem.close() @@ -629,7 +694,8 @@ def optimize(self, **kwargs) -> pd.DataFrame: for df in _tqdm(self._dfs, desc=self.__class__.__name__, mininterval=2): bt = Backtest(df, self._strategy, **self._bt_kwargs) _best_stats, heatmap = bt.optimize( # type: ignore - return_heatmap=True, return_optimization=False, **kwargs) + return_heatmap=True, return_optimization=False, **kwargs + ) heatmaps.append(heatmap) heatmap = pd.DataFrame(dict(zip(count(), heatmaps))) return heatmap @@ -637,10 +703,17 @@ def optimize(self, **kwargs) -> pd.DataFrame: # NOTE: Don't put anything below this __all__ list -__all__ = [getattr(v, '__name__', k) - for k, v in globals().items() # export - if ((callable(v) and getattr(v, '__module__', None) == __name__ or # callables from this module - k.isupper()) and # or CONSTANTS - not getattr(v, '__name__', k).startswith('_'))] # neither marked internal +__all__ = [ + getattr(v, "__name__", k) + for k, v in globals().items() # export + if ( + ( + callable(v) + and getattr(v, "__module__", None) == __name__ # callables from this module + or k.isupper() + ) # or CONSTANTS + and not getattr(v, "__name__", k).startswith("_") + ) +] # neither marked internal # NOTE: Don't put anything below here. See above. diff --git a/backtesting/test/_test.py b/backtesting/test/_test.py index 63045ce1..aecb87a0 100644 --- a/backtesting/test/_test.py +++ b/backtesting/test/_test.py @@ -19,7 +19,9 @@ from backtesting._stats import compute_drawdown_duration_peaks from backtesting._util import _Array, _as_str, _Indicator, patch, try_ from backtesting.lib import ( - FractionalBacktest, MultiBacktest, OHLCV_AGG, + FractionalBacktest, + MultiBacktest, + OHLCV_AGG, SignalStrategy, TrailingStrategy, barssince, @@ -38,8 +40,8 @@ @contextmanager def _tempfile(): - with NamedTemporaryFile(suffix='.html') as f: - if sys.platform.startswith('win'): + with NamedTemporaryFile(suffix=".html") as f: + if sys.platform.startswith("win"): f.close() yield f.name @@ -91,33 +93,33 @@ def test_run_speed(self): start = time.process_time() bt.run() end = time.process_time() - self.assertLess(end - start, .3) + self.assertLess(end - start, 0.3) def test_data_missing_columns(self): df = GOOG.copy(deep=False) - del df['Open'] + del df["Open"] with self.assertRaises(ValueError): Backtest(df, SmaCross).run() def test_data_nan_columns(self): df = GOOG.copy() - df['Open'] = np.nan + df["Open"] = np.nan with self.assertRaises(ValueError): Backtest(df, SmaCross).run() def test_data_extra_columns(self): df = GOOG.copy(deep=False) - df['P/E'] = np.arange(len(df)) - df['MCap'] = np.arange(len(df)) + df["P/E"] = np.arange(len(df)) + df["MCap"] = np.arange(len(df)) class S(Strategy): def init(self): assert len(self.data.MCap) == len(self.data.Close) - assert len(self.data['P/E']) == len(self.data.Close) + assert len(self.data["P/E"]) == len(self.data.Close) def next(self): assert len(self.data.MCap) == len(self.data.Close) - assert len(self.data['P/E']) == len(self.data.Close) + assert len(self.data["P/E"]) == len(self.data.Close) Backtest(df, S).run() @@ -133,24 +135,25 @@ def init(self): self.sma = self.I(SMA, self.data.Close, 10) self.remains_indicator = np.r_[2] * np.cumsum(self.sma * 5 + 1) * np.r_[2] - self.transpose_invalid = self.I(lambda: np.column_stack((self.data.Open, - self.data.Close))) + self.transpose_invalid = self.I( + lambda: np.column_stack((self.data.Open, self.data.Close)) + ) - resampled = resample_apply('W', SMA, self.data.Close, 3) - resampled_ind = resample_apply('W', SMA, self.sma, 3) + resampled = resample_apply("W", SMA, self.data.Close, 3) + resampled_ind = resample_apply("W", SMA, self.sma, 3) assert np.unique(resampled[-5:]).size == 1 assert np.unique(resampled[-6:]).size == 2 assert resampled in self._indicators, "Strategy.I not called" assert resampled_ind in self._indicators, "Strategy.I not called" assert 1 == try_(lambda: self.data.X, 1, AttributeError) - assert 1 == try_(lambda: self.data['X'], 1, KeyError) + assert 1 == try_(lambda: self.data["X"], 1, KeyError) - assert self.data.pip == .01 + assert self.data.pip == 0.01 assert float(self.data.Close) == self.data.Close[-1] - def next(self, _FEW_DAYS=pd.Timedelta('3 days')): # noqa: N803 + def next(self, _FEW_DAYS=pd.Timedelta("3 days")): # noqa: N803 assert self.equity >= 0 assert isinstance(self.sma, _Indicator) @@ -177,16 +180,16 @@ def next(self, _FEW_DAYS=pd.Timedelta('3 days')): # noqa: N803 if not order.is_contingent: order.cancel() price = self.data.Close[-1] - sl, tp = 1.05 * price, .9 * price + sl, tp = 1.05 * price, 0.9 * price n_orders = len(self.orders) - self.sell(size=.21, limit=price, stop=price, sl=sl, tp=tp) + self.sell(size=0.21, limit=price, stop=price, sl=sl, tp=tp) assert len(self.orders) == n_orders + 1 order = self.orders[-1] assert order.limit == price assert order.stop == price - assert order.size == -.21 + assert order.size == -0.21 assert order.sl == sl assert order.tp == tp assert not order.is_contingent @@ -215,20 +218,21 @@ def next(self, _FEW_DAYS=pd.Timedelta('3 days')): # noqa: N803 assert trade.sl assert trade.tp # Close multiple times - self.position.close(.5) - self.position.close(.5) - self.position.close(.5) + self.position.close(0.5) + self.position.close(0.5) + self.position.close(0.5) self.position.close() self.position.close() bt = Backtest(GOOG, Assertive) with self.assertWarns(UserWarning): stats = bt.run() - self.assertEqual(stats['# Trades'], 131) + self.assertEqual(stats["# Trades"], 131) def test_broker_params(self): - bt = Backtest(GOOG.iloc[:100], SmaCross, - cash=1000, spread=.01, margin=.1, trade_on_close=True) + bt = Backtest( + GOOG.iloc[:100], SmaCross, cash=1000, spread=0.01, margin=0.1, trade_on_close=True + ) bt.run() def test_spread_commission(self): @@ -243,24 +247,27 @@ def next(self): self.position.close() self.next = lambda: None # Done - SPREAD = .01 - COMMISSION = .01 + SPREAD = 0.01 + COMMISSION = 0.01 CASH = 10_000 ORDER_BAR = 2 stats = Backtest(SHORT_DATA, S, cash=CASH, spread=SPREAD, commission=COMMISSION).run() - trade_open_price = SHORT_DATA['Open'].iloc[ORDER_BAR] - self.assertEqual(stats['_trades']['EntryPrice'].iloc[0], trade_open_price * (1 + SPREAD)) - self.assertEqual(stats['_equity_curve']['Equity'].iloc[2:4].round(2).tolist(), - [9685.31, 9749.33]) + trade_open_price = SHORT_DATA["Open"].iloc[ORDER_BAR] + self.assertEqual(stats["_trades"]["EntryPrice"].iloc[0], trade_open_price * (1 + SPREAD)) + self.assertEqual( + stats["_equity_curve"]["Equity"].iloc[2:4].round(2).tolist(), [9685.31, 9749.33] + ) stats = Backtest(SHORT_DATA, S, cash=CASH, commission=(100, COMMISSION)).run() - self.assertEqual(stats['_equity_curve']['Equity'].iloc[2:4].round(2).tolist(), - [9784.50, 9718.69]) + self.assertEqual( + stats["_equity_curve"]["Equity"].iloc[2:4].round(2).tolist(), [9784.50, 9718.69] + ) commission_func = lambda size, price: size * price * COMMISSION # noqa: E731 stats = Backtest(SHORT_DATA, S, cash=CASH, commission=commission_func).run() - self.assertEqual(stats['_equity_curve']['Equity'].iloc[2:4].round(2).tolist(), - [9781.28, 9846.04]) + self.assertEqual( + stats["_equity_curve"]["Equity"].iloc[2:4].round(2).tolist(), [9781.28, 9846.04] + ) def test_commissions(self): class S(_S): @@ -268,20 +275,24 @@ def next(self): if len(self.data) == 2: self.buy(size=SIZE, tp=3) - FIXED_COMMISSION, COMMISSION = 10, .01 + FIXED_COMMISSION, COMMISSION = 10, 0.01 CASH, SIZE, PRICE_ENTRY, PRICE_EXIT = 5000, 100, 1, 4 arr = np.r_[1, PRICE_ENTRY, 1, 2, PRICE_EXIT, 1, 2] - df = pd.DataFrame({'Open': arr, 'High': arr, 'Low': arr, 'Close': arr}) - with self.assertWarnsRegex(UserWarning, 'index is not datetime'): + df = pd.DataFrame({"Open": arr, "High": arr, "Low": arr, "Close": arr}) + with self.assertWarnsRegex(UserWarning, "index is not datetime"): stats = Backtest(df, S, cash=CASH, commission=(FIXED_COMMISSION, COMMISSION)).run() EXPECTED_PAID_COMMISSION = ( - FIXED_COMMISSION + COMMISSION * SIZE * PRICE_ENTRY + - FIXED_COMMISSION + COMMISSION * SIZE * PRICE_EXIT) - self.assertEqual(stats['Commissions [$]'], EXPECTED_PAID_COMMISSION) - self.assertEqual(stats._trades['Commission'][0], EXPECTED_PAID_COMMISSION) + FIXED_COMMISSION + + COMMISSION * SIZE * PRICE_ENTRY + + FIXED_COMMISSION + + COMMISSION * SIZE * PRICE_EXIT + ) + self.assertEqual(stats["Commissions [$]"], EXPECTED_PAID_COMMISSION) + self.assertEqual(stats._trades["Commission"][0], EXPECTED_PAID_COMMISSION) self.assertEqual( - stats['Equity Final [$]'], - CASH + (PRICE_EXIT - PRICE_ENTRY) * SIZE - EXPECTED_PAID_COMMISSION) + stats["Equity Final [$]"], + CASH + (PRICE_EXIT - PRICE_ENTRY) * SIZE - EXPECTED_PAID_COMMISSION, + ) def test_dont_overwrite_data(self): df = EURUSD.copy() @@ -300,7 +311,7 @@ class MyStrategy(Strategy): def test_strategy_str(self): bt = Backtest(GOOG.iloc[:100], SmaCross) self.assertEqual(str(bt.run()._strategy), SmaCross.__name__) - self.assertEqual(str(bt.run(fast=11)._strategy), SmaCross.__name__ + '(fast=11)') + self.assertEqual(str(bt.run(fast=11)._strategy), SmaCross.__name__ + "(fast=11)") def test_compute_drawdown(self): dd = pd.Series([0, 1, 7, 0, 4, 0, 0]) @@ -310,68 +321,90 @@ def test_compute_drawdown(self): def test_compute_stats(self): stats = Backtest(GOOG, SmaCross, finalize_trades=True).run() - expected = pd.Series({ + expected = pd.Series( + { # NOTE: These values are also used on the website! # noqa: E126 - '# Trades': 66, - 'Avg. Drawdown Duration': pd.Timedelta('41 days 00:00:00'), - 'Avg. Drawdown [%]': -5.925851581948801, - 'Avg. Trade Duration': pd.Timedelta('46 days 00:00:00'), - 'Avg. Trade [%]': 2.531715975158555, - 'Best Trade [%]': 53.59595229490424, - 'Buy & Hold Return [%]': 522.0601851851852, - 'Calmar Ratio': 0.4414380935608377, - 'Duration': pd.Timedelta('3116 days 00:00:00'), - 'End': pd.Timestamp('2013-03-01 00:00:00'), - 'Equity Final [$]': 51422.98999999996, - 'Equity Peak [$]': 75787.44, - 'Expectancy [%]': 3.2748078066748834, - 'Exposure Time [%]': 96.74115456238361, - 'Max. Drawdown Duration': pd.Timedelta('584 days 00:00:00'), - 'Max. Drawdown [%]': -47.98012705007589, - 'Max. Trade Duration': pd.Timedelta('183 days 00:00:00'), - 'Profit Factor': 2.167945974262033, - 'Return (Ann.) [%]': 21.180255813792282, - 'Return [%]': 414.2298999999996, - 'Volatility (Ann.) [%]': 36.49390889140787, - 'CAGR [%]': 14.159843619607383, - 'SQN': 1.0766187356697705, - 'Kelly Criterion': 0.1518705127029717, - 'Sharpe Ratio': 0.5803778344714113, - 'Sortino Ratio': 1.0847880675854096, - 'Start': pd.Timestamp('2004-08-19 00:00:00'), - 'Win Rate [%]': 46.96969696969697, - 'Worst Trade [%]': -18.39887353835481, - 'Alpha [%]': 394.37391142027462, - 'Beta': 0.03803390709192, - }) + "# Trades": 66, + "Avg. Drawdown Duration": pd.Timedelta("41 days 00:00:00"), + "Avg. Drawdown [%]": -5.925851581948801, + "Avg. Trade Duration": pd.Timedelta("46 days 00:00:00"), + "Avg. Trade [%]": 2.531715975158555, + "Best Trade [%]": 53.59595229490424, + "Buy & Hold Return [%]": 522.0601851851852, + "Calmar Ratio": 0.4414380935608377, + "Duration": pd.Timedelta("3116 days 00:00:00"), + "End": pd.Timestamp("2013-03-01 00:00:00"), + "Equity Final [$]": 51422.98999999996, + "Equity Peak [$]": 75787.44, + "Expectancy [%]": 3.2748078066748834, + "Exposure Time [%]": 96.74115456238361, + "Max. Drawdown Duration": pd.Timedelta("584 days 00:00:00"), + "Max. Drawdown [%]": -47.98012705007589, + "Max. Trade Duration": pd.Timedelta("183 days 00:00:00"), + "Profit Factor": 2.167945974262033, + "Return (Ann.) [%]": 21.180255813792282, + "Return [%]": 414.2298999999996, + "Volatility (Ann.) [%]": 36.49390889140787, + "CAGR [%]": 14.159843619607383, + "SQN": 1.0766187356697705, + "Kelly Criterion": 0.1518705127029717, + "Sharpe Ratio": 0.5803778344714113, + "Sortino Ratio": 1.0847880675854096, + "Start": pd.Timestamp("2004-08-19 00:00:00"), + "Win Rate [%]": 46.96969696969697, + "Worst Trade [%]": -18.39887353835481, + "Alpha [%]": 394.37391142027462, + "Beta": 0.03803390709192, + } + ) def almost_equal(a, b): try: - return np.isclose(a, b, rtol=1.e-8) + return np.isclose(a, b, rtol=1.0e-8) except TypeError: return a == b - diff = {key: print(key) or value # noqa: T201 - for key, value in stats.filter(regex='^[^_]').items() - if not almost_equal(value, expected[key])} + diff = { + key: print(key) or value # noqa: T201 + for key, value in stats.filter(regex="^[^_]").items() + if not almost_equal(value, expected[key]) + } self.assertDictEqual(diff, {}) self.assertSequenceEqual( - sorted(stats['_equity_curve'].columns), - sorted(['Equity', 'DrawdownPct', 'DrawdownDuration'])) + sorted(stats["_equity_curve"].columns), + sorted(["Equity", "DrawdownPct", "DrawdownDuration"]), + ) - self.assertEqual(len(stats['_trades']), 66) + self.assertEqual(len(stats["_trades"]), 66) indicator_columns = [ - f'{entry}_SMA(C,{n})' - for entry in ('Entry', 'Exit') - for n in (SmaCross.fast, SmaCross.slow)] + f"{entry}_SMA(C,{n})" + for entry in ("Entry", "Exit") + for n in (SmaCross.fast, SmaCross.slow) + ] self.assertSequenceEqual( - sorted(stats['_trades'].columns), - sorted(['Size', 'EntryBar', 'ExitBar', 'EntryPrice', 'ExitPrice', - 'SL', 'TP', 'PnL', 'ReturnPct', 'EntryTime', 'ExitTime', - 'Duration', 'Tag', 'Commission', - *indicator_columns])) + sorted(stats["_trades"].columns), + sorted( + [ + "Size", + "EntryBar", + "ExitBar", + "EntryPrice", + "ExitPrice", + "SL", + "TP", + "PnL", + "ReturnPct", + "EntryTime", + "ExitTime", + "Duration", + "Tag", + "Commission", + *indicator_columns, + ] + ), + ) def test_compute_stats_bordercase(self): @@ -395,16 +428,13 @@ class NoTrade(_S): def next(self): pass - for strategy in (SmaCross, - SingleTrade, - SinglePosition, - NoTrade): + for strategy in (SmaCross, SingleTrade, SinglePosition, NoTrade): with self.subTest(strategy=strategy.__name__): stats = Backtest(GOOG.iloc[:100], strategy).run() - self.assertFalse(np.isnan(stats['Equity Final [$]'])) - self.assertFalse(stats['_equity_curve']['Equity'].isnull().any()) - self.assertEqual(stats['_strategy'].__class__, strategy) + self.assertFalse(np.isnan(stats["Equity Final [$]"])) + self.assertFalse(stats["_equity_curve"]["Equity"].isnull().any()) + self.assertEqual(stats["_strategy"].__class__, strategy) def test_trade_enter_hit_sl_on_same_day(self): the_day = pd.Timestamp("2012-10-17 00:00:00") @@ -443,9 +473,9 @@ def next(self): if not self.position and crossover(self.sma1, self.sma2): self.buy(size=10) if self.position and crossover(self.sma2, self.sma1): - self.position.close(portion=.5) + self.position.close(portion=0.5) - bt = Backtest(GOOG, SmaCross, spread=.002) + bt = Backtest(GOOG, SmaCross, spread=0.002) bt.run() def test_close_orders_from_last_strategy_iteration(self): @@ -456,7 +486,7 @@ def next(self): elif len(self.data) == len(SHORT_DATA): self.position.close() - with self.assertWarnsRegex(UserWarning, 'finalize_trades'): + with self.assertWarnsRegex(UserWarning, "finalize_trades"): self.assertTrue(Backtest(SHORT_DATA, S, finalize_trades=False).run()._trades.empty) self.assertFalse(Backtest(SHORT_DATA, S, finalize_trades=True).run()._trades.empty) @@ -465,7 +495,56 @@ class S(_S): def next(self): self.buy(tp=self.data.Close * 1.01) - self.assertRaises(ValueError, Backtest(SHORT_DATA, S, spread=.02).run) + self.assertRaises(ValueError, Backtest(SHORT_DATA, S, spread=0.02).run) + + def test_multi_asset_run_and_symbols(self): + n = 220 + goog = GOOG.iloc[:n].copy() + half = GOOG.iloc[:n].copy() + half[["Open", "High", "Low", "Close"]] = half[["Open", "High", "Low", "Close"]] * 0.5 + + class S(_S): + def next(self): + i = len(self.data.index) + if i == 20: + self.buy(symbol="GOOG", size=1) + self.buy(symbol="HALF", size=2) + if i == 30: + self.position["GOOG"].close() + if i == 40: + self.position["HALF"].close() + + stats = Backtest({"GOOG": goog, "HALF": half}, S, cash=1_000_000).run() + self.assertEqual(len(stats._trades), 2) + self.assertIn("Symbol", stats._trades.columns) + self.assertEqual(stats._trades["Symbol"].tolist(), ["GOOG", "HALF"]) + + def test_multi_asset_position_accessor(self): + n = 120 + goog = GOOG.iloc[:n].copy() + half = GOOG.iloc[:n].copy() + half[["Open", "High", "Low", "Close"]] = half[["Open", "High", "Low", "Close"]] * 0.5 + + class S(_S): + def next(self): + i = len(self.data.index) + if i == 10: + self.buy(symbol="GOOG", size=1) + self.buy(symbol="HALF", size=1) + if i == 11: + assert self.position["GOOG"].size == 1 + assert self.position["HALF"].size == 1 + assert self.position.size == 2 + self.position.close() + + Backtest({"GOOG": goog, "HALF": half}, S, cash=1_000_000).run() + + def test_multi_asset_misaligned_index_raises(self): + n = 40 + goog = GOOG.iloc[:n].copy() + shifted = GOOG.iloc[1 : n + 1].copy() + with self.assertRaisesRegex(ValueError, "identical, aligned indexes"): + Backtest({"GOOG": goog, "SHIFT": shifted}, SmaCross) class TestStrategy(TestCase): @@ -531,7 +610,7 @@ def coroutine(self): assert self.trades self.trades[-1].close(1) - self.trades[-1].close(.1) + self.trades[-1].close(0.1) yield self._Backtest(coroutine).run() @@ -553,7 +632,7 @@ def coroutine(self): yield stats = self._Backtest(coroutine).run() - self.assertListEqual(stats._trades.filter(like='Price').stack().tolist(), [112, 107]) + self.assertListEqual(stats._trades.filter(like="Price").stack().tolist(), [112, 107]) def test_autoclose_trades_on_finish(self): def coroutine(self): @@ -565,7 +644,7 @@ def coroutine(self): def test_order_tag(self): def coroutine(self): yield self.buy(size=2, tag=1) - yield self.sell(size=1, tag='s') + yield self.sell(size=1, tag="s") yield self.sell(size=1) yield self.buy(tag=2) @@ -578,11 +657,11 @@ def coroutine(self): class TestOptimize(TestCase): def test_optimize(self): bt = Backtest(GOOG.iloc[:100], SmaCross) - OPT_PARAMS = {'fast': range(2, 5, 2), 'slow': [2, 5, 7, 9]} + OPT_PARAMS = {"fast": range(2, 5, 2), "slow": [2, 5, 7, 9]} self.assertRaises(ValueError, bt.optimize) - self.assertRaises(ValueError, bt.optimize, maximize='missing key', **OPT_PARAMS) - self.assertRaises(ValueError, bt.optimize, maximize='missing key', **OPT_PARAMS) + self.assertRaises(ValueError, bt.optimize, maximize="missing key", **OPT_PARAMS) + self.assertRaises(ValueError, bt.optimize, maximize="missing key", **OPT_PARAMS) self.assertRaises(TypeError, bt.optimize, maximize=15, **OPT_PARAMS) self.assertRaises(TypeError, bt.optimize, constraint=15, **OPT_PARAMS) self.assertRaises(ValueError, bt.optimize, constraint=lambda d: False, **OPT_PARAMS) @@ -591,13 +670,16 @@ def test_optimize(self): res = bt.optimize(**OPT_PARAMS) self.assertIsInstance(res, pd.Series) - default_maximize = inspect.signature(Backtest.optimize).parameters['maximize'].default + default_maximize = inspect.signature(Backtest.optimize).parameters["maximize"].default res2 = bt.optimize(**OPT_PARAMS, maximize=lambda s: s[default_maximize]) - self.assertDictEqual(res.filter(regex='^[^_]').fillna(-1).to_dict(), - res2.filter(regex='^[^_]').fillna(-1).to_dict()) - - res3, heatmap = bt.optimize(**OPT_PARAMS, return_heatmap=True, - constraint=lambda d: d.slow > 2 * d.fast) + self.assertDictEqual( + res.filter(regex="^[^_]").fillna(-1).to_dict(), + res2.filter(regex="^[^_]").fillna(-1).to_dict(), + ) + + res3, heatmap = bt.optimize( + **OPT_PARAMS, return_heatmap=True, constraint=lambda d: d.slow > 2 * d.fast + ) self.assertIsInstance(heatmap, pd.Series) self.assertEqual(len(heatmap), 4) self.assertEqual(heatmap.name, default_maximize) @@ -608,13 +690,15 @@ def test_optimize(self): def test_method_sambo(self): bt = Backtest(GOOG.iloc[:100], SmaCross, finalize_trades=True) res, heatmap, sambo_results = bt.optimize( - fast=range(2, 20), slow=np.arange(2, 20, dtype=object), + fast=range(2, 20), + slow=np.arange(2, 20, dtype=object), constraint=lambda p: p.fast < p.slow, max_tries=30, - method='sambo', + method="sambo", return_optimization=True, return_heatmap=True, - random_state=2) + random_state=2, + ) self.assertIsInstance(res, pd.Series) self.assertIsInstance(heatmap, pd.Series) self.assertGreater(heatmap.max(), 1.1) @@ -624,19 +708,21 @@ def test_method_sambo(self): def test_max_tries(self): bt = Backtest(GOOG.iloc[:100], SmaCross) - OPT_PARAMS = {'fast': range(2, 10, 2), 'slow': [2, 5, 7, 9]} - for method, max_tries, random_state in (('grid', 5, 0), - ('grid', .3, 0), - ('sambo', 6, 0), - ('sambo', .42, 0)): - with self.subTest(method=method, - max_tries=max_tries, - random_state=random_state): - _, heatmap = bt.optimize(max_tries=max_tries, - method=method, - random_state=random_state, - return_heatmap=True, - **OPT_PARAMS) + OPT_PARAMS = {"fast": range(2, 10, 2), "slow": [2, 5, 7, 9]} + for method, max_tries, random_state in ( + ("grid", 5, 0), + ("grid", 0.3, 0), + ("sambo", 6, 0), + ("sambo", 0.42, 0), + ): + with self.subTest(method=method, max_tries=max_tries, random_state=random_state): + _, heatmap = bt.optimize( + max_tries=max_tries, + method=method, + random_state=random_state, + return_heatmap=True, + **OPT_PARAMS, + ) self.assertEqual(len(heatmap), 6) def test_optimize_invalid_param(self): @@ -655,8 +741,8 @@ def test_optimize_speed(self): bt.optimize(fast=range(2, 20, 2), slow=range(10, 40, 2)) end = time.process_time() print(end - start) - handicap = 5 if 'win' in sys.platform else .1 - self.assertLess(end - start, .3 + handicap) + handicap = 5 if "win" in sys.platform else 0.1 + self.assertLess(end - start, 0.3 + handicap) class TestPlot(TestCase): @@ -668,25 +754,27 @@ def test_file_size(self): bt = Backtest(GOOG, SmaCross) bt.run() with _tempfile() as f: - bt.plot(filename=f[:-len('.html')], open_browser=False) + bt.plot(filename=f[: -len(".html")], open_browser=False) self.assertLess(os.path.getsize(f), 500000) def test_params(self): bt = Backtest(GOOG.iloc[:100], SmaCross) bt.run() with _tempfile() as f: - for p in dict(plot_volume=False, # noqa: C408 - plot_equity=False, - plot_return=True, - plot_pl=False, - plot_drawdown=True, - plot_trades=False, - superimpose=False, - resample='1W', - smooth_equity=False, - relative_equity=False, - reverse_indicators=True, - show_legend=False).items(): + for p in dict( + plot_volume=False, # noqa: C408 + plot_equity=False, + plot_return=True, + plot_pl=False, + plot_drawdown=True, + plot_trades=False, + superimpose=False, + resample="1W", + smooth_equity=False, + relative_equity=False, + reverse_indicators=True, + show_legend=False, + ).items(): with self.subTest(param=p[0]): bt.plot(**dict([p]), filename=f, open_browser=False) @@ -700,7 +788,7 @@ def test_hide_legend(self): def test_resolutions(self): with _tempfile() as f: - for rule in 'ms s min h D W ME'.split(): + for rule in "ms s min h D W ME".split(): with self.subTest(rule=rule): df = EURUSD.iloc[:2].resample(rule).agg(OHLCV_AGG).dropna().iloc[:1100] bt = Backtest(df, SmaCross) @@ -731,7 +819,7 @@ def init(self): def ok(x): return x - self.a = self.I(SMA, self.data.Open, 5, overlay=False, name='ok') + self.a = self.I(SMA, self.data.Open, 5, overlay=False, name="ok") self.b = self.I(ok, np.random.random(len(self.data.Open))) bt = Backtest(GOOG, Strategy) @@ -745,26 +833,33 @@ def test_wellknown(self): class S(_S): def next(self): date = self.data.index[-1] - if date == pd.Timestamp('Thu 19 Oct 2006'): + if date == pd.Timestamp("Thu 19 Oct 2006"): self.buy(stop=484, limit=466, size=100) - elif date == pd.Timestamp('Thu 30 Oct 2007'): + elif date == pd.Timestamp("Thu 30 Oct 2007"): self.position.close() - elif date == pd.Timestamp('Tue 11 Nov 2008'): - self.sell(stop=self.data.Low, - limit=324.90, # High from 14 Nov - size=200) - - bt = Backtest(GOOG, S, margin=.1) + elif date == pd.Timestamp("Tue 11 Nov 2008"): + self.sell( + stop=self.data.Low, + limit=324.90, # High from 14 Nov + size=200, + ) + + bt = Backtest(GOOG, S, margin=0.1) stats = bt.run() - trades = stats['_trades'] + trades = stats["_trades"] - self.assertAlmostEqual(stats['Equity Peak [$]'], 46961) - self.assertEqual(stats['Equity Final [$]'], 0) + self.assertAlmostEqual(stats["Equity Peak [$]"], 46961) + self.assertEqual(stats["Equity Final [$]"], 0) self.assertEqual(len(trades), 2) - assert trades[['EntryTime', 'ExitTime']].equals( - pd.DataFrame({'EntryTime': pd.to_datetime(['2006-11-01', '2008-11-14']), - 'ExitTime': pd.to_datetime(['2007-10-31', '2009-09-21'])})) - assert trades['PnL'].round().equals(pd.Series([23469., -34420.])) + assert trades[["EntryTime", "ExitTime"]].equals( + pd.DataFrame( + { + "EntryTime": pd.to_datetime(["2006-11-01", "2008-11-14"]), + "ExitTime": pd.to_datetime(["2007-10-31", "2009-09-21"]), + } + ) + ) + assert trades["PnL"].round().equals(pd.Series([23469.0, -34420.0])) with _tempfile() as f: bt.plot(filename=f, plot_drawdown=True, smooth_equity=False) @@ -774,15 +869,18 @@ def next(self): def test_resample(self): class S(SmaCross): def init(self): - self.I(lambda: ['x'] * len(self.data)) # categorical indicator, GH-309 + self.I(lambda: ["x"] * len(self.data)) # categorical indicator, GH-309 super().init() bt = Backtest(GOOG, S) bt.run() import backtesting._plotting - with _tempfile() as f, \ - patch(backtesting._plotting, '_MAX_CANDLES', 10), \ - self.assertWarns(UserWarning): + + with ( + _tempfile() as f, + patch(backtesting._plotting, "_MAX_CANDLES", 10), + self.assertWarns(UserWarning), + ): bt.plot(filename=f, resample=True) # Give browser time to open before tempfile is removed time.sleep(1) @@ -796,14 +894,15 @@ def _SMA(): return SMA(self.data.Close, 5), SMA(self.data.Close, 10) test_self.assertRaises(TypeError, self.I, _SMA, name=42) - test_self.assertRaises(ValueError, self.I, _SMA, name=("SMA One", )) + test_self.assertRaises(ValueError, self.I, _SMA, name=("SMA One",)) test_self.assertRaises( - ValueError, self.I, _SMA, name=("SMA One", "SMA Two", "SMA Three")) + ValueError, self.I, _SMA, name=("SMA One", "SMA Two", "SMA Three") + ) for overlay in (True, False): self.I(SMA, self.data.Close, 5, overlay=overlay) self.I(SMA, self.data.Close, 5, name="My SMA", overlay=overlay) - self.I(SMA, self.data.Close, 5, name=("My SMA", ), overlay=overlay) + self.I(SMA, self.data.Close, 5, name=("My SMA",), overlay=overlay) self.I(_SMA, overlay=overlay) self.I(_SMA, name="My SMA", overlay=overlay) self.I(_SMA, name=("SMA One", "SMA Two"), overlay=overlay) @@ -814,16 +913,21 @@ def next(self): bt = Backtest(GOOG, S) bt.run() with _tempfile() as f: - bt.plot(filename=f, - plot_drawdown=False, plot_equity=False, plot_pl=False, plot_volume=False, - open_browser=False) + bt.plot( + filename=f, + plot_drawdown=False, + plot_equity=False, + plot_pl=False, + plot_volume=False, + open_browser=False, + ) def test_indicator_color(self): class S(Strategy): def init(self): - a = self.I(SMA, self.data.Close, 5, overlay=True, color='red') - b = self.I(SMA, self.data.Close, 10, overlay=False, color='blue') - self.I(lambda: (a, b), overlay=False, color=('green', 'orange')) + a = self.I(SMA, self.data.Close, 5, overlay=True, color="red") + b = self.I(SMA, self.data.Close, 10, overlay=False, color="blue") + self.I(lambda: (a, b), overlay=False, color=("green", "orange")) def next(self): pass @@ -831,9 +935,14 @@ def next(self): bt = Backtest(GOOG, S) bt.run() with _tempfile() as f: - bt.plot(filename=f, - plot_drawdown=False, plot_equity=False, plot_pl=False, plot_volume=False, - open_browser=False) + bt.plot( + filename=f, + plot_drawdown=False, + plot_equity=False, + plot_pl=False, + plot_volume=False, + open_browser=False, + ) def test_indicator_scatter(self): class S(Strategy): @@ -847,9 +956,14 @@ def next(self): bt = Backtest(GOOG, S) bt.run() with _tempfile() as f: - bt.plot(filename=f, - plot_drawdown=False, plot_equity=False, plot_pl=False, plot_volume=False, - open_browser=False) + bt.plot( + filename=f, + plot_drawdown=False, + plot_equity=False, + plot_pl=False, + plot_volume=False, + open_browser=False, + ) class TestLib(TestCase): @@ -865,44 +979,41 @@ def test_cross(self): def test_crossover(self): self.assertTrue(crossover([0, 1], [1, 0])) - self.assertTrue(crossover([0, 1], .5)) - self.assertTrue(crossover([0, 1], pd.Series([.5, .5], index=[5, 6]))) + self.assertTrue(crossover([0, 1], 0.5)) + self.assertTrue(crossover([0, 1], pd.Series([0.5, 0.5], index=[5, 6]))) self.assertFalse(crossover([1, 0], [1, 0])) self.assertFalse(crossover([0], [1])) def test_quantile(self): - self.assertEqual(quantile(np.r_[1, 3, 2], .5), 2) - self.assertEqual(quantile(np.r_[1, 3, 2]), .5) + self.assertEqual(quantile(np.r_[1, 3, 2], 0.5), 2) + self.assertEqual(quantile(np.r_[1, 3, 2]), 0.5) def test_resample_apply(self): - res = resample_apply('D', SMA, EURUSD.Close, 10) - self.assertEqual(res.name, 'C[D]') - self.assertEqual(res.count() / res.size, .9634) - np.testing.assert_almost_equal(res.iloc[-48:].unique().tolist(), - [1.242643, 1.242381, 1.242275], - decimal=6) + res = resample_apply("D", SMA, EURUSD.Close, 10) + self.assertEqual(res.name, "C[D]") + self.assertEqual(res.count() / res.size, 0.9634) + np.testing.assert_almost_equal( + res.iloc[-48:].unique().tolist(), [1.242643, 1.242381, 1.242275], decimal=6 + ) def resets_index(*args): return pd.Series(SMA(*args).values) - res2 = resample_apply('D', resets_index, EURUSD.Close, 10) + res2 = resample_apply("D", resets_index, EURUSD.Close, 10) self.assertTrue((res.dropna() == res2.dropna()).all()) self.assertTrue((res.index == res2.index).all()) - res3 = resample_apply('D', None, EURUSD) - self.assertIn('Volume', res3) + res3 = resample_apply("D", None, EURUSD) + self.assertIn("Volume", res3) - res3 = resample_apply('D', lambda df: (df.Close, df.Close), EURUSD) + res3 = resample_apply("D", lambda df: (df.Close, df.Close), EURUSD) self.assertIsInstance(res3, pd.DataFrame) def test_plot_heatmaps(self): bt = Backtest(GOOG, SmaCross) - stats, heatmap = bt.optimize(fast=range(2, 7, 2), - slow=range(7, 15, 2), - return_heatmap=True) + stats, heatmap = bt.optimize(fast=range(2, 7, 2), slow=range(7, 15, 2), return_heatmap=True) with _tempfile() as f: - for agg in ('mean', - lambda x: np.percentile(x, 75)): + for agg in ("mean", lambda x: np.percentile(x, 75)): plot_heatmaps(heatmap, agg, filename=f, open_browser=False) # Preview @@ -919,12 +1030,12 @@ def test_random_ohlc_data(self): def test_compute_stats(self): stats = Backtest(GOOG, SmaCross).run() only_long_trades = stats._trades[stats._trades.Size > 0] - long_stats = compute_stats(stats=stats, trades=only_long_trades, - data=GOOG, risk_free_rate=.02) - self.assertNotEqual(list(stats._equity_curve.Equity), - list(long_stats._equity_curve.Equity)) - self.assertNotEqual(stats['Sharpe Ratio'], long_stats['Sharpe Ratio']) - self.assertEqual(long_stats['# Trades'], len(only_long_trades)) + long_stats = compute_stats( + stats=stats, trades=only_long_trades, data=GOOG, risk_free_rate=0.02 + ) + self.assertNotEqual(list(stats._equity_curve.Equity), list(long_stats._equity_curve.Equity)) + self.assertNotEqual(stats["Sharpe Ratio"], long_stats["Sharpe Ratio"]) + self.assertEqual(long_stats["# Trades"], len(only_long_trades)) self.assertEqual(stats._strategy, long_stats._strategy) assert_frame_equal(long_stats._trades, only_long_trades) @@ -932,18 +1043,17 @@ def test_SignalStrategy(self): class S(SignalStrategy): def init(self): sma = self.data.Close.s.rolling(10).mean() - self.set_signal(self.data.Close > sma, - self.data.Close < sma) + self.set_signal(self.data.Close > sma, self.data.Close < sma) stats = Backtest(GOOG, S).run() - self.assertIn(stats['# Trades'], (1179, 1180)) # varies on different archs? + self.assertIn(stats["# Trades"], (1179, 1180)) # varies on different archs? def test_TrailingStrategy(self): class S(TrailingStrategy): def init(self): super().init() self.set_atr_periods(40) - self.set_trailing_pct(.1) + self.set_trailing_pct(0.1) self.set_trailing_sl(3) self.sma = self.I(lambda: self.data.Close.s.rolling(10).mean()) @@ -953,24 +1063,27 @@ def next(self): self.buy() stats = Backtest(GOOG, S).run() - self.assertEqual(stats['# Trades'], 56) + self.assertEqual(stats["# Trades"], 56) def test_FractionalBacktest(self): - ubtc_bt = FractionalBacktest(BTCUSD['2015':], SmaCross, fractional_unit=1 / 1e6, cash=100) + ubtc_bt = FractionalBacktest(BTCUSD["2015":], SmaCross, fractional_unit=1 / 1e6, cash=100) stats = ubtc_bt.run(fast=2, slow=3) - self.assertEqual(stats['# Trades'], 41) - trades = stats['_trades'] + self.assertEqual(stats["# Trades"], 41) + trades = stats["_trades"] self.assertEqual(len(trades), 41) trade = trades.iloc[0] - self.assertAlmostEqual(trade['EntryPrice'], 236.69) - self.assertAlmostEqual(stats['_strategy']._indicators[0][trade['EntryBar']], 234.14) + self.assertAlmostEqual(trade["EntryPrice"], 236.69) + self.assertAlmostEqual(stats["_strategy"]._indicators[0][trade["EntryBar"]], 234.14) def test_MultiBacktest(self): import backtesting - assert callable(getattr(backtesting, 'Pool', None)), backtesting.__dict__ + + assert callable(getattr(backtesting, "Pool", None)), backtesting.__dict__ for start_method in mp.get_all_start_methods(): - with self.subTest(start_method=start_method), \ - patch(backtesting, 'Pool', mp.get_context(start_method).Pool): + with ( + self.subTest(start_method=start_method), + patch(backtesting, "Pool", mp.get_context(start_method).Pool), + ): start_time = time.monotonic() btm = MultiBacktest([GOOG, EURUSD, BTCUSD], SmaCross, cash=100_000) res = btm.run(fast=2) @@ -992,24 +1105,25 @@ class Class: def __call__(self): pass - self.assertEqual(_as_str('4'), '4') - self.assertEqual(_as_str(4), '4') - self.assertEqual(_as_str(_Indicator([1, 2], name='x')), 'x') - self.assertEqual(_as_str(func), 'func') - self.assertEqual(_as_str(Class), 'Class') - self.assertEqual(_as_str(Class()), 'Class') - self.assertEqual(_as_str(pd.Series([1, 2], name='x')), 'x') - self.assertEqual(_as_str(pd.DataFrame()), 'df') - self.assertEqual(_as_str(lambda x: x), 'λ') - for s in ('Open', 'High', 'Low', 'Close', 'Volume'): + self.assertEqual(_as_str("4"), "4") + self.assertEqual(_as_str(4), "4") + self.assertEqual(_as_str(_Indicator([1, 2], name="x")), "x") + self.assertEqual(_as_str(func), "func") + self.assertEqual(_as_str(Class), "Class") + self.assertEqual(_as_str(Class()), "Class") + self.assertEqual(_as_str(pd.Series([1, 2], name="x")), "x") + self.assertEqual(_as_str(pd.DataFrame()), "df") + self.assertEqual(_as_str(lambda x: x), "λ") + for s in ("Open", "High", "Low", "Close", "Volume"): self.assertEqual(_as_str(_Array([1], name=s)), s[0]) def test_patch(self): class Object: pass + o = Object() o.attr = False - with patch(o, 'attr', True): + with patch(o, "attr", True): self.assertTrue(o.attr) self.assertFalse(o.attr) @@ -1018,14 +1132,14 @@ class S(Strategy): def init(self): close, index = self.data.Close, self.data.index assert close.s.equals(pd.Series(close, index=index)) - assert self.data.df['Close'].equals(pd.Series(close, index=index)) - self.data.df['new_key'] = 2 * close + assert self.data.df["Close"].equals(pd.Series(close, index=index)) + self.data.df["new_key"] = 2 * close def next(self): close, index = self.data.Close, self.data.index assert close.s.equals(pd.Series(close, index=index)) - assert self.data.df['Close'].equals(pd.Series(close, index=index)) - assert self.data.df['new_key'].equals(pd.Series(self.data.new_key, index=index)) + assert self.data.df["Close"].equals(pd.Series(close, index=index)) + assert self.data.df["new_key"].equals(pd.Series(self.data.new_key, index=index)) Backtest(GOOG.iloc[:20], S).run() @@ -1033,21 +1147,23 @@ def test_indicators_picklable(self): bt = Backtest(SHORT_DATA, SmaCross) with ProcessPoolExecutor() as executor: stats = executor.submit(Backtest.run, bt).result() - assert stats._strategy._indicators[0]._opts, '._opts and .name were not unpickled' - bt.plot(results=stats, resample='2d', open_browser=False) + assert stats._strategy._indicators[0]._opts, "._opts and .name were not unpickled" + bt.plot(results=stats, resample="2d", open_browser=False) class TestDocs(TestCase): - DOCS_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'doc') + DOCS_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "doc") @unittest.skipUnless(os.path.isdir(DOCS_DIR), "docs dir doesn't exist") - @unittest.skipUnless(sys.platform.startswith('linux'), "test_examples requires mp.start_method=fork") + @unittest.skipUnless( + sys.platform.startswith("linux"), "test_examples requires mp.start_method=fork" + ) def test_examples(self): import backtesting - examples = glob(os.path.join(self.DOCS_DIR, 'examples', '*.py')) + + examples = glob(os.path.join(self.DOCS_DIR, "examples", "*.py")) self.assertGreaterEqual(len(examples), 4) - with chdir(gettempdir()), \ - patch(backtesting, 'Pool', mp.get_context('fork').Pool): + with chdir(gettempdir()), patch(backtesting, "Pool", mp.get_context("fork").Pool): for file in examples: with self.subTest(example=os.path.basename(file)): run_path(file) @@ -1058,8 +1174,7 @@ def test_backtest_run_docstring_contains_stats_keys(self): self.assertIn(key, Backtest.run.__doc__) def test_readme_contains_stats_keys(self): - with open(os.path.join(os.path.dirname(__file__), - '..', '..', 'README.md')) as f: + with open(os.path.join(os.path.dirname(__file__), "..", "..", "README.md")) as f: readme = f.read() stats = Backtest(SHORT_DATA, SmaCross).run() for key in stats.index: @@ -1074,15 +1189,15 @@ def next(self): self.buy(size=1, sl=90) arr = np.r_[100, 100, 100, 50, 50] - df = pd.DataFrame({'Open': arr, 'High': arr, 'Low': arr, 'Close': arr}) - with self.assertWarnsRegex(UserWarning, 'index is not datetime'): + df = pd.DataFrame({"Open": arr, "High": arr, "Low": arr, "Close": arr}) + with self.assertWarnsRegex(UserWarning, "index is not datetime"): bt = Backtest(df, S, cash=100, trade_on_close=True) - self.assertEqual(bt.run()._trades['ExitPrice'][0], 50) + self.assertEqual(bt.run()._trades["ExitPrice"][0], 50) def test_stats_annualized(self): - stats = Backtest(GOOG.resample('W').agg(OHLCV_AGG), SmaCross).run() - self.assertFalse(np.isnan(stats['Return (Ann.) [%]'])) - self.assertEqual(round(stats['Return (Ann.) [%]']), -3) + stats = Backtest(GOOG.resample("W").agg(OHLCV_AGG), SmaCross).run() + self.assertFalse(np.isnan(stats["Return (Ann.) [%]"])) + self.assertEqual(round(stats["Return (Ann.) [%]"]), -3) def test_cancel_orders(self): class S(_S): @@ -1103,38 +1218,35 @@ def coro(strat): yield arr = np.r_[100, 101, 102, 50, 51] - df = pd.DataFrame({ - 'Open': arr - 10, - 'Close': arr, 'High': arr, 'Low': arr}) - with self.assertWarnsRegex(UserWarning, 'index is not datetime'): + df = pd.DataFrame({"Open": arr - 10, "Close": arr, "High": arr, "Low": arr}) + with self.assertWarnsRegex(UserWarning, "index is not datetime"): trades = TestStrategy._Backtest(coro, df, cash=250, trade_on_close=True).run()._trades # trades = Backtest(df, S, cash=250, trade_on_close=True).run()._trades - self.assertEqual(trades['EntryBar'][0], 1) - self.assertEqual(trades['ExitBar'][0], 2) - self.assertEqual(trades['EntryPrice'][0], 101) - self.assertEqual(trades['ExitPrice'][0], 102) - self.assertEqual(trades['EntryBar'][1], 1) - self.assertEqual(trades['ExitBar'][1], 3) - self.assertEqual(trades['EntryPrice'][1], 101) - self.assertEqual(trades['ExitPrice'][1], 40) - - with self.assertWarnsRegex(UserWarning, 'index is not datetime'): + self.assertEqual(trades["EntryBar"][0], 1) + self.assertEqual(trades["ExitBar"][0], 2) + self.assertEqual(trades["EntryPrice"][0], 101) + self.assertEqual(trades["ExitPrice"][0], 102) + self.assertEqual(trades["EntryBar"][1], 1) + self.assertEqual(trades["ExitBar"][1], 3) + self.assertEqual(trades["EntryPrice"][1], 101) + self.assertEqual(trades["ExitPrice"][1], 40) + + with self.assertWarnsRegex(UserWarning, "index is not datetime"): trades = TestStrategy._Backtest(coro, df, cash=250, trade_on_close=False).run()._trades # trades = Backtest(df, S, cash=250, trade_on_close=False).run()._trades - self.assertEqual(trades['EntryBar'][0], 2) - self.assertEqual(trades['ExitBar'][0], 3) - self.assertEqual(trades['EntryPrice'][0], 92) - self.assertEqual(trades['ExitPrice'][0], 40) - self.assertEqual(trades['EntryBar'][1], 2) - self.assertEqual(trades['ExitBar'][1], 3) - self.assertEqual(trades['EntryPrice'][1], 92) - self.assertEqual(trades['ExitPrice'][1], 40) + self.assertEqual(trades["EntryBar"][0], 2) + self.assertEqual(trades["ExitBar"][0], 3) + self.assertEqual(trades["EntryPrice"][0], 92) + self.assertEqual(trades["ExitPrice"][0], 40) + self.assertEqual(trades["EntryBar"][1], 2) + self.assertEqual(trades["ExitBar"][1], 3) + self.assertEqual(trades["EntryPrice"][1], 92) + self.assertEqual(trades["ExitPrice"][1], 40) def test_trades_dates_match_prices(self): bt = Backtest(EURUSD, SmaCross, trade_on_close=True) trades = bt.run()._trades - self.assertEqual(EURUSD.Close[trades['ExitTime']].tolist(), - trades['ExitPrice'].tolist()) + self.assertEqual(EURUSD.Close[trades["ExitTime"]].tolist(), trades["ExitPrice"].tolist()) def test_sl_always_before_tp(self): class S(_S): @@ -1148,7 +1260,7 @@ def next(self): t.tp = 107.9 trades = Backtest(SHORT_DATA, S).run()._trades - self.assertEqual(trades['ExitPrice'].iloc[0], 104.95) + self.assertEqual(trades["ExitPrice"].iloc[0], 104.95) def test_stop_entry_and_tp_in_same_bar(self): class S(_S): @@ -1158,14 +1270,14 @@ def next(self): self.sell(stop=108, tp=105, sl=113) trades = Backtest(SHORT_DATA, S).run()._trades - self.assertEqual(trades['ExitBar'].iloc[0], 3) - self.assertEqual(trades['ExitPrice'].iloc[0], 105) + self.assertEqual(trades["ExitBar"].iloc[0], 3) + self.assertEqual(trades["ExitPrice"].iloc[0], 105) def test_optimize_datetime_index_with_timezone(self): data: pd.DataFrame = GOOG.iloc[:100] - data.index = data.index.tz_localize('Asia/Kolkata') + data.index = data.index.tz_localize("Asia/Kolkata") res = Backtest(data, SmaCross).optimize(fast=range(2, 3), slow=range(4, 5)) - self.assertGreater(res['# Trades'], 0) + self.assertGreater(res["# Trades"], 0) def test_sl_tp_values_in_trades_df(self): class S(_S): @@ -1175,5 +1287,5 @@ def next(self): self.buy(size=1, sl=99) trades = Backtest(SHORT_DATA, S).run()._trades - self.assertEqual(trades['SL'].fillna(0).tolist(), [0, 99]) - self.assertEqual(trades['TP'].fillna(0).tolist(), [111, 0]) + self.assertEqual(trades["SL"].fillna(0).tolist(), [0, 99]) + self.assertEqual(trades["TP"].fillna(0).tolist(), [111, 0])