Skip to content

Commit 9900d43

Browse files
voorhsSamoed
andauthored
Feat/optuna (#128)
* define interface * basic ho iterator * move obtaining data for train from node optimizer to modules themselves * stage progress * implement cv iterator * minor bug fix * implement cv iterator for decision node * move cv iteration to base module definition * implement cv iterator for embedding node * add training to `score_ho` of each node * properly define base module * fix codestyle * remove regexp node * remove regexp validator * fix typing problems (except `DataHandler._split_cv`) * add ingore oos decorator * fix codestyle * fix typing * add oos handling to cv iterator * remove `DataHandler.dump()` * minor bug fix * implement splitting to cv folds * fix codestyle * remove regex tests * bug fix * bug fix * update tests * fix typing * big fix * basic test on cv folding * add tests for metrics to ignore oos samples * add tests for cv iterator * fix codestyle * minor bug fix * fix codestyle * add test for cv * bug fix * implement cv iterator for description scorer * refactor cv iterator for description node * fix typing * add cache cleaning before refitting * bug fix * implement refitting the whole pipeline with all train data * fix typing * bug fix * fix typing * respond to samoed * create `ValidationType` in `autointent.custom_types` * fix docstring * properly expose `n_folds` argument * implement `_fit_bayes` * add typing to param spaces * minor bug fix * minor bug fix * fix codestyle * add tuning selection to pipeline * add test on bayes * disable search space validation for now * fix codestyle * remove `ParamSpaceCat` (it's redundant) * move to optuna entirely * refactor yaml format a little bit * add test for random sampler * rename some variables * add config validation for optuna (#132) * add config validation * add validation for union types * remove debug code * remove comment * run tests on pr for all branches * fix mlknn * fix type validation * return CI config back to normal * fix default value for step in `ParamSpaceFloat` * update schema * update callback test * change CI config * update search space configs for testing * enable validation back * remove TunableDecision from search spaces * upd schema --------- Co-authored-by: Roman Solomatin <[email protected]> Co-authored-by: Roman Solomatin <[email protected]>
1 parent b134971 commit 9900d43

File tree

26 files changed

+790
-193
lines changed

26 files changed

+790
-193
lines changed

.github/workflows/test-inference.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ on:
55
branches:
66
- dev
77
pull_request:
8-
branches:
9-
- dev
108

119
jobs:
1210
test:

.github/workflows/test-nodes.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ on:
55
branches:
66
- dev
77
pull_request:
8-
branches:
9-
- dev
108

119
jobs:
1210
test:

.github/workflows/test-optimization.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ on:
55
branches:
66
- dev
77
pull_request:
8-
branches:
9-
- dev
108

119
jobs:
1210
test:

.github/workflows/unit-tests.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ on:
55
branches:
66
- dev
77
pull_request:
8-
branches:
9-
- dev
108

119
jobs:
1210
test:

autointent/_pipeline/_pipeline.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from autointent import Context, Dataset
1212
from autointent.configs import InferenceNodeConfig, LoggingConfig, VectorIndexConfig
13-
from autointent.custom_types import ListOfGenericLabels, NodeType, ValidationScheme
13+
from autointent.custom_types import ListOfGenericLabels, NodeType, SamplerType, ValidationScheme
1414
from autointent.metrics import PREDICTION_METRICS_MULTILABEL
1515
from autointent.nodes import InferenceNode, NodeOptimizer
1616
from autointent.nodes.schemes import OptimizationConfig
@@ -87,7 +87,7 @@ def default_optimizer(cls, multilabel: bool, seed: int = 42) -> "Pipeline":
8787
"""
8888
return cls.from_search_space(search_space=load_default_search_space(multilabel), seed=seed)
8989

90-
def _fit(self, context: Context) -> None:
90+
def _fit(self, context: Context, sampler: SamplerType = "brute") -> None:
9191
"""
9292
Optimize the pipeline.
9393
@@ -102,7 +102,7 @@ def _fit(self, context: Context) -> None:
102102
for node_type in NodeType:
103103
node_optimizer = self.nodes.get(node_type, None)
104104
if node_optimizer is not None:
105-
node_optimizer.fit(context) # type: ignore[union-attr]
105+
node_optimizer.fit(context, sampler) # type: ignore[union-attr]
106106
if not context.vector_index_config.save_db:
107107
self._logger.info("removing vector database from file system...")
108108
# TODO clear cache from appdirs
@@ -117,7 +117,12 @@ def _is_inference(self) -> bool:
117117
return isinstance(self.nodes[NodeType.scoring], InferenceNode)
118118

119119
def fit(
120-
self, dataset: Dataset, scheme: ValidationScheme = "ho", n_folds: int = 3, refit_after: bool = False
120+
self,
121+
dataset: Dataset,
122+
scheme: ValidationScheme = "ho",
123+
n_folds: int = 3,
124+
refit_after: bool = False,
125+
sampler: SamplerType = "brute",
121126
) -> Context:
122127
"""
123128
Optimize the pipeline from dataset.
@@ -135,7 +140,7 @@ def fit(
135140
context.configure_vector_index(self.vector_index_config)
136141

137142
self.validate_modules(dataset)
138-
self._fit(context)
143+
self._fit(context, sampler)
139144

140145
if context.is_ram_to_clear():
141146
nodes_configs = context.optimization_info.get_inference_nodes_config()

autointent/configs/_optimization.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from pathlib import Path
44

5-
from pydantic import BaseModel, Field
5+
from pydantic import BaseModel, Field, PositiveInt
66

7-
from autointent.custom_types import ValidationScheme
7+
from autointent.custom_types import SamplerType, ValidationScheme
88

99
from ._name import get_run_name
1010

@@ -16,7 +16,7 @@ class DataConfig(BaseModel):
1616
"""Path to the training data. Can be local path or HF repo."""
1717
scheme: ValidationScheme
1818
"""Hold-out or cross-validation."""
19-
n_folds: int = 3
19+
n_folds: PositiveInt = 3
2020
"""Number of folds in cross-validation."""
2121

2222

@@ -25,6 +25,7 @@ class TaskConfig(BaseModel):
2525

2626
search_space_path: Path | None = None
2727
"""Path to the search space configuration file. If None, the default search space will be used"""
28+
sampler: SamplerType = "brute"
2829

2930

3031
class LoggingConfig(BaseModel):

autointent/custom_types.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
"""
66

77
from enum import Enum
8-
from typing import Literal, TypeAlias
8+
from typing import Annotated, Literal, TypeAlias
9+
10+
from annotated_types import Interval
911

1012

1113
class LogLevel(Enum):
@@ -71,4 +73,9 @@ class Split:
7173
INTENTS = "intents"
7274

7375

76+
SamplerType = Literal["brute", "tpe", "random"]
7477
ValidationScheme = Literal["ho", "cv"]
78+
79+
80+
FloatFromZeroToOne = Annotated[float, Interval(ge=0, le=1)]
81+
"""Float value between 0 and 1, inclusive."""

autointent/modules/decision/_adaptive.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy.typing as npt
88

99
from autointent import Context
10-
from autointent.custom_types import ListOfGenericLabels, ListOfLabelsWithOOS, MultiLabel
10+
from autointent.custom_types import FloatFromZeroToOne, ListOfGenericLabels, ListOfLabelsWithOOS, MultiLabel
1111
from autointent.exceptions import MismatchNumClassesError
1212
from autointent.metrics import decision_f1
1313
from autointent.modules.abc import DecisionModule
@@ -58,7 +58,7 @@ class AdaptiveDecision(DecisionModule):
5858
supports_oos = False
5959
name = "adaptive"
6060

61-
def __init__(self, search_space: list[float] | None = None) -> None:
61+
def __init__(self, search_space: list[FloatFromZeroToOne] | None = None) -> None:
6262
"""
6363
Initialize the AdaptiveDecision.
6464
@@ -68,7 +68,7 @@ def __init__(self, search_space: list[float] | None = None) -> None:
6868
self.search_space = search_space if search_space is not None else default_search_space
6969

7070
@classmethod
71-
def from_context(cls, context: Context, search_space: list[float] | None = None) -> "AdaptiveDecision":
71+
def from_context(cls, context: Context, search_space: list[FloatFromZeroToOne] | None = None) -> "AdaptiveDecision":
7272
"""
7373
Create an AdaptiveDecision instance using a Context object.
7474

autointent/modules/decision/_jinoos.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy.typing as npt
77

88
from autointent import Context
9-
from autointent.custom_types import ListOfGenericLabels
9+
from autointent.custom_types import FloatFromZeroToOne, ListOfGenericLabels
1010
from autointent.exceptions import MismatchNumClassesError
1111
from autointent.modules.abc import DecisionModule
1212
from autointent.schemas import Tag
@@ -55,7 +55,7 @@ class JinoosDecision(DecisionModule):
5555

5656
def __init__(
5757
self,
58-
search_space: list[float] | None = None,
58+
search_space: list[FloatFromZeroToOne] | None = None,
5959
) -> None:
6060
"""
6161
Initialize Jinoos predictor.
@@ -65,7 +65,7 @@ def __init__(
6565
self.search_space = np.array(search_space) if search_space is not None else default_search_space
6666

6767
@classmethod
68-
def from_context(cls, context: Context, search_space: list[float] | None = None) -> "JinoosDecision":
68+
def from_context(cls, context: Context, search_space: list[FloatFromZeroToOne] | None = None) -> "JinoosDecision":
6969
"""
7070
Initialize from context.
7171

autointent/modules/decision/_threshold.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy.typing as npt
88

99
from autointent import Context
10-
from autointent.custom_types import ListOfGenericLabels, MultiLabel
10+
from autointent.custom_types import FloatFromZeroToOne, ListOfGenericLabels, MultiLabel
1111
from autointent.exceptions import MismatchNumClassesError
1212
from autointent.modules.abc import DecisionModule
1313
from autointent.schemas import Tag
@@ -75,7 +75,7 @@ class ThresholdDecision(DecisionModule):
7575

7676
def __init__(
7777
self,
78-
thresh: float | list[float],
78+
thresh: FloatFromZeroToOne | list[FloatFromZeroToOne],
7979
) -> None:
8080
"""
8181
Initialize threshold predictor.
@@ -85,7 +85,9 @@ def __init__(
8585
self.thresh = thresh if isinstance(thresh, float) else np.array(thresh)
8686

8787
@classmethod
88-
def from_context(cls, context: Context, thresh: float | list[float] = 0.5) -> "ThresholdDecision":
88+
def from_context(
89+
cls, context: Context, thresh: FloatFromZeroToOne | list[FloatFromZeroToOne] = 0.5
90+
) -> "ThresholdDecision":
8991
"""
9092
Initialize from context.
9193

0 commit comments

Comments
 (0)