|
3 | 3 | import gc |
4 | 4 | import itertools as it |
5 | 5 | import logging |
| 6 | +from abc import ABC, abstractmethod |
6 | 7 | from copy import deepcopy |
7 | 8 | from functools import partial |
8 | 9 | from pathlib import Path |
9 | | -from typing import Any |
| 10 | +from typing import Any, TypeVar |
10 | 11 |
|
11 | 12 | import optuna |
12 | 13 | import torch |
13 | 14 | from optuna.trial import Trial |
14 | | -from pydantic import BaseModel, Field |
| 15 | +from pydantic import BaseModel, Field, ValidationInfo, field_validator |
15 | 16 | from typing_extensions import assert_never |
16 | 17 |
|
17 | 18 | from autointent import Dataset |
|
20 | 21 | from autointent.nodes.info import NODES_INFO |
21 | 22 |
|
22 | 23 |
|
23 | | -class ParamSpaceInt(BaseModel): |
| 24 | +class ParamSpace(BaseModel, ABC): |
| 25 | + """Base class for parameter search space configuration.""" |
| 26 | + |
| 27 | + @abstractmethod |
| 28 | + def n_possible_values(self) -> int | None: |
| 29 | + """Calculate the number of possible values in the search space. |
| 30 | +
|
| 31 | + Returns: |
| 32 | + The number of possible values or None if search space is continuous. |
| 33 | + """ |
| 34 | + |
| 35 | + |
| 36 | +class ParamSpaceInt(ParamSpace): |
24 | 37 | """Integer parameter search space configuration.""" |
25 | 38 |
|
26 | 39 | low: int = Field(..., description="Lower boundary of the search space.") |
27 | 40 | high: int = Field(..., description="Upper boundary of the search space.") |
28 | 41 | step: int = Field(1, description="Step size for the search space.") |
29 | 42 | log: bool = Field(False, description="Indicates whether to use a logarithmic scale.") |
30 | 43 |
|
| 44 | + def n_possible_values(self) -> int: |
| 45 | + """Calculate the number of possible values in the search space. |
| 46 | +
|
| 47 | + Returns: |
| 48 | + The number of possible values. |
| 49 | + """ |
| 50 | + return (self.high - self.low) // self.step + 1 |
31 | 51 |
|
32 | | -class ParamSpaceFloat(BaseModel): |
| 52 | + |
| 53 | +class ParamSpaceFloat(ParamSpace): |
33 | 54 | """Float parameter search space configuration.""" |
34 | 55 |
|
35 | 56 | low: float = Field(..., description="Lower boundary of the search space.") |
36 | 57 | high: float = Field(..., description="Upper boundary of the search space.") |
37 | 58 | step: float | None = Field(None, description="Step size for the search space (if applicable).") |
38 | 59 | log: bool = Field(False, description="Indicates whether to use a logarithmic scale.") |
39 | 60 |
|
| 61 | + @field_validator("step") |
| 62 | + @classmethod |
| 63 | + def validate_step_with_log(cls, v: float | None, info: ValidationInfo) -> float | None: |
| 64 | + """Validate that step is not used when log is True. |
| 65 | +
|
| 66 | + Args: |
| 67 | + v: The step value to validate |
| 68 | + info: Validation info containing other field values |
| 69 | +
|
| 70 | + Returns: |
| 71 | + The validated step value |
| 72 | +
|
| 73 | + Raises: |
| 74 | + ValueError: If step is provided when log is True |
| 75 | + """ |
| 76 | + if info.data.get("log", False) and v is not None: |
| 77 | + msg = "Step cannot be used when log is True. See optuna docs on `suggest_float` (https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.suggest_float)." |
| 78 | + raise ValueError(msg) |
| 79 | + return v |
| 80 | + |
| 81 | + def n_possible_values(self) -> int | None: |
| 82 | + """Calculate the number of possible values in the search space. |
| 83 | +
|
| 84 | + Returns: |
| 85 | + The number of possible values or None if search space is continuous. |
| 86 | + """ |
| 87 | + if self.step is None: |
| 88 | + return None |
| 89 | + return int((self.high - self.low) // self.step) + 1 |
| 90 | + |
40 | 91 |
|
41 | 92 | logger = logging.getLogger(__name__) |
42 | 93 |
|
43 | 94 |
|
| 95 | +ParamSpaceT = TypeVar("ParamSpaceT", bound=ParamSpace) |
| 96 | + |
| 97 | + |
44 | 98 | class NodeOptimizer: |
45 | 99 | """Class for optimizing nodes in a computational pipeline. |
46 | 100 |
|
@@ -104,6 +158,9 @@ def fit(self, context: Context, sampler: SamplerType = "brute", n_jobs: int = 1) |
104 | 158 | else: |
105 | 159 | assert_never(sampler) |
106 | 160 |
|
| 161 | + if n_trials and (possible_combinations := self._n_possible_combinations(search_space)): |
| 162 | + n_trials = min(possible_combinations, n_trials) |
| 163 | + |
107 | 164 | study, finished_trials, n_trials = load_or_create_study( |
108 | 165 | study_name=f"{self.node_info.node_type}_{module_name}", |
109 | 166 | storage_dir=context.get_dump_dir(), |
@@ -205,23 +262,44 @@ def suggest(self, trial: Trial, search_space: dict[str, Any | list[Any]]) -> dic |
205 | 262 | for param_name, param_space in search_space.items(): |
206 | 263 | if isinstance(param_space, list): |
207 | 264 | res[param_name] = trial.suggest_categorical(param_name, choices=param_space) |
208 | | - elif self._is_valid_param_space(param_space, ParamSpaceInt): |
| 265 | + elif self._parse_param_space(param_space, ParamSpaceInt): |
209 | 266 | res[param_name] = trial.suggest_int(param_name, **param_space) |
210 | | - elif self._is_valid_param_space(param_space, ParamSpaceFloat): |
| 267 | + elif self._parse_param_space(param_space, ParamSpaceFloat): |
211 | 268 | res[param_name] = trial.suggest_float(param_name, **param_space) |
212 | 269 | else: |
213 | 270 | msg = f"Unsupported type of param search space: {param_space}" |
214 | 271 | raise TypeError(msg) |
215 | 272 | return res |
216 | 273 |
|
217 | | - def _is_valid_param_space( |
218 | | - self, param_space: dict[str, Any], space_type: type[ParamSpaceInt | ParamSpaceFloat] |
219 | | - ) -> bool: |
| 274 | + def _n_possible_combinations(self, search_space: dict[str, Any]) -> int | None: |
| 275 | + """Calculate the number of possible combinations in the search space. |
| 276 | +
|
| 277 | + Args: |
| 278 | + search_space: The parameter search space. |
| 279 | +
|
| 280 | + Returns: |
| 281 | + The number of possible combinations or None if search space is continuous. |
| 282 | + """ |
| 283 | + n_combinations = 1 |
| 284 | + for param_space in search_space.values(): |
| 285 | + if isinstance(param_space, list): |
| 286 | + n_combinations *= len(param_space) |
| 287 | + elif param_space_int := self._parse_param_space(param_space, ParamSpaceInt): |
| 288 | + n_combinations *= param_space_int.n_possible_values() |
| 289 | + elif param_space_float := self._parse_param_space(param_space, ParamSpaceFloat): |
| 290 | + n_possible_values = param_space_float.n_possible_values() |
| 291 | + if n_possible_values is None: |
| 292 | + return None |
| 293 | + n_combinations *= n_possible_values |
| 294 | + else: |
| 295 | + assert_never(param_space) |
| 296 | + return n_combinations |
| 297 | + |
| 298 | + def _parse_param_space(self, param_space: dict[str, Any], space_type: type[ParamSpaceT]) -> ParamSpaceT | None: |
220 | 299 | try: |
221 | | - space_type(**param_space) |
222 | | - return True # noqa: TRY300 |
| 300 | + return space_type(**param_space) |
223 | 301 | except ValueError: |
224 | | - return False |
| 302 | + return None |
225 | 303 |
|
226 | 304 | def get_module_dump_dir(self, dump_dir: Path, module_name: str, j_combination: int) -> str: |
227 | 305 | """Creates and returns the path to the module dump directory. |
@@ -305,7 +383,7 @@ def _reformat_search_space(self, module_search_space: dict[str, Any]) -> tuple[d |
305 | 383 | continue |
306 | 384 | if isinstance(param_space, list): |
307 | 385 | res[param_name] = param_space |
308 | | - elif self._is_valid_param_space(param_space, ParamSpaceInt) or self._is_valid_param_space( |
| 386 | + elif self._parse_param_space(param_space, ParamSpaceInt) or self._parse_param_space( |
309 | 387 | param_space, ParamSpaceFloat |
310 | 388 | ): |
311 | 389 | res[param_name] = [param_space["low"], param_space["high"]] |
|
0 commit comments