Skip to content

Commit b11f845

Browse files
voorhsSamoed
andauthored
Fix/n trials issue (#196)
* try to fix * fix typing errors * bug fix * Update autointent/nodes/_node_optimizer.py Co-authored-by: Roman Solomatin <samoed.roman@gmail.com> --------- Co-authored-by: Roman Solomatin <samoed.roman@gmail.com>
1 parent b57c002 commit b11f845

File tree

1 file changed

+91
-13
lines changed

1 file changed

+91
-13
lines changed

autointent/nodes/_node_optimizer.py

Lines changed: 91 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33
import gc
44
import itertools as it
55
import logging
6+
from abc import ABC, abstractmethod
67
from copy import deepcopy
78
from functools import partial
89
from pathlib import Path
9-
from typing import Any
10+
from typing import Any, TypeVar
1011

1112
import optuna
1213
import torch
1314
from optuna.trial import Trial
14-
from pydantic import BaseModel, Field
15+
from pydantic import BaseModel, Field, ValidationInfo, field_validator
1516
from typing_extensions import assert_never
1617

1718
from autointent import Dataset
@@ -20,27 +21,80 @@
2021
from autointent.nodes.info import NODES_INFO
2122

2223

23-
class ParamSpaceInt(BaseModel):
24+
class ParamSpace(BaseModel, ABC):
25+
"""Base class for parameter search space configuration."""
26+
27+
@abstractmethod
28+
def n_possible_values(self) -> int | None:
29+
"""Calculate the number of possible values in the search space.
30+
31+
Returns:
32+
The number of possible values or None if search space is continuous.
33+
"""
34+
35+
36+
class ParamSpaceInt(ParamSpace):
2437
"""Integer parameter search space configuration."""
2538

2639
low: int = Field(..., description="Lower boundary of the search space.")
2740
high: int = Field(..., description="Upper boundary of the search space.")
2841
step: int = Field(1, description="Step size for the search space.")
2942
log: bool = Field(False, description="Indicates whether to use a logarithmic scale.")
3043

44+
def n_possible_values(self) -> int:
45+
"""Calculate the number of possible values in the search space.
46+
47+
Returns:
48+
The number of possible values.
49+
"""
50+
return (self.high - self.low) // self.step + 1
3151

32-
class ParamSpaceFloat(BaseModel):
52+
53+
class ParamSpaceFloat(ParamSpace):
3354
"""Float parameter search space configuration."""
3455

3556
low: float = Field(..., description="Lower boundary of the search space.")
3657
high: float = Field(..., description="Upper boundary of the search space.")
3758
step: float | None = Field(None, description="Step size for the search space (if applicable).")
3859
log: bool = Field(False, description="Indicates whether to use a logarithmic scale.")
3960

61+
@field_validator("step")
62+
@classmethod
63+
def validate_step_with_log(cls, v: float | None, info: ValidationInfo) -> float | None:
64+
"""Validate that step is not used when log is True.
65+
66+
Args:
67+
v: The step value to validate
68+
info: Validation info containing other field values
69+
70+
Returns:
71+
The validated step value
72+
73+
Raises:
74+
ValueError: If step is provided when log is True
75+
"""
76+
if info.data.get("log", False) and v is not None:
77+
msg = "Step cannot be used when log is True. See optuna docs on `suggest_float` (https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.suggest_float)."
78+
raise ValueError(msg)
79+
return v
80+
81+
def n_possible_values(self) -> int | None:
82+
"""Calculate the number of possible values in the search space.
83+
84+
Returns:
85+
The number of possible values or None if search space is continuous.
86+
"""
87+
if self.step is None:
88+
return None
89+
return int((self.high - self.low) // self.step) + 1
90+
4091

4192
logger = logging.getLogger(__name__)
4293

4394

95+
ParamSpaceT = TypeVar("ParamSpaceT", bound=ParamSpace)
96+
97+
4498
class NodeOptimizer:
4599
"""Class for optimizing nodes in a computational pipeline.
46100
@@ -104,6 +158,9 @@ def fit(self, context: Context, sampler: SamplerType = "brute", n_jobs: int = 1)
104158
else:
105159
assert_never(sampler)
106160

161+
if n_trials and (possible_combinations := self._n_possible_combinations(search_space)):
162+
n_trials = min(possible_combinations, n_trials)
163+
107164
study, finished_trials, n_trials = load_or_create_study(
108165
study_name=f"{self.node_info.node_type}_{module_name}",
109166
storage_dir=context.get_dump_dir(),
@@ -205,23 +262,44 @@ def suggest(self, trial: Trial, search_space: dict[str, Any | list[Any]]) -> dic
205262
for param_name, param_space in search_space.items():
206263
if isinstance(param_space, list):
207264
res[param_name] = trial.suggest_categorical(param_name, choices=param_space)
208-
elif self._is_valid_param_space(param_space, ParamSpaceInt):
265+
elif self._parse_param_space(param_space, ParamSpaceInt):
209266
res[param_name] = trial.suggest_int(param_name, **param_space)
210-
elif self._is_valid_param_space(param_space, ParamSpaceFloat):
267+
elif self._parse_param_space(param_space, ParamSpaceFloat):
211268
res[param_name] = trial.suggest_float(param_name, **param_space)
212269
else:
213270
msg = f"Unsupported type of param search space: {param_space}"
214271
raise TypeError(msg)
215272
return res
216273

217-
def _is_valid_param_space(
218-
self, param_space: dict[str, Any], space_type: type[ParamSpaceInt | ParamSpaceFloat]
219-
) -> bool:
274+
def _n_possible_combinations(self, search_space: dict[str, Any]) -> int | None:
275+
"""Calculate the number of possible combinations in the search space.
276+
277+
Args:
278+
search_space: The parameter search space.
279+
280+
Returns:
281+
The number of possible combinations or None if search space is continuous.
282+
"""
283+
n_combinations = 1
284+
for param_space in search_space.values():
285+
if isinstance(param_space, list):
286+
n_combinations *= len(param_space)
287+
elif param_space_int := self._parse_param_space(param_space, ParamSpaceInt):
288+
n_combinations *= param_space_int.n_possible_values()
289+
elif param_space_float := self._parse_param_space(param_space, ParamSpaceFloat):
290+
n_possible_values = param_space_float.n_possible_values()
291+
if n_possible_values is None:
292+
return None
293+
n_combinations *= n_possible_values
294+
else:
295+
assert_never(param_space)
296+
return n_combinations
297+
298+
def _parse_param_space(self, param_space: dict[str, Any], space_type: type[ParamSpaceT]) -> ParamSpaceT | None:
220299
try:
221-
space_type(**param_space)
222-
return True # noqa: TRY300
300+
return space_type(**param_space)
223301
except ValueError:
224-
return False
302+
return None
225303

226304
def get_module_dump_dir(self, dump_dir: Path, module_name: str, j_combination: int) -> str:
227305
"""Creates and returns the path to the module dump directory.
@@ -305,7 +383,7 @@ def _reformat_search_space(self, module_search_space: dict[str, Any]) -> tuple[d
305383
continue
306384
if isinstance(param_space, list):
307385
res[param_name] = param_space
308-
elif self._is_valid_param_space(param_space, ParamSpaceInt) or self._is_valid_param_space(
386+
elif self._parse_param_space(param_space, ParamSpaceInt) or self._parse_param_space(
309387
param_space, ParamSpaceFloat
310388
):
311389
res[param_name] = [param_space["low"], param_space["high"]]

0 commit comments

Comments
 (0)