Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ __pycache__/
*.py[cod]
*$py.class
src/quant_research_starter.egg-info/PKG-INFO

myenv/
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ The backtester supports different rebalancing frequencies to match your strategy

```python
from quant_research_starter.backtest import VectorizedBacktest

# Daily rebalancing (default)
bt_daily = VectorizedBacktest(prices, signals, rebalance_freq="D")

Expand Down
36 changes: 21 additions & 15 deletions src/quant_research_starter/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import click
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm

from .backtest import VectorizedBacktest
from .data import SampleDataLoader, SyntheticDataGenerator
Expand All @@ -30,9 +31,14 @@ def generate_data(output, symbols, days):
click.echo("Generating synthetic price data...")

generator = SyntheticDataGenerator()
prices = generator.generate_price_data(
n_symbols=symbols, days=days, start_date="2020-01-01"
)
all_prices = []
for _ in tqdm(range(symbols), desc="Generating price series"):
prices = generator.generate_price_data(
n_symbols=1, days=days, start_date="2020-01-01"
)
all_prices.append(prices)

prices = pd.concat(all_prices, axis=1)

# Ensure output directory exists
output_path = Path(output)
Expand Down Expand Up @@ -94,8 +100,11 @@ def compute_factors(data_file, factors, output):
vol = VolatilityFactor(lookback=21)
factor_data["volatility"] = vol.compute(prices)

# Combine factors (simple average for demo)
combined_signals = pd.DataFrame({k: v.mean(axis=1) for k, v in factor_data.items()})
combined_signals_dict = {}
for k, v in tqdm(factor_data.items(), desc="Averaging factors"):
combined_signals_dict[k] = v.mean(axis=1)

combined_signals = pd.DataFrame(combined_signals_dict)
combined_signals["composite"] = combined_signals.mean(axis=1)

# Save results
Expand Down Expand Up @@ -148,7 +157,7 @@ def backtest(data_file, signals_file, initial_capital, output, plot, plotly):
# Load signals
if Path(signals_file).exists():
signals_data = pd.read_csv(signals_file, index_col=0, parse_dates=True)
# Use composite signal if available, otherwise first column
# If a 'composite' signal column exists, use it; otherwise, fall back to the first available signal column.
if "composite" in signals_data.columns:
signals = signals_data["composite"]
else:
Expand All @@ -158,27 +167,27 @@ def backtest(data_file, signals_file, initial_capital, output, plot, plotly):
momentum = MomentumFactor(lookback=63)
signals = momentum.compute(prices).mean(axis=1)

# Ensure signals align with prices
# Align dates
common_dates = prices.index.intersection(signals.index)
prices = prices.loc[common_dates]
signals = signals.loc[common_dates]

# Expand signals to all symbols (simplified - same signal for all)
# Expand signals across symbols
signal_matrix = pd.DataFrame(
dict.fromkeys(prices.columns, signals), index=signals.index
)

# Run backtest
# Use the original vectorized run() method for performance

backtest = VectorizedBacktest(
prices=prices,
signals=signal_matrix,
initial_capital=initial_capital,
transaction_cost=0.001,
)

results = backtest.run(weight_scheme="rank")

# Calculate metrics
# Metrics
metrics_calc = RiskMetrics(results["returns"])
metrics = metrics_calc.calculate_all()

Expand All @@ -195,18 +204,16 @@ def backtest(data_file, signals_file, initial_capital, output, plot, plotly):
with open(output_path, "w") as f:
json.dump(results_dict, f, indent=2)

# Generate plot
# Plotting
if plot:
plt.figure(figsize=(12, 8))

# Plot portfolio value
plt.subplot(2, 1, 1)
plt.plot(results["portfolio_value"].index, results["portfolio_value"].values)
plt.title("Portfolio Value")
plt.ylabel("USD")
plt.grid(True)

# Plot returns
plt.subplot(2, 1, 2)
plt.bar(results["returns"].index, results["returns"].values, alpha=0.7)
plt.title("Daily Returns")
Expand All @@ -220,7 +227,6 @@ def backtest(data_file, signals_file, initial_capital, output, plot, plotly):

click.echo(f"Plot saved -> {plot_path}")

# Generate Plotly HTML chart if requested
if plotly:
html_path = output_path.parent / "backtest_plot.html"

Expand Down
Loading