|
3 | 3 | import gc |
4 | 4 | import itertools as it |
5 | 5 | import logging |
| 6 | +from abc import ABC, abstractmethod |
6 | 7 | from copy import deepcopy |
7 | 8 | from functools import partial |
8 | 9 | from pathlib import Path |
9 | | -from typing import Any |
| 10 | +from typing import Any, TypeVar |
10 | 11 |
|
| 12 | +import numpy as np |
11 | 13 | import optuna |
12 | 14 | import torch |
13 | 15 | from optuna.trial import Trial |
|
20 | 22 | from autointent.nodes.info import NODES_INFO |
21 | 23 |
|
22 | 24 |
|
23 | | -class ParamSpaceInt(BaseModel): |
| 25 | +class ParamSpace(BaseModel, ABC): |
| 26 | + """Base class for parameter search space configuration.""" |
| 27 | + |
| 28 | + @abstractmethod |
| 29 | + def n_possible_values(self) -> int | None: |
| 30 | + """Calculate the number of possible values in the search space. |
| 31 | +
|
| 32 | + Returns: |
| 33 | + The number of possible values or None if search space is continuous. |
| 34 | + """ |
| 35 | + |
| 36 | + |
| 37 | +class ParamSpaceInt(ParamSpace): |
24 | 38 | """Integer parameter search space configuration.""" |
25 | 39 |
|
26 | 40 | low: int = Field(..., description="Lower boundary of the search space.") |
27 | 41 | high: int = Field(..., description="Upper boundary of the search space.") |
28 | 42 | step: int = Field(1, description="Step size for the search space.") |
29 | 43 | log: bool = Field(False, description="Indicates whether to use a logarithmic scale.") |
30 | 44 |
|
| 45 | + def n_possible_values(self) -> int: |
| 46 | + """Calculate the number of possible values in the search space. |
31 | 47 |
|
32 | | -class ParamSpaceFloat(BaseModel): |
| 48 | + Returns: |
| 49 | + The number of possible values. |
| 50 | + """ |
| 51 | + if self.log: |
| 52 | + return int(np.logspace(np.log10(self.low), np.log10(self.high), num=self.step)) |
| 53 | + |
| 54 | + return (self.high - self.low) // self.step + 1 |
| 55 | + |
| 56 | + |
| 57 | +class ParamSpaceFloat(ParamSpace): |
33 | 58 | """Float parameter search space configuration.""" |
34 | 59 |
|
35 | 60 | low: float = Field(..., description="Lower boundary of the search space.") |
36 | 61 | high: float = Field(..., description="Upper boundary of the search space.") |
37 | 62 | step: float | None = Field(None, description="Step size for the search space (if applicable).") |
38 | 63 | log: bool = Field(False, description="Indicates whether to use a logarithmic scale.") |
39 | 64 |
|
| 65 | + def n_possible_values(self) -> int | None: |
| 66 | + """Calculate the number of possible values in the search space. |
| 67 | +
|
| 68 | + Returns: |
| 69 | + The number of possible values or None if search space is continuous. |
| 70 | + """ |
| 71 | + if self.step is None: |
| 72 | + return None |
| 73 | + if self.log: |
| 74 | + return int(np.logspace(np.log10(self.low), np.log10(self.high), num=self.step)) |
| 75 | + return (self.high - self.low) // self.step + 1 |
| 76 | + |
40 | 77 |
|
41 | 78 | logger = logging.getLogger(__name__) |
42 | 79 |
|
43 | 80 |
|
| 81 | +ParamSpaceT = TypeVar("ParamSpaceT", bound=ParamSpace) |
| 82 | + |
| 83 | + |
44 | 84 | class NodeOptimizer: |
45 | 85 | """Class for optimizing nodes in a computational pipeline. |
46 | 86 |
|
@@ -104,6 +144,8 @@ def fit(self, context: Context, sampler: SamplerType = "brute", n_jobs: int = 1) |
104 | 144 | else: |
105 | 145 | assert_never(sampler) |
106 | 146 |
|
| 147 | + n_trials = None if n_trials is None else min(self._n_possible_combinations(search_space), n_trials) |
| 148 | + |
107 | 149 | study, finished_trials, n_trials = load_or_create_study( |
108 | 150 | study_name=f"{self.node_info.node_type}_{module_name}", |
109 | 151 | storage_dir=context.get_dump_dir(), |
@@ -205,23 +247,42 @@ def suggest(self, trial: Trial, search_space: dict[str, Any | list[Any]]) -> dic |
205 | 247 | for param_name, param_space in search_space.items(): |
206 | 248 | if isinstance(param_space, list): |
207 | 249 | res[param_name] = trial.suggest_categorical(param_name, choices=param_space) |
208 | | - elif self._is_valid_param_space(param_space, ParamSpaceInt): |
| 250 | + elif self._parse_param_space(param_space, ParamSpaceInt): |
209 | 251 | res[param_name] = trial.suggest_int(param_name, **param_space) |
210 | | - elif self._is_valid_param_space(param_space, ParamSpaceFloat): |
| 252 | + elif self._parse_param_space(param_space, ParamSpaceFloat): |
211 | 253 | res[param_name] = trial.suggest_float(param_name, **param_space) |
212 | 254 | else: |
213 | 255 | msg = f"Unsupported type of param search space: {param_space}" |
214 | 256 | raise TypeError(msg) |
215 | 257 | return res |
216 | 258 |
|
217 | | - def _is_valid_param_space( |
218 | | - self, param_space: dict[str, Any], space_type: type[ParamSpaceInt | ParamSpaceFloat] |
219 | | - ) -> bool: |
| 259 | + def _n_possible_combinations(self, search_space: dict[str, Any]) -> int: |
| 260 | + """Calculate the number of possible combinations in the search space. |
| 261 | +
|
| 262 | + Args: |
| 263 | + search_space: The parameter search space. |
| 264 | + """ |
| 265 | + n_combinations = 1 |
| 266 | + for param_space in search_space.values(): |
| 267 | + if isinstance(param_space, list): |
| 268 | + n_combinations *= len(param_space) |
| 269 | + else: |
| 270 | + param_space_int = self._parse_param_space(param_space, ParamSpaceInt) |
| 271 | + if param_space_int is not None: |
| 272 | + n_combinations *= param_space_int.n_possible_values() |
| 273 | + continue |
| 274 | + param_space_float = self._parse_param_space(param_space, ParamSpaceFloat) |
| 275 | + if param_space_float is not None: |
| 276 | + n_combinations *= param_space_float.n_possible_values() |
| 277 | + continue |
| 278 | + assert_never(param_space) |
| 279 | + return n_combinations |
| 280 | + |
| 281 | + def _parse_param_space(self, param_space: dict[str, Any], space_type: type[ParamSpaceT]) -> ParamSpaceT | None: |
220 | 282 | try: |
221 | | - space_type(**param_space) |
222 | | - return True # noqa: TRY300 |
| 283 | + return space_type(**param_space) |
223 | 284 | except ValueError: |
224 | | - return False |
| 285 | + return None |
225 | 286 |
|
226 | 287 | def get_module_dump_dir(self, dump_dir: Path, module_name: str, j_combination: int) -> str: |
227 | 288 | """Creates and returns the path to the module dump directory. |
|
0 commit comments