44import itertools as it
55import json
66import logging
7+ import os
78from copy import deepcopy
89from functools import partial
910from pathlib import Path
@@ -59,59 +60,60 @@ def __init__(
5960 self .validate_search_space (search_space )
6061 self .modules_search_spaces = search_space
6162
62- def fit (self , context : Context , sampler : SamplerType = "brute" , n_jobs : int = 1 ) -> None :
63+ def fit (
64+ self ,
65+ context : Context ,
66+ sampler : SamplerType = "brute" ,
67+ n_trials : int | None = None ,
68+ timeout : float | None = None ,
69+ n_jobs : int = 1 ,
70+ ) -> None :
6371 """Performs the optimization process for the node.
6472
6573 Args:
6674 context: The optimization context containing relevant data.
6775 sampler: The sampling strategy used for optimization.
76+ n_trials: Number of optuna trials.
77+ timeout: Number of secords for optimizing the whole node.
6878 n_jobs: The number of parallel jobs to run during optimization.
6979
7080 Raises:
7181 AssertionError: If an invalid sampler type is provided.
7282 """
7383 self ._logger .info ("Starting %s node optimization..." , self .node_info .node_type .value )
74- for search_space in deepcopy (self .modules_search_spaces ):
75- self ._counter : int = 0
76- module_name = search_space .pop ("module_name" )
77- n_trials = search_space .pop ("n_trials" , None )
78-
79- if sampler == "tpe" :
80- sampler_instance = optuna .samplers .TPESampler (seed = context .seed )
81- n_trials = n_trials or 10
82- elif sampler == "brute" :
83- sampler_instance = optuna .samplers .BruteForceSampler (seed = context .seed ) # type: ignore[assignment]
84- n_trials = None
85- elif sampler == "random" :
86- sampler_instance = optuna .samplers .RandomSampler (seed = context .seed ) # type: ignore[assignment]
87- n_trials = n_trials or 10
88- else :
89- assert_never (sampler )
90-
91- if n_trials and (possible_combinations := self ._n_possible_combinations (search_space )):
92- n_trials = min (possible_combinations , n_trials )
9384
94- study , finished_trials , n_trials = load_or_create_study (
95- study_name = f"{ self .node_info .node_type } _{ module_name } " ,
96- context = context ,
97- direction = "maximize" ,
98- sampler = sampler_instance ,
99- n_trials = n_trials ,
100- )
101- self ._counter = max (self ._counter , finished_trials )
85+ if sampler == "tpe" :
86+ sampler_instance = optuna .samplers .TPESampler (seed = context .seed )
87+ n_trials = n_trials or 10
88+ elif sampler == "brute" :
89+ sampler_instance = optuna .samplers .BruteForceSampler (seed = context .seed ) # type: ignore[assignment]
90+ n_trials = None
91+ elif sampler == "random" :
92+ sampler_instance = optuna .samplers .RandomSampler (seed = context .seed ) # type: ignore[assignment]
93+ n_trials = n_trials or 10
94+ else :
95+ assert_never (sampler )
96+
97+ study , finished_trials , n_trials = load_or_create_study (
98+ study_name = self .node_info .node_type ,
99+ context = context ,
100+ direction = "maximize" ,
101+ sampler = sampler_instance ,
102+ n_trials = n_trials ,
103+ )
104+ self ._counter = max (self ._counter , finished_trials )
102105
103- optuna .logging .set_verbosity (optuna .logging .WARNING )
104- obj = partial (self .objective , module_name = module_name , search_space = search_space , context = context )
106+ optuna .logging .set_verbosity (optuna .logging .WARNING )
107+ obj = partial (self .objective , search_space = self . modules_search_spaces , context = context )
105108
106- study .optimize (obj , n_trials = n_trials , n_jobs = n_jobs )
109+ study .optimize (obj , n_trials = n_trials , n_jobs = n_jobs , gc_after_trial = True , timeout = timeout )
107110
108111 self ._logger .info ("%s node optimization is finished!" , self .node_info .node_type )
109112
110113 def objective (
111114 self ,
112115 trial : Trial ,
113- module_name : str ,
114- search_space : dict [str , ParamSpaceInt | ParamSpaceFloat | list [Any ]],
116+ search_space : list [dict [str , Any ]],
115117 context : Context ,
116118 ) -> float :
117119 """Defines the objective function for optimization.
@@ -125,13 +127,17 @@ def objective(
125127 Returns:
126128 The value of the target metric for the given trial.
127129 """
128- config = self .suggest (trial , search_space )
130+ module_name , module_hyperparams = self ._suggest_module_and_hyperparams (trial , search_space )
129131
130- self ._logger .debug ("Initializing %s module with config: %s" , module_name , json .dumps (config ))
131- module = self .node_info .modules_available [module_name ].from_context (context , ** config )
132- config .update (module .get_implicit_initialization_params ())
132+ self ._logger .debug ("Initializing %s module with config: %s" , module_name , json .dumps (module_hyperparams ))
133+ module = self .node_info .modules_available [module_name ].from_context (context , ** module_hyperparams )
134+ module_hyperparams .update (module .get_implicit_initialization_params ())
133135
134- context .callback_handler .start_module (module_name = module .trial_name , num = self ._counter , module_kwargs = config )
136+ context .callback_handler .start_module (
137+ module_name = module .trial_name ,
138+ num = self ._counter ,
139+ module_kwargs = module_hyperparams ,
140+ )
135141
136142 self ._logger .debug ("Scoring %s module..." , module_name )
137143
@@ -148,7 +154,7 @@ def objective(
148154 context .optimization_info .log_module_optimization (
149155 node_type = self .node_info .node_type ,
150156 module_name = module_name ,
151- module_params = config ,
157+ module_params = module_hyperparams ,
152158 metric_value = target_metric ,
153159 metric_name = self .target_metric ,
154160 metrics = quality_metrics ,
@@ -166,30 +172,32 @@ def objective(
166172 self ._counter += 1
167173 return target_metric
168174
169- def suggest (self , trial : Trial , search_space : dict [str , Any | list [Any ]]) -> dict [str , Any ]:
170- """Suggests parameter values based on the search space.
171-
172- Args:
173- trial: The Optuna trial instance.
174- search_space: A dictionary defining the parameter search space.
175-
176- Returns:
177- A dictionary containing the suggested parameter values.
178-
179- Raises:
180- TypeError: If an unsupported parameter search space type is encountered.
181- """
175+ def _suggest_module_and_hyperparams (
176+ self , trial : Trial , search_space : list [dict [str , Any ]]
177+ ) -> tuple [str , dict [str , Any ]]:
178+ """Sample module name and its hyperparams from given search space."""
179+ n_modules = len (search_space )
180+ id_module_chosen = trial .suggest_categorical ("module_idx" , list (range (n_modules )))
181+ module_chosen = deepcopy (search_space [id_module_chosen ])
182+ module_name = module_chosen .pop ("module_name" )
183+ module_config = self ._suggest_hyperparams (trial , f"{ module_name } _{ id_module_chosen } " , module_chosen )
184+ return module_name , module_config
185+
186+ def _suggest_hyperparams (
187+ self , trial : Trial , module_name : str , search_space : dict [str , Any | list [Any ]]
188+ ) -> dict [str , Any ]:
182189 res : dict [str , Any ] = {}
183190
184191 for param_name , param_space in search_space .items ():
192+ name = f"{ module_name } _{ param_name } "
185193 if isinstance (param_space , list ):
186- res [param_name ] = trial .suggest_categorical (param_name , choices = param_space )
194+ res [param_name ] = trial .suggest_categorical (name , choices = param_space )
187195 elif self ._parse_param_space (param_space , ParamSpaceInt ):
188- res [param_name ] = trial .suggest_int (param_name , ** param_space )
196+ res [param_name ] = trial .suggest_int (name , ** param_space )
189197 elif self ._parse_param_space (param_space , ParamSpaceFloat ):
190- res [param_name ] = trial .suggest_float (param_name , ** param_space )
198+ res [param_name ] = trial .suggest_float (name , ** param_space )
191199 else :
192- msg = f"Unsupported type of param search space: { param_space } "
200+ msg = f"Unsupported type of param search space { name } : { param_space } "
193201 raise TypeError (msg )
194202 return res
195203
@@ -294,6 +302,10 @@ def validate_nodes_with_dataset(self, dataset: Dataset, mode: SearchSpaceValidat
294302 def validate_search_space (self , search_space : list [dict [str , Any ]]) -> None :
295303 """Check if search space is configured correctly."""
296304 validated_search_space = SearchSpaceConfig (search_space ).model_dump ()
305+
306+ if not bool (int (os .getenv ("AUTOINTENT_EXTRA_VALIDATION" , "0" ))):
307+ return
308+
297309 for module_search_space in validated_search_space :
298310 module_search_space_no_optuna , module_name = self ._reformat_search_space (deepcopy (module_search_space ))
299311
0 commit comments