@@ -63,7 +63,7 @@ def __init__(
6363 def fit (
6464 self ,
6565 context : Context ,
66- sampler : SamplerType = "brute " ,
66+ sampler : SamplerType = "tpe " ,
6767 n_trials : int | None = None ,
6868 timeout : float | None = None ,
6969 n_jobs : int = 1 ,
@@ -82,15 +82,12 @@ def fit(
8282 """
8383 self ._logger .info ("Starting %s node optimization..." , self .node_info .node_type .value )
8484
85+ n_trials = n_trials or 10
86+
8587 if sampler == "tpe" :
8688 sampler_instance = optuna .samplers .TPESampler (seed = context .seed )
87- n_trials = n_trials or 10
88- elif sampler == "brute" :
89- sampler_instance = optuna .samplers .BruteForceSampler (seed = context .seed ) # type: ignore[assignment]
90- n_trials = None
9189 elif sampler == "random" :
9290 sampler_instance = optuna .samplers .RandomSampler (seed = context .seed ) # type: ignore[assignment]
93- n_trials = n_trials or 10
9491 else :
9592 assert_never (sampler )
9693
@@ -101,7 +98,7 @@ def fit(
10198 sampler = sampler_instance ,
10299 n_trials = n_trials ,
103100 )
104- self ._counter = max ( self . _counter , finished_trials )
101+ self ._counter = finished_trials # zero if study is newly created
105102
106103 optuna .logging .set_verbosity (optuna .logging .WARNING )
107104 obj = partial (self .objective , search_space = self .modules_search_spaces , context = context )
@@ -364,8 +361,8 @@ def load_or_create_study(
364361 study_name : str ,
365362 context : Context ,
366363 sampler : optuna .samplers .BaseSampler ,
364+ n_trials : int ,
367365 direction : str = "maximize" ,
368- n_trials : int = 10 ,
369366) -> tuple [optuna .Study , int , int ]:
370367 """Load an existing study or create a new one if it doesn't exist.
371368
@@ -396,7 +393,7 @@ def load_or_create_study(
396393 # Find the highest trial number to continue counting
397394 finished_trials = max (t .number for t in study .trials ) + 1
398395 # Calculate remaining trials if n_trials is specified
399- remaining_trials = n_trials if n_trials is None else max (0 , n_trials - len (study .trials ))
396+ remaining_trials = max (0 , n_trials - len (study .trials ))
400397
401398 context .load_optimization_info ()
402399 return study , finished_trials , remaining_trials # noqa: TRY300
0 commit comments