Skip to content

Commit 2c67c8f

Browse files
feat: pruning implemented, optuna integrated
closes #95
1 parent 520de1f commit 2c67c8f

File tree

5 files changed

+473
-0
lines changed

5 files changed

+473
-0
lines changed

examples/autotune_config.yaml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Example configuration for hyperparameter tuning with Optuna
2+
# Usage: qrs autotune -c examples/autotune_config.yaml
3+
4+
# Data configuration
5+
data_file: "data_sample/sample_prices.csv"
6+
7+
# Factor to optimize
8+
factor_type: "momentum" # Options: momentum, value, size, volatility
9+
10+
# Optimization settings
11+
n_trials: 100 # Number of trials to run
12+
metric: "sharpe_ratio" # Metric to optimize (sharpe_ratio, total_return, cagr, etc.)
13+
14+
# Output configuration
15+
output: "output/tuning_results.json"
16+
study_name: "momentum_factor_study"
17+
18+
# Pruning configuration (for early stopping of bad trials)
19+
# Options: none, median, percentile
20+
pruner: "median"
21+
22+
# Optional: RDB storage for distributed tuning runs
23+
# Uncomment and configure for multi-worker setups
24+
# storage: "sqlite:///optuna.db"
25+
# For PostgreSQL: "postgresql://user:password@localhost/dbname"
26+
# For MySQL: "mysql://user:password@localhost/dbname"
27+

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ dependencies = [
3131
"uvicorn>=0.23.0",
3232
"python-dotenv>=1.0.0",
3333
"requests>=2.31.0",
34+
"optuna>=3.0.0",
35+
"pyyaml>=6.0",
3436
]
3537

3638
[project.optional-dependencies]

src/quant_research_starter/cli.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
import click
77
import matplotlib.pyplot as plt
88
import pandas as pd
9+
import yaml
910
from tqdm import tqdm
1011

1112
from .backtest import VectorizedBacktest
1213
from .data import SampleDataLoader, SyntheticDataGenerator
1314
from .factors import MomentumFactor, SizeFactor, ValueFactor, VolatilityFactor
1415
from .metrics import RiskMetrics, create_equity_curve_plot
16+
from .tuning import OptunaRunner, create_backtest_objective
1517

1618

1719
@click.group()
@@ -247,5 +249,138 @@ def backtest(data_file, signals_file, initial_capital, output, plot, plotly):
247249
click.echo(f"Results saved -> {output}")
248250

249251

252+
@cli.command()
253+
@click.option(
254+
"--config",
255+
"-c",
256+
type=click.Path(exists=True),
257+
help="YAML configuration file for hyperparameter tuning",
258+
)
259+
@click.option(
260+
"--data-file",
261+
"-d",
262+
default="data_sample/sample_prices.csv",
263+
help="Price data file path",
264+
)
265+
@click.option(
266+
"--factor-type",
267+
"-f",
268+
type=click.Choice(["momentum", "value", "size", "volatility"]),
269+
default="momentum",
270+
help="Factor type to optimize",
271+
)
272+
@click.option(
273+
"--n-trials",
274+
"-n",
275+
default=100,
276+
help="Number of optimization trials",
277+
)
278+
@click.option(
279+
"--metric",
280+
"-m",
281+
default="sharpe_ratio",
282+
help="Metric to optimize (sharpe_ratio, total_return, cagr, etc.)",
283+
)
284+
@click.option(
285+
"--output",
286+
"-o",
287+
default="output/tuning_results.json",
288+
help="Output file for tuning results",
289+
)
290+
@click.option(
291+
"--storage",
292+
"-s",
293+
default=None,
294+
help="RDB storage URL (e.g., sqlite:///optuna.db) for distributed tuning",
295+
)
296+
@click.option(
297+
"--pruner",
298+
"-p",
299+
type=click.Choice(["none", "median", "percentile"]),
300+
default="median",
301+
help="Pruning strategy for early stopping",
302+
)
303+
@click.option(
304+
"--study-name",
305+
default="optuna_study",
306+
help="Name of the Optuna study",
307+
)
308+
def autotune(
309+
config,
310+
data_file,
311+
factor_type,
312+
n_trials,
313+
metric,
314+
output,
315+
storage,
316+
pruner,
317+
study_name,
318+
):
319+
"""Run hyperparameter optimization with Optuna."""
320+
click.echo("Starting hyperparameter optimization...")
321+
322+
# Load configuration from YAML if provided
323+
search_space = None
324+
if config:
325+
with open(config, "r") as f:
326+
config_data = yaml.safe_load(f)
327+
data_file = config_data.get("data_file", data_file)
328+
factor_type = config_data.get("factor_type", factor_type)
329+
n_trials = config_data.get("n_trials", n_trials)
330+
metric = config_data.get("metric", metric)
331+
output = config_data.get("output", output)
332+
storage = config_data.get("storage", storage)
333+
pruner = config_data.get("pruner", pruner)
334+
study_name = config_data.get("study_name", study_name)
335+
search_space = config_data.get("search_space", None)
336+
337+
# Load data
338+
if Path(data_file).exists():
339+
prices = pd.read_csv(data_file, index_col=0, parse_dates=True)
340+
else:
341+
click.echo("Data file not found, using sample data...")
342+
loader = SampleDataLoader()
343+
prices = loader.load_sample_prices()
344+
345+
click.echo(f"Optimizing {factor_type} factor with {n_trials} trials...")
346+
click.echo(f"Optimizing metric: {metric}")
347+
348+
# Create objective function
349+
objective = create_backtest_objective(
350+
prices=prices,
351+
factor_type=factor_type,
352+
metric=metric,
353+
search_space=search_space,
354+
)
355+
356+
# Create and run Optuna runner
357+
runner = OptunaRunner(
358+
objective=objective,
359+
n_trials=n_trials,
360+
study_name=study_name,
361+
storage=storage,
362+
pruner=pruner,
363+
direction=(
364+
"maximize"
365+
if metric in ["sharpe_ratio", "total_return", "cagr"]
366+
else "minimize"
367+
),
368+
)
369+
370+
# Run optimization
371+
results = runner.optimize()
372+
373+
# Save results
374+
runner.save_results(output)
375+
376+
click.echo("\n" + "=" * 60)
377+
click.echo("Optimization Results")
378+
click.echo("=" * 60)
379+
click.echo(f"Best parameters: {results['best_params']}")
380+
click.echo(f"Best {metric}: {results['best_value']:.4f}")
381+
click.echo(f"Total trials: {len(results['trial_history'])}")
382+
click.echo(f"Results saved -> {output}")
383+
384+
250385
if __name__ == "__main__":
251386
cli()
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Hyperparameter tuning with Optuna."""
2+
3+
from .optuna_runner import OptunaRunner, create_backtest_objective
4+
5+
__all__ = ["OptunaRunner", "create_backtest_objective"]

0 commit comments

Comments
 (0)