66import click
77import matplotlib .pyplot as plt
88import pandas as pd
9+ from tqdm import tqdm
910
1011from .backtest import VectorizedBacktest
1112from .data import SampleDataLoader , SyntheticDataGenerator
@@ -30,9 +31,14 @@ def generate_data(output, symbols, days):
3031 click .echo ("Generating synthetic price data..." )
3132
3233 generator = SyntheticDataGenerator ()
33- prices = generator .generate_price_data (
34- n_symbols = symbols , days = days , start_date = "2020-01-01"
35- )
34+ all_prices = []
35+ for _ in tqdm (range (symbols ), desc = "Generating price series" ):
36+ prices = generator .generate_price_data (
37+ n_symbols = 1 , days = days , start_date = "2020-01-01"
38+ )
39+ all_prices .append (prices )
40+
41+ prices = pd .concat (all_prices , axis = 1 )
3642
3743 # Ensure output directory exists
3844 output_path = Path (output )
@@ -94,8 +100,11 @@ def compute_factors(data_file, factors, output):
94100 vol = VolatilityFactor (lookback = 21 )
95101 factor_data ["volatility" ] = vol .compute (prices )
96102
97- # Combine factors (simple average for demo)
98- combined_signals = pd .DataFrame ({k : v .mean (axis = 1 ) for k , v in factor_data .items ()})
103+ combined_signals_dict = {}
104+ for k , v in tqdm (factor_data .items (), desc = "Averaging factors" ):
105+ combined_signals_dict [k ] = v .mean (axis = 1 )
106+
107+ combined_signals = pd .DataFrame (combined_signals_dict )
99108 combined_signals ["composite" ] = combined_signals .mean (axis = 1 )
100109
101110 # Save results
@@ -148,7 +157,7 @@ def backtest(data_file, signals_file, initial_capital, output, plot, plotly):
148157 # Load signals
149158 if Path (signals_file ).exists ():
150159 signals_data = pd .read_csv (signals_file , index_col = 0 , parse_dates = True )
151- # Use composite signal if available, otherwise first column
160+ # If a ' composite' signal column exists, use it; otherwise, fall back to the first available signal column.
152161 if "composite" in signals_data .columns :
153162 signals = signals_data ["composite" ]
154163 else :
@@ -158,27 +167,27 @@ def backtest(data_file, signals_file, initial_capital, output, plot, plotly):
158167 momentum = MomentumFactor (lookback = 63 )
159168 signals = momentum .compute (prices ).mean (axis = 1 )
160169
161- # Ensure signals align with prices
170+ # Align dates
162171 common_dates = prices .index .intersection (signals .index )
163172 prices = prices .loc [common_dates ]
164173 signals = signals .loc [common_dates ]
165174
166- # Expand signals to all symbols (simplified - same signal for all)
175+ # Expand signals across symbols
167176 signal_matrix = pd .DataFrame (
168177 dict .fromkeys (prices .columns , signals ), index = signals .index
169178 )
170179
171- # Run backtest
180+ # Use the original vectorized run() method for performance
181+
172182 backtest = VectorizedBacktest (
173183 prices = prices ,
174184 signals = signal_matrix ,
175185 initial_capital = initial_capital ,
176186 transaction_cost = 0.001 ,
177187 )
178-
179188 results = backtest .run (weight_scheme = "rank" )
180189
181- # Calculate metrics
190+ # Metrics
182191 metrics_calc = RiskMetrics (results ["returns" ])
183192 metrics = metrics_calc .calculate_all ()
184193
@@ -195,18 +204,16 @@ def backtest(data_file, signals_file, initial_capital, output, plot, plotly):
195204 with open (output_path , "w" ) as f :
196205 json .dump (results_dict , f , indent = 2 )
197206
198- # Generate plot
207+ # Plotting
199208 if plot :
200209 plt .figure (figsize = (12 , 8 ))
201210
202- # Plot portfolio value
203211 plt .subplot (2 , 1 , 1 )
204212 plt .plot (results ["portfolio_value" ].index , results ["portfolio_value" ].values )
205213 plt .title ("Portfolio Value" )
206214 plt .ylabel ("USD" )
207215 plt .grid (True )
208216
209- # Plot returns
210217 plt .subplot (2 , 1 , 2 )
211218 plt .bar (results ["returns" ].index , results ["returns" ].values , alpha = 0.7 )
212219 plt .title ("Daily Returns" )
@@ -220,7 +227,6 @@ def backtest(data_file, signals_file, initial_capital, output, plot, plotly):
220227
221228 click .echo (f"Plot saved -> { plot_path } " )
222229
223- # Generate Plotly HTML chart if requested
224230 if plotly :
225231 html_path = output_path .parent / "backtest_plot.html"
226232
0 commit comments