@@ -146,21 +146,25 @@ def objective(
146146
147147 return target_metric
148148
149- def suggest (
150- self , trial : Trial , search_space : dict [str , ParamSpaceInt | ParamSpaceFloat | list [Any ]]
151- ) -> dict [str , Any ]:
149+ def suggest (self , trial : Trial , search_space : dict [str , Any | list [Any ]]) -> dict [str , Any ]:
152150 res : dict [str , Any ] = {}
151+
152+ def is_valid_param_space (
153+ param_space : dict [str , Any ], space_type : type [ParamSpaceInt | ParamSpaceFloat ]
154+ ) -> bool :
155+ try :
156+ space_type (** param_space )
157+ return True # noqa: TRY300
158+ except ValueError :
159+ return False
160+
153161 for param_name , param_space in search_space .items ():
154162 if isinstance (param_space , list ):
155163 res [param_name ] = trial .suggest_categorical (param_name , choices = param_space )
156- elif isinstance (param_space , ParamSpaceInt ):
157- res [param_name ] = trial .suggest_int (
158- param_name , low = param_space .low , high = param_space .high , step = param_space .step , log = param_space .log
159- )
160- elif isinstance (param_space , ParamSpaceFloat ):
161- res [param_name ] = trial .suggest_float (
162- param_name , low = param_space .low , high = param_space .high , step = param_space .step , log = param_space .log
163- )
164+ elif is_valid_param_space (param_space , ParamSpaceInt ):
165+ res [param_name ] = trial .suggest_int (param_name , ** param_space )
166+ elif is_valid_param_space (param_space , ParamSpaceFloat ):
167+ res [param_name ] = trial .suggest_float (param_name , ** param_space )
164168 else :
165169 msg = f"Unsupported type of param search space: { param_space } "
166170 raise TypeError (msg )
0 commit comments