Skip to content

Commit 190a92e

Browse files
committed
fix type validation
1 parent f875e23 commit 190a92e

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

autointent/nodes/_optimization/_node_optimizer.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -146,21 +146,25 @@ def objective(
146146

147147
return target_metric
148148

149-
def suggest(
150-
self, trial: Trial, search_space: dict[str, ParamSpaceInt | ParamSpaceFloat | list[Any]]
151-
) -> dict[str, Any]:
149+
def suggest(self, trial: Trial, search_space: dict[str, Any | list[Any]]) -> dict[str, Any]:
152150
res: dict[str, Any] = {}
151+
152+
def is_valid_param_space(
153+
param_space: dict[str, Any], space_type: type[ParamSpaceInt | ParamSpaceFloat]
154+
) -> bool:
155+
try:
156+
space_type(**param_space)
157+
return True # noqa: TRY300
158+
except ValueError:
159+
return False
160+
153161
for param_name, param_space in search_space.items():
154162
if isinstance(param_space, list):
155163
res[param_name] = trial.suggest_categorical(param_name, choices=param_space)
156-
elif isinstance(param_space, ParamSpaceInt):
157-
res[param_name] = trial.suggest_int(
158-
param_name, low=param_space.low, high=param_space.high, step=param_space.step, log=param_space.log
159-
)
160-
elif isinstance(param_space, ParamSpaceFloat):
161-
res[param_name] = trial.suggest_float(
162-
param_name, low=param_space.low, high=param_space.high, step=param_space.step, log=param_space.log
163-
)
164+
elif is_valid_param_space(param_space, ParamSpaceInt):
165+
res[param_name] = trial.suggest_int(param_name, **param_space)
166+
elif is_valid_param_space(param_space, ParamSpaceFloat):
167+
res[param_name] = trial.suggest_float(param_name, **param_space)
164168
else:
165169
msg = f"Unsupported type of param search space: {param_space}"
166170
raise TypeError(msg)

0 commit comments

Comments
 (0)