Skip to content

Commit 8935f09

Browse files
committed
try to fix
1 parent b57c002 commit 8935f09

File tree

1 file changed

+72
-11
lines changed

1 file changed

+72
-11
lines changed

autointent/nodes/_node_optimizer.py

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
import gc
44
import itertools as it
55
import logging
6+
from abc import ABC, abstractmethod
67
from copy import deepcopy
78
from functools import partial
89
from pathlib import Path
9-
from typing import Any
10+
from typing import Any, TypeVar
1011

12+
import numpy as np
1113
import optuna
1214
import torch
1315
from optuna.trial import Trial
@@ -20,27 +22,65 @@
2022
from autointent.nodes.info import NODES_INFO
2123

2224

23-
class ParamSpaceInt(BaseModel):
25+
class ParamSpace(BaseModel, ABC):
26+
"""Base class for parameter search space configuration."""
27+
28+
@abstractmethod
29+
def n_possible_values(self) -> int | None:
30+
"""Calculate the number of possible values in the search space.
31+
32+
Returns:
33+
The number of possible values or None if search space is continuous.
34+
"""
35+
36+
37+
class ParamSpaceInt(ParamSpace):
2438
"""Integer parameter search space configuration."""
2539

2640
low: int = Field(..., description="Lower boundary of the search space.")
2741
high: int = Field(..., description="Upper boundary of the search space.")
2842
step: int = Field(1, description="Step size for the search space.")
2943
log: bool = Field(False, description="Indicates whether to use a logarithmic scale.")
3044

45+
def n_possible_values(self) -> int:
46+
"""Calculate the number of possible values in the search space.
3147
32-
class ParamSpaceFloat(BaseModel):
48+
Returns:
49+
The number of possible values.
50+
"""
51+
if self.log:
52+
return int(np.logspace(np.log10(self.low), np.log10(self.high), num=self.step))
53+
54+
return (self.high - self.low) // self.step + 1
55+
56+
57+
class ParamSpaceFloat(ParamSpace):
3358
"""Float parameter search space configuration."""
3459

3560
low: float = Field(..., description="Lower boundary of the search space.")
3661
high: float = Field(..., description="Upper boundary of the search space.")
3762
step: float | None = Field(None, description="Step size for the search space (if applicable).")
3863
log: bool = Field(False, description="Indicates whether to use a logarithmic scale.")
3964

65+
def n_possible_values(self) -> int | None:
66+
"""Calculate the number of possible values in the search space.
67+
68+
Returns:
69+
The number of possible values or None if search space is continuous.
70+
"""
71+
if self.step is None:
72+
return None
73+
if self.log:
74+
return int(np.logspace(np.log10(self.low), np.log10(self.high), num=self.step))
75+
return (self.high - self.low) // self.step + 1
76+
4077

4178
logger = logging.getLogger(__name__)
4279

4380

81+
ParamSpaceT = TypeVar("ParamSpaceT", bound=ParamSpace)
82+
83+
4484
class NodeOptimizer:
4585
"""Class for optimizing nodes in a computational pipeline.
4686
@@ -104,6 +144,8 @@ def fit(self, context: Context, sampler: SamplerType = "brute", n_jobs: int = 1)
104144
else:
105145
assert_never(sampler)
106146

147+
n_trials = None if n_trials is None else min(self._n_possible_combinations(search_space), n_trials)
148+
107149
study, finished_trials, n_trials = load_or_create_study(
108150
study_name=f"{self.node_info.node_type}_{module_name}",
109151
storage_dir=context.get_dump_dir(),
@@ -205,23 +247,42 @@ def suggest(self, trial: Trial, search_space: dict[str, Any | list[Any]]) -> dic
205247
for param_name, param_space in search_space.items():
206248
if isinstance(param_space, list):
207249
res[param_name] = trial.suggest_categorical(param_name, choices=param_space)
208-
elif self._is_valid_param_space(param_space, ParamSpaceInt):
250+
elif self._parse_param_space(param_space, ParamSpaceInt):
209251
res[param_name] = trial.suggest_int(param_name, **param_space)
210-
elif self._is_valid_param_space(param_space, ParamSpaceFloat):
252+
elif self._parse_param_space(param_space, ParamSpaceFloat):
211253
res[param_name] = trial.suggest_float(param_name, **param_space)
212254
else:
213255
msg = f"Unsupported type of param search space: {param_space}"
214256
raise TypeError(msg)
215257
return res
216258

217-
def _is_valid_param_space(
218-
self, param_space: dict[str, Any], space_type: type[ParamSpaceInt | ParamSpaceFloat]
219-
) -> bool:
259+
def _n_possible_combinations(self, search_space: dict[str, Any]) -> int:
260+
"""Calculate the number of possible combinations in the search space.
261+
262+
Args:
263+
search_space: The parameter search space.
264+
"""
265+
n_combinations = 1
266+
for param_space in search_space.values():
267+
if isinstance(param_space, list):
268+
n_combinations *= len(param_space)
269+
else:
270+
param_space_int = self._parse_param_space(param_space, ParamSpaceInt)
271+
if param_space_int is not None:
272+
n_combinations *= param_space_int.n_possible_values()
273+
continue
274+
param_space_float = self._parse_param_space(param_space, ParamSpaceFloat)
275+
if param_space_float is not None:
276+
n_combinations *= param_space_float.n_possible_values()
277+
continue
278+
assert_never(param_space)
279+
return n_combinations
280+
281+
def _parse_param_space(self, param_space: dict[str, Any], space_type: type[ParamSpaceT]) -> ParamSpaceT | None:
220282
try:
221-
space_type(**param_space)
222-
return True # noqa: TRY300
283+
return space_type(**param_space)
223284
except ValueError:
224-
return False
285+
return None
225286

226287
def get_module_dump_dir(self, dump_dir: Path, module_name: str, j_combination: int) -> str:
227288
"""Creates and returns the path to the module dump directory.

0 commit comments

Comments
 (0)