99from pathlib import Path
1010from typing import Any , TypeVar
1111
12- import numpy as np
1312import optuna
1413import torch
1514from optuna .trial import Trial
16- from pydantic import BaseModel , Field
15+ from pydantic import BaseModel , Field , ValidationInfo , field_validator
1716from typing_extensions import assert_never
1817
1918from autointent import Dataset
@@ -48,9 +47,6 @@ def n_possible_values(self) -> int:
4847 Returns:
4948 The number of possible values.
5049 """
51- if self .log :
52- return int (np .logspace (np .log10 (self .low ), np .log10 (self .high ), num = self .step ))
53-
5450 return (self .high - self .low ) // self .step + 1
5551
5652
@@ -62,6 +58,26 @@ class ParamSpaceFloat(ParamSpace):
6258 step : float | None = Field (None , description = "Step size for the search space (if applicable)." )
6359 log : bool = Field (False , description = "Indicates whether to use a logarithmic scale." )
6460
61+ @field_validator ("step" )
62+ @classmethod
63+ def validate_step_with_log (cls , v : float | None , info : ValidationInfo ) -> float | None :
64+ """Validate that step is not used when log is True.
65+
66+ Args:
67+ v: The step value to validate
68+ info: Validation info containing other field values
69+
70+ Returns:
71+ The validated step value
72+
73+ Raises:
74+ ValueError: If step is provided when log is True
75+ """
76+ if info .data .get ("log" , False ) and v is not None :
77+ msg = "Step cannot be used when log is True. See optuna docs on `suggest_float`."
78+ raise ValueError (msg )
79+ return v
80+
6581 def n_possible_values (self ) -> int | None :
6682 """Calculate the number of possible values in the search space.
6783
@@ -70,9 +86,7 @@ def n_possible_values(self) -> int | None:
7086 """
7187 if self .step is None :
7288 return None
73- if self .log :
74- return int (np .logspace (np .log10 (self .low ), np .log10 (self .high ), num = self .step ))
75- return (self .high - self .low ) // self .step + 1
89+ return int ((self .high - self .low ) // self .step ) + 1
7690
7791
7892logger = logging .getLogger (__name__ )
@@ -256,25 +270,27 @@ def suggest(self, trial: Trial, search_space: dict[str, Any | list[Any]]) -> dic
256270 raise TypeError (msg )
257271 return res
258272
259- def _n_possible_combinations (self , search_space : dict [str , Any ]) -> int :
273+ def _n_possible_combinations (self , search_space : dict [str , Any ]) -> int | None :
260274 """Calculate the number of possible combinations in the search space.
261275
262276 Args:
263277 search_space: The parameter search space.
278+
279+ Returns:
280+ The number of possible combinations or None if search space is continuous.
264281 """
265282 n_combinations = 1
266283 for param_space in search_space .values ():
267284 if isinstance (param_space , list ):
268285 n_combinations *= len (param_space )
286+ elif param_space_int := self ._parse_param_space (param_space , ParamSpaceInt ):
287+ n_combinations *= param_space_int .n_possible_values ()
288+ elif param_space_float := self ._parse_param_space (param_space , ParamSpaceFloat ):
289+ n_possible_values = param_space_float .n_possible_values ()
290+ if n_possible_values is None :
291+ return None
292+ n_combinations *= n_possible_values
269293 else :
270- param_space_int = self ._parse_param_space (param_space , ParamSpaceInt )
271- if param_space_int is not None :
272- n_combinations *= param_space_int .n_possible_values ()
273- continue
274- param_space_float = self ._parse_param_space (param_space , ParamSpaceFloat )
275- if param_space_float is not None :
276- n_combinations *= param_space_float .n_possible_values ()
277- continue
278294 assert_never (param_space )
279295 return n_combinations
280296
@@ -366,7 +382,7 @@ def _reformat_search_space(self, module_search_space: dict[str, Any]) -> tuple[d
366382 continue
367383 if isinstance (param_space , list ):
368384 res [param_name ] = param_space
369- elif self ._is_valid_param_space (param_space , ParamSpaceInt ) or self ._is_valid_param_space (
385+ elif self ._parse_param_space (param_space , ParamSpaceInt ) or self ._parse_param_space (
370386 param_space , ParamSpaceFloat
371387 ):
372388 res [param_name ] = [param_space ["low" ], param_space ["high" ]]
0 commit comments