Skip to content

Commit 069a313

Browse files
committed
fix typing errors
1 parent 8935f09 commit 069a313

File tree

1 file changed

+34
-18
lines changed

1 file changed

+34
-18
lines changed

autointent/nodes/_node_optimizer.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,10 @@
99
from pathlib import Path
1010
from typing import Any, TypeVar
1111

12-
import numpy as np
1312
import optuna
1413
import torch
1514
from optuna.trial import Trial
16-
from pydantic import BaseModel, Field
15+
from pydantic import BaseModel, Field, ValidationInfo, field_validator
1716
from typing_extensions import assert_never
1817

1918
from autointent import Dataset
@@ -48,9 +47,6 @@ def n_possible_values(self) -> int:
4847
Returns:
4948
The number of possible values.
5049
"""
51-
if self.log:
52-
return int(np.logspace(np.log10(self.low), np.log10(self.high), num=self.step))
53-
5450
return (self.high - self.low) // self.step + 1
5551

5652

@@ -62,6 +58,26 @@ class ParamSpaceFloat(ParamSpace):
6258
step: float | None = Field(None, description="Step size for the search space (if applicable).")
6359
log: bool = Field(False, description="Indicates whether to use a logarithmic scale.")
6460

61+
@field_validator("step")
62+
@classmethod
63+
def validate_step_with_log(cls, v: float | None, info: ValidationInfo) -> float | None:
64+
"""Validate that step is not used when log is True.
65+
66+
Args:
67+
v: The step value to validate
68+
info: Validation info containing other field values
69+
70+
Returns:
71+
The validated step value
72+
73+
Raises:
74+
ValueError: If step is provided when log is True
75+
"""
76+
if info.data.get("log", False) and v is not None:
77+
msg = "Step cannot be used when log is True. See optuna docs on `suggest_float`."
78+
raise ValueError(msg)
79+
return v
80+
6581
def n_possible_values(self) -> int | None:
6682
"""Calculate the number of possible values in the search space.
6783
@@ -70,9 +86,7 @@ def n_possible_values(self) -> int | None:
7086
"""
7187
if self.step is None:
7288
return None
73-
if self.log:
74-
return int(np.logspace(np.log10(self.low), np.log10(self.high), num=self.step))
75-
return (self.high - self.low) // self.step + 1
89+
return int((self.high - self.low) // self.step) + 1
7690

7791

7892
logger = logging.getLogger(__name__)
@@ -256,25 +270,27 @@ def suggest(self, trial: Trial, search_space: dict[str, Any | list[Any]]) -> dic
256270
raise TypeError(msg)
257271
return res
258272

259-
def _n_possible_combinations(self, search_space: dict[str, Any]) -> int:
273+
def _n_possible_combinations(self, search_space: dict[str, Any]) -> int | None:
260274
"""Calculate the number of possible combinations in the search space.
261275
262276
Args:
263277
search_space: The parameter search space.
278+
279+
Returns:
280+
The number of possible combinations or None if search space is continuous.
264281
"""
265282
n_combinations = 1
266283
for param_space in search_space.values():
267284
if isinstance(param_space, list):
268285
n_combinations *= len(param_space)
286+
elif param_space_int := self._parse_param_space(param_space, ParamSpaceInt):
287+
n_combinations *= param_space_int.n_possible_values()
288+
elif param_space_float := self._parse_param_space(param_space, ParamSpaceFloat):
289+
n_possible_values = param_space_float.n_possible_values()
290+
if n_possible_values is None:
291+
return None
292+
n_combinations *= n_possible_values
269293
else:
270-
param_space_int = self._parse_param_space(param_space, ParamSpaceInt)
271-
if param_space_int is not None:
272-
n_combinations *= param_space_int.n_possible_values()
273-
continue
274-
param_space_float = self._parse_param_space(param_space, ParamSpaceFloat)
275-
if param_space_float is not None:
276-
n_combinations *= param_space_float.n_possible_values()
277-
continue
278294
assert_never(param_space)
279295
return n_combinations
280296

@@ -366,7 +382,7 @@ def _reformat_search_space(self, module_search_space: dict[str, Any]) -> tuple[d
366382
continue
367383
if isinstance(param_space, list):
368384
res[param_name] = param_space
369-
elif self._is_valid_param_space(param_space, ParamSpaceInt) or self._is_valid_param_space(
385+
elif self._parse_param_space(param_space, ParamSpaceInt) or self._parse_param_space(
370386
param_space, ParamSpaceFloat
371387
):
372388
res[param_name] = [param_space["low"], param_space["high"]]

0 commit comments

Comments
 (0)