@@ -17,14 +17,14 @@ class OptunaRunner:
1717
1818 def __init__ (
1919 self ,
20- search_space : Dict [str , Any ],
2120 objective : Callable [[Trial ], float ],
2221 n_trials : int = 100 ,
2322 study_name : Optional [str ] = None ,
2423 storage : Optional [Union [str , RDBStorage ]] = None ,
2524 pruner : Optional [Union [str , optuna .pruners .BasePruner ]] = None ,
2625 direction : str = "maximize" ,
2726 random_state : Optional [int ] = None ,
27+ search_space : Optional [Dict [str , Any ]] = None ,
2828 ):
2929
3030 self .search_space = search_space
@@ -139,6 +139,7 @@ def create_backtest_objective(
139139 initial_capital : float = 1_000_000 ,
140140 transaction_cost : float = 0.001 ,
141141 metric : str = "sharpe_ratio" ,
142+ search_space : Optional [Dict [str , Any ]] = None ,
142143) -> Callable [[Trial ], float ]:
143144 """
144145 Create an objective function for backtest-based hyperparameter tuning.
@@ -149,6 +150,8 @@ def create_backtest_objective(
149150 initial_capital: Initial capital for backtest.
150151 transaction_cost: Transaction cost rate.
151152 metric: Metric to optimize ("sharpe_ratio", "total_return", "cagr", etc.).
153+ search_space: Optional search space dictionary. If provided, will be used
154+ instead of default hardcoded parameter ranges.
152155
153156 Returns:
154157 Objective function that takes a Trial and returns a float.
@@ -170,15 +173,22 @@ def create_backtest_objective(
170173
171174 def objective (trial : Trial ) -> float :
172175 """Objective function for Optuna trial."""
173- if factor_type == "momentum" :
174- lookback = trial .suggest_int ("lookback" , 10 , 252 , step = 1 )
175- skip_period = trial .suggest_int ("skip_period" , 0 , 5 , step = 1 )
176- factor = FactorClass (lookback = lookback , skip_period = skip_period )
177- elif factor_type == "volatility" :
178- lookback = trial .suggest_int ("lookback" , 10 , 126 , step = 1 )
179- factor = FactorClass (lookback = lookback )
176+ # Use search_space if provided, otherwise use default hardcoded ranges
177+ if search_space :
178+ params = suggest_hyperparameters (trial , search_space )
179+ # Create factor with suggested parameters
180+ factor = FactorClass (** params )
180181 else :
181- factor = FactorClass ()
182+ # Default behavior: use hardcoded parameter ranges
183+ if factor_type == "momentum" :
184+ lookback = trial .suggest_int ("lookback" , 10 , 252 , step = 1 )
185+ skip_period = trial .suggest_int ("skip_period" , 0 , 5 , step = 1 )
186+ factor = FactorClass (lookback = lookback , skip_period = skip_period )
187+ elif factor_type == "volatility" :
188+ lookback = trial .suggest_int ("lookback" , 10 , 126 , step = 1 )
189+ factor = FactorClass (lookback = lookback )
190+ else :
191+ factor = FactorClass ()
182192 signals = factor .compute (prices )
183193 if signals .empty :
184194 return (
0 commit comments