Skip to content

Commit b0e1e0c

Browse files
authored
Refactor/search space validation (#153)
* add search space validation to `NodeOptimizer` constructor * add validation to decision modules' constructors * add validaiton to embedding modules' constructors * add validaiton to scoring modules' constructors * add optuna notation handling in search space validation * fix typing * disable previous validation * bug fix * update tests * remove previous validation * update tests * update schema * bug fix * update docs * edit presets
1 parent a5b777e commit b0e1e0c

30 files changed

+303
-1821
lines changed

autointent/_optimization_config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1+
from typing import Any
2+
13
from pydantic import BaseModel, PositiveInt
24

35
from .configs import CrossEncoderConfig, DataConfig, EmbedderConfig, LoggingConfig
46
from .custom_types import SamplerType
5-
from .nodes.schemes import OptimizationSearchSpaceConfig
67

78

89
class OptimizationConfig(BaseModel):
910
"""Configuration for the optimization process."""
1011

1112
data_config: DataConfig = DataConfig()
12-
search_space: OptimizationSearchSpaceConfig
13+
search_space: list[dict[str, Any]]
1314
logging_config: LoggingConfig = LoggingConfig()
1415
embedder_config: EmbedderConfig = EmbedderConfig()
1516
cross_encoder_config: CrossEncoderConfig = CrossEncoderConfig()

autointent/_pipeline/_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def from_optimization_config(cls, config: dict[str, Any] | Path | str | Optimiza
122122
optimization_config = OptimizationConfig(**dict_params)
123123

124124
pipeline = cls(
125-
[NodeOptimizer(**node.model_dump()) for node in optimization_config.search_space],
125+
[NodeOptimizer(**node) for node in optimization_config.search_space],
126126
optimization_config.sampler,
127127
optimization_config.seed,
128128
)

autointent/_presets/heavy.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ search_space:
1414
k:
1515
low: 1
1616
high: 20
17-
step: 1
1817
n_trials: 10
1918
- module_name: description
2019
temperature:

autointent/_presets/heavy_moderate.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ search_space:
1313
k:
1414
low: 1
1515
high: 20
16-
step: 1
1716
n_trials: 10
1817
- module_name: description
1918
temperature:

autointent/_presets/light.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ search_space:
1313
k:
1414
low: 1
1515
high: 20
16-
step: 1
1716
n_trials: 10
1817
- node_type: decision
1918
target_metric: decision_accuracy
@@ -22,6 +21,6 @@ search_space:
2221
thresh:
2322
low: 0.1
2423
high: 0.9
25-
step: 0.1
24+
n_trials: 10
2625
- module_name: argmax
2726
sampler: tpe

autointent/_presets/light_extra.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,6 @@ search_space:
2121
thresh:
2222
low: 0.1
2323
high: 0.9
24-
n_trials: 10
24+
n_trials: 10
2525
- module_name: argmax
2626
sampler: random

autointent/modules/decision/_adaptive.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ def __init__(self, search_space: list[FloatFromZeroToOne] | None = None) -> None
6767
"""
6868
self.search_space = search_space if search_space is not None else default_search_space
6969

70+
if any(val < 0 or val > 1 for val in self.search_space):
71+
msg = "Unsupported items in `search_space` arg of `AdaptiveDecision` module"
72+
raise ValueError(msg)
73+
7074
@classmethod
7175
def from_context(cls, context: Context, search_space: list[FloatFromZeroToOne] | None = None) -> "AdaptiveDecision":
7276
"""

autointent/modules/decision/_jinoos.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ def __init__(
6464
"""
6565
self.search_space = np.array(search_space) if search_space is not None else default_search_space
6666

67+
if any(val < 0 or val > 1 for val in self.search_space):
68+
msg = "Items pf `search_space` of `AdaptiveDecision` module must be a floats from zero to one"
69+
raise ValueError(msg)
70+
6771
@classmethod
6872
def from_context(cls, context: Context, search_space: list[FloatFromZeroToOne] | None = None) -> "JinoosDecision":
6973
"""

autointent/modules/decision/_threshold.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,16 @@ def __init__(
8282
8383
:param thresh: Threshold for the scores, shape (n_classes,) or float
8484
"""
85+
val_error = False
8586
self.thresh = thresh if isinstance(thresh, float) else np.array(thresh)
87+
if isinstance(thresh, float):
88+
val_error = val_error or thresh < 0 or thresh > 1
89+
else:
90+
val_error = val_error or any(val < 0 or val > 1 for val in thresh)
91+
92+
if val_error:
93+
msg = "`thresh` arg of `ThresholdDecision` must contain a float from zero to one (or list of floats)."
94+
raise ValueError(msg)
8695

8796
@classmethod
8897
def from_context(

autointent/modules/decision/_tunable.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Tunable predictor module."""
22

3-
from typing import Any, Literal
3+
from typing import Any, Literal, get_args
44

55
import numpy as np
66
import numpy.typing as npt
@@ -96,6 +96,14 @@ def __init__(
9696
self.seed = seed
9797
self.tags = tags
9898

99+
if self.n_optuna_trials < 0 or not isinstance(self.n_optuna_trials, int):
100+
msg = "Unsupported value for `n_optuna_trial` of `TunableDecision` module"
101+
raise ValueError(msg)
102+
103+
if self.target_metric not in get_args(MetricType):
104+
msg = "Unsupported value for `target_metric` of `TunableDecision` module"
105+
raise TypeError(msg)
106+
99107
@classmethod
100108
def from_context(
101109
cls, context: Context, target_metric: MetricType = "decision_accuracy", n_optuna_trials: PositiveInt = 320

0 commit comments

Comments
 (0)