Skip to content

Commit ccf7d8f

Browse files
feat: added loading bars centrally cuz file wise approach resulted in pytest failing a lot
1 parent af423cc commit ccf7d8f

File tree

2 files changed

+45
-15
lines changed

2 files changed

+45
-15
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@ __pycache__/
55
*.py[cod]
66
*$py.class
77
src/quant_research_starter.egg-info/PKG-INFO
8+
9+
myenv/

src/quant_research_starter/cli.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import click
77
import matplotlib.pyplot as plt
88
import pandas as pd
9+
from tqdm import tqdm
910

1011
from .backtest import VectorizedBacktest
1112
from .data import SampleDataLoader, SyntheticDataGenerator
@@ -30,9 +31,14 @@ def generate_data(output, symbols, days):
3031
click.echo("Generating synthetic price data...")
3132

3233
generator = SyntheticDataGenerator()
33-
prices = generator.generate_price_data(
34-
n_symbols=symbols, days=days, start_date="2020-01-01"
35-
)
34+
all_prices = []
35+
for _ in tqdm(range(symbols), desc="Generating price series"):
36+
prices = generator.generate_price_data(
37+
n_symbols=1, days=days, start_date="2020-01-01"
38+
)
39+
all_prices.append(prices)
40+
41+
prices = pd.concat(all_prices, axis=1)
3642

3743
# Ensure output directory exists
3844
output_path = Path(output)
@@ -94,8 +100,13 @@ def compute_factors(data_file, factors, output):
94100
vol = VolatilityFactor(lookback=21)
95101
factor_data["volatility"] = vol.compute(prices)
96102

97-
# Combine factors (simple average for demo)
98-
combined_signals = pd.DataFrame({k: v.mean(axis=1) for k, v in factor_data.items()})
103+
combined_signals = pd.DataFrame(
104+
{
105+
k: tqdm(v.mean(axis=1), desc=f"Averaging {k} factor")
106+
for k, v in factor_data.items()
107+
}
108+
)
109+
99110
combined_signals["composite"] = combined_signals.mean(axis=1)
100111

101112
# Save results
@@ -148,7 +159,6 @@ def backtest(data_file, signals_file, initial_capital, output, plot, plotly):
148159
# Load signals
149160
if Path(signals_file).exists():
150161
signals_data = pd.read_csv(signals_file, index_col=0, parse_dates=True)
151-
# Use composite signal if available, otherwise first column
152162
if "composite" in signals_data.columns:
153163
signals = signals_data["composite"]
154164
else:
@@ -158,27 +168,48 @@ def backtest(data_file, signals_file, initial_capital, output, plot, plotly):
158168
momentum = MomentumFactor(lookback=63)
159169
signals = momentum.compute(prices).mean(axis=1)
160170

161-
# Ensure signals align with prices
171+
# Align dates
162172
common_dates = prices.index.intersection(signals.index)
163173
prices = prices.loc[common_dates]
164174
signals = signals.loc[common_dates]
165175

166-
# Expand signals to all symbols (simplified - same signal for all)
176+
# Expand signals across symbols
167177
signal_matrix = pd.DataFrame(
168178
dict.fromkeys(prices.columns, signals), index=signals.index
169179
)
170180

171-
# Run backtest
181+
def run_with_progress(self, weight_scheme="rank"):
182+
returns = []
183+
idx = self.prices.index
184+
185+
for i in tqdm(range(1, len(idx)), desc="Running backtest"):
186+
ret = self._compute_daily_return(
187+
self.prices.iloc[i - 1],
188+
self.prices.iloc[i],
189+
weight_scheme,
190+
)
191+
returns.append(ret)
192+
193+
results = pd.DataFrame({"returns": returns}, index=idx[1:])
194+
results["portfolio_value"] = (
195+
self.initial_capital * (1 + results["returns"]).cumprod()
196+
)
197+
results["final_value"] = results["portfolio_value"].iloc[-1]
198+
results["total_return"] = results["final_value"] / self.initial_capital - 1
199+
200+
return results
201+
202+
VectorizedBacktest.run = run_with_progress
203+
172204
backtest = VectorizedBacktest(
173205
prices=prices,
174206
signals=signal_matrix,
175207
initial_capital=initial_capital,
176208
transaction_cost=0.001,
177209
)
178-
179210
results = backtest.run(weight_scheme="rank")
180211

181-
# Calculate metrics
212+
# Metrics
182213
metrics_calc = RiskMetrics(results["returns"])
183214
metrics = metrics_calc.calculate_all()
184215

@@ -195,18 +226,16 @@ def backtest(data_file, signals_file, initial_capital, output, plot, plotly):
195226
with open(output_path, "w") as f:
196227
json.dump(results_dict, f, indent=2)
197228

198-
# Generate plot
229+
# Plotting
199230
if plot:
200231
plt.figure(figsize=(12, 8))
201232

202-
# Plot portfolio value
203233
plt.subplot(2, 1, 1)
204234
plt.plot(results["portfolio_value"].index, results["portfolio_value"].values)
205235
plt.title("Portfolio Value")
206236
plt.ylabel("USD")
207237
plt.grid(True)
208238

209-
# Plot returns
210239
plt.subplot(2, 1, 2)
211240
plt.bar(results["returns"].index, results["returns"].values, alpha=0.7)
212241
plt.title("Daily Returns")
@@ -220,7 +249,6 @@ def backtest(data_file, signals_file, initial_capital, output, plot, plotly):
220249

221250
click.echo(f"Plot saved -> {plot_path}")
222251

223-
# Generate Plotly HTML chart if requested
224252
if plotly:
225253
html_path = output_path.parent / "backtest_plot.html"
226254

0 commit comments

Comments
 (0)