Skip to content

Commit c4f5137

Browse files
feat: added loading bars centrally cuz file wise approach resulted in pytest failing a lot
closes #32
1 parent 857b525 commit c4f5137

File tree

3 files changed

+23
-16
lines changed

3 files changed

+23
-16
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@ __pycache__/
55
*.py[cod]
66
*$py.class
77
src/quant_research_starter.egg-info/PKG-INFO
8+
9+
myenv/

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ The backtester supports different rebalancing frequencies to match your strategy
125125

126126
```python
127127
from quant_research_starter.backtest import VectorizedBacktest
128-
129128
# Daily rebalancing (default)
130129
bt_daily = VectorizedBacktest(prices, signals, rebalance_freq="D")
131130

src/quant_research_starter/cli.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import click
77
import matplotlib.pyplot as plt
88
import pandas as pd
9+
from tqdm import tqdm
910

1011
from .backtest import VectorizedBacktest
1112
from .data import SampleDataLoader, SyntheticDataGenerator
@@ -30,9 +31,14 @@ def generate_data(output, symbols, days):
3031
click.echo("Generating synthetic price data...")
3132

3233
generator = SyntheticDataGenerator()
33-
prices = generator.generate_price_data(
34-
n_symbols=symbols, days=days, start_date="2020-01-01"
35-
)
34+
all_prices = []
35+
for _ in tqdm(range(symbols), desc="Generating price series"):
36+
prices = generator.generate_price_data(
37+
n_symbols=1, days=days, start_date="2020-01-01"
38+
)
39+
all_prices.append(prices)
40+
41+
prices = pd.concat(all_prices, axis=1)
3642

3743
# Ensure output directory exists
3844
output_path = Path(output)
@@ -94,8 +100,11 @@ def compute_factors(data_file, factors, output):
94100
vol = VolatilityFactor(lookback=21)
95101
factor_data["volatility"] = vol.compute(prices)
96102

97-
# Combine factors (simple average for demo)
98-
combined_signals = pd.DataFrame({k: v.mean(axis=1) for k, v in factor_data.items()})
103+
combined_signals_dict = {}
104+
for k, v in tqdm(factor_data.items(), desc="Averaging factors"):
105+
combined_signals_dict[k] = v.mean(axis=1)
106+
107+
combined_signals = pd.DataFrame(combined_signals_dict)
99108
combined_signals["composite"] = combined_signals.mean(axis=1)
100109

101110
# Save results
@@ -148,7 +157,7 @@ def backtest(data_file, signals_file, initial_capital, output, plot, plotly):
148157
# Load signals
149158
if Path(signals_file).exists():
150159
signals_data = pd.read_csv(signals_file, index_col=0, parse_dates=True)
151-
# Use composite signal if available, otherwise first column
160+
# If a 'composite' signal column exists, use it; otherwise, fall back to the first available signal column.
152161
if "composite" in signals_data.columns:
153162
signals = signals_data["composite"]
154163
else:
@@ -158,27 +167,27 @@ def backtest(data_file, signals_file, initial_capital, output, plot, plotly):
158167
momentum = MomentumFactor(lookback=63)
159168
signals = momentum.compute(prices).mean(axis=1)
160169

161-
# Ensure signals align with prices
170+
# Align dates
162171
common_dates = prices.index.intersection(signals.index)
163172
prices = prices.loc[common_dates]
164173
signals = signals.loc[common_dates]
165174

166-
# Expand signals to all symbols (simplified - same signal for all)
175+
# Expand signals across symbols
167176
signal_matrix = pd.DataFrame(
168177
dict.fromkeys(prices.columns, signals), index=signals.index
169178
)
170179

171-
# Run backtest
180+
# Use the original vectorized run() method for performance
181+
172182
backtest = VectorizedBacktest(
173183
prices=prices,
174184
signals=signal_matrix,
175185
initial_capital=initial_capital,
176186
transaction_cost=0.001,
177187
)
178-
179188
results = backtest.run(weight_scheme="rank")
180189

181-
# Calculate metrics
190+
# Metrics
182191
metrics_calc = RiskMetrics(results["returns"])
183192
metrics = metrics_calc.calculate_all()
184193

@@ -195,18 +204,16 @@ def backtest(data_file, signals_file, initial_capital, output, plot, plotly):
195204
with open(output_path, "w") as f:
196205
json.dump(results_dict, f, indent=2)
197206

198-
# Generate plot
207+
# Plotting
199208
if plot:
200209
plt.figure(figsize=(12, 8))
201210

202-
# Plot portfolio value
203211
plt.subplot(2, 1, 1)
204212
plt.plot(results["portfolio_value"].index, results["portfolio_value"].values)
205213
plt.title("Portfolio Value")
206214
plt.ylabel("USD")
207215
plt.grid(True)
208216

209-
# Plot returns
210217
plt.subplot(2, 1, 2)
211218
plt.bar(results["returns"].index, results["returns"].values, alpha=0.7)
212219
plt.title("Daily Returns")
@@ -220,7 +227,6 @@ def backtest(data_file, signals_file, initial_capital, output, plot, plotly):
220227

221228
click.echo(f"Plot saved -> {plot_path}")
222229

223-
# Generate Plotly HTML chart if requested
224230
if plotly:
225231
html_path = output_path.parent / "backtest_plot.html"
226232

0 commit comments

Comments
 (0)