|
6 | 6 | import click |
7 | 7 | import matplotlib.pyplot as plt |
8 | 8 | import pandas as pd |
| 9 | +import yaml |
9 | 10 | from tqdm import tqdm |
10 | 11 |
|
11 | 12 | from .backtest import VectorizedBacktest |
12 | 13 | from .data import SampleDataLoader, SyntheticDataGenerator |
13 | 14 | from .factors import MomentumFactor, SizeFactor, ValueFactor, VolatilityFactor |
14 | 15 | from .metrics import RiskMetrics, create_equity_curve_plot |
| 16 | +from .tuning import OptunaRunner, create_backtest_objective |
15 | 17 |
|
16 | 18 |
|
17 | 19 | @click.group() |
@@ -247,5 +249,136 @@ def backtest(data_file, signals_file, initial_capital, output, plot, plotly): |
247 | 249 | click.echo(f"Results saved -> {output}") |
248 | 250 |
|
249 | 251 |
|
| 252 | +@cli.command() |
| 253 | +@click.option( |
| 254 | + "--config", |
| 255 | + "-c", |
| 256 | + type=click.Path(exists=True), |
| 257 | + help="YAML configuration file for hyperparameter tuning", |
| 258 | +) |
| 259 | +@click.option( |
| 260 | + "--data-file", |
| 261 | + "-d", |
| 262 | + default="data_sample/sample_prices.csv", |
| 263 | + help="Price data file path", |
| 264 | +) |
| 265 | +@click.option( |
| 266 | + "--factor-type", |
| 267 | + "-f", |
| 268 | + type=click.Choice(["momentum", "value", "size", "volatility"]), |
| 269 | + default="momentum", |
| 270 | + help="Factor type to optimize", |
| 271 | +) |
| 272 | +@click.option( |
| 273 | + "--n-trials", |
| 274 | + "-n", |
| 275 | + default=100, |
| 276 | + help="Number of optimization trials", |
| 277 | +) |
| 278 | +@click.option( |
| 279 | + "--metric", |
| 280 | + "-m", |
| 281 | + default="sharpe_ratio", |
| 282 | + help="Metric to optimize (sharpe_ratio, total_return, cagr, etc.)", |
| 283 | +) |
| 284 | +@click.option( |
| 285 | + "--output", |
| 286 | + "-o", |
| 287 | + default="output/tuning_results.json", |
| 288 | + help="Output file for tuning results", |
| 289 | +) |
| 290 | +@click.option( |
| 291 | + "--storage", |
| 292 | + "-s", |
| 293 | + default=None, |
| 294 | + help="RDB storage URL (e.g., sqlite:///optuna.db) for distributed tuning", |
| 295 | +) |
| 296 | +@click.option( |
| 297 | + "--pruner", |
| 298 | + "-p", |
| 299 | + type=click.Choice(["none", "median", "percentile"]), |
| 300 | + default="median", |
| 301 | + help="Pruning strategy for early stopping", |
| 302 | +) |
| 303 | +@click.option( |
| 304 | + "--study-name", |
| 305 | + default="optuna_study", |
| 306 | + help="Name of the Optuna study", |
| 307 | +) |
| 308 | +def autotune( |
| 309 | + config, |
| 310 | + data_file, |
| 311 | + factor_type, |
| 312 | + n_trials, |
| 313 | + metric, |
| 314 | + output, |
| 315 | + storage, |
| 316 | + pruner, |
| 317 | + study_name, |
| 318 | +): |
| 319 | + """Run hyperparameter optimization with Optuna.""" |
| 320 | + click.echo("Starting hyperparameter optimization...") |
| 321 | + |
| 322 | + # Load configuration from YAML if provided |
| 323 | + if config: |
| 324 | + with open(config, "r") as f: |
| 325 | + config_data = yaml.safe_load(f) |
| 326 | + data_file = config_data.get("data_file", data_file) |
| 327 | + factor_type = config_data.get("factor_type", factor_type) |
| 328 | + n_trials = config_data.get("n_trials", n_trials) |
| 329 | + metric = config_data.get("metric", metric) |
| 330 | + output = config_data.get("output", output) |
| 331 | + storage = config_data.get("storage", storage) |
| 332 | + pruner = config_data.get("pruner", pruner) |
| 333 | + study_name = config_data.get("study_name", study_name) |
| 334 | + |
| 335 | + # Load data |
| 336 | + if Path(data_file).exists(): |
| 337 | + prices = pd.read_csv(data_file, index_col=0, parse_dates=True) |
| 338 | + else: |
| 339 | + click.echo("Data file not found, using sample data...") |
| 340 | + loader = SampleDataLoader() |
| 341 | + prices = loader.load_sample_prices() |
| 342 | + |
| 343 | + click.echo(f"Optimizing {factor_type} factor with {n_trials} trials...") |
| 344 | + click.echo(f"Optimizing metric: {metric}") |
| 345 | + |
| 346 | + # Create objective function |
| 347 | + objective = create_backtest_objective( |
| 348 | + prices=prices, |
| 349 | + factor_type=factor_type, |
| 350 | + metric=metric, |
| 351 | + ) |
| 352 | + |
| 353 | + # Create and run Optuna runner |
| 354 | + runner = OptunaRunner( |
| 355 | + search_space={}, # Not used when using create_backtest_objective |
| 356 | + objective=objective, |
| 357 | + n_trials=n_trials, |
| 358 | + study_name=study_name, |
| 359 | + storage=storage, |
| 360 | + pruner=pruner, |
| 361 | + direction=( |
| 362 | + "maximize" |
| 363 | + if metric in ["sharpe_ratio", "total_return", "cagr"] |
| 364 | + else "minimize" |
| 365 | + ), |
| 366 | + ) |
| 367 | + |
| 368 | + # Run optimization |
| 369 | + results = runner.optimize() |
| 370 | + |
| 371 | + # Save results |
| 372 | + runner.save_results(output) |
| 373 | + |
| 374 | + click.echo("\n" + "=" * 60) |
| 375 | + click.echo("Optimization Results") |
| 376 | + click.echo("=" * 60) |
| 377 | + click.echo(f"Best parameters: {results['best_params']}") |
| 378 | + click.echo(f"Best {metric}: {results['best_value']:.4f}") |
| 379 | + click.echo(f"Total trials: {len(results['trial_history'])}") |
| 380 | + click.echo(f"Results saved -> {output}") |
| 381 | + |
| 382 | + |
250 | 383 | if __name__ == "__main__": |
251 | 384 | cli() |
0 commit comments