Skip to content

Commit 54b8fad

Browse files
Samoedgithub-actions[bot]voorhs
authored
add full config (#144)
* add full config * Update optimizer_config.schema.json * add missing config * Update optimizer_config.schema.json * update config # Conflicts: # docs/optimizer_config.schema.json * Update optimizer_config.schema.json * fix * try fix * try fix * attempt (#150) --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Алексеев Илья <[email protected]>
1 parent d2ac6e1 commit 54b8fad

19 files changed

+1974
-87
lines changed

autointent/_callbacks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
REPORTERS = {cb.name: cb for cb in [WandbCallback, TensorBoardCallback]}
99

10-
REPORTERS_NAMES = list(REPORTERS.keys())
10+
REPORTERS_NAMES = Literal[tuple(REPORTERS.keys())] # type: ignore[valid-type]
1111

1212

1313
def get_callbacks(reporters: list[str] | None) -> CallbackHandler:

autointent/_pipeline/_pipeline.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
import logging
55
from pathlib import Path
6-
from typing import TYPE_CHECKING, Any
6+
from typing import TYPE_CHECKING, Any, get_args
77

88
import numpy as np
99
import yaml
@@ -13,7 +13,7 @@
1313
from autointent.custom_types import ListOfGenericLabels, NodeType, SamplerType
1414
from autointent.metrics import DECISION_METRICS
1515
from autointent.nodes import InferenceNode, NodeOptimizer
16-
from autointent.nodes.schemes import OptimizationConfig
16+
from autointent.nodes.schemes import OptimizationConfig, OptimizationSearchSpaceConfig
1717
from autointent.utils import load_default_search_space, load_search_space
1818

1919
from ._schemas import InferencePipelineOutput, InferencePipelineUtteranceOutput
@@ -28,17 +28,24 @@ class Pipeline:
2828
def __init__(
2929
self,
3030
nodes: list[NodeOptimizer] | list[InferenceNode],
31+
sampler: SamplerType = "brute",
3132
seed: int = 42,
3233
) -> None:
3334
"""
3435
Initialize the pipeline optimizer.
3536
3637
:param nodes: list of nodes
38+
:param sampler: sampler type
3739
:param seed: random seed
3840
"""
3941
self._logger = logging.getLogger(__name__)
4042
self.nodes = {node.node_type: node for node in nodes}
4143
self.seed = seed
44+
if sampler not in get_args(SamplerType):
45+
msg = f"Sampler should be one of {get_args(SamplerType)}"
46+
raise ValueError(msg)
47+
48+
self.sampler = sampler
4249

4350
if isinstance(nodes[0], NodeOptimizer):
4451
self.logging_config = LoggingConfig(dump_dir=None)
@@ -74,10 +81,34 @@ def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed
7481
"""
7582
if isinstance(search_space, Path | str):
7683
search_space = load_search_space(search_space)
77-
validated_search_space = OptimizationConfig(search_space).model_dump() # type: ignore[arg-type]
84+
validated_search_space = OptimizationSearchSpaceConfig(search_space).model_dump() # type: ignore[arg-type]
7885
nodes = [NodeOptimizer(**node) for node in validated_search_space]
7986
return cls(nodes=nodes, seed=seed)
8087

88+
@classmethod
89+
def from_optimization_config(cls, config: dict[str, Any] | Path | str) -> "Pipeline":
90+
"""
91+
Create pipeline optimizer from optimization config.
92+
93+
:param config: Optimization config
94+
:return:
95+
"""
96+
if isinstance(config, Path | str):
97+
with Path(config).open() as file:
98+
loaded_config = yaml.safe_load(file)
99+
else:
100+
loaded_config = config
101+
optimization_config = OptimizationConfig(**loaded_config)
102+
pipeline = cls(
103+
[NodeOptimizer(**node.model_dump()) for node in optimization_config.task_config.search_space],
104+
optimization_config.task_config.sampler,
105+
optimization_config.seed,
106+
)
107+
pipeline.set_config(optimization_config.logging_config)
108+
pipeline.set_config(optimization_config.vector_index_config)
109+
pipeline.set_config(optimization_config.data_config)
110+
return pipeline
111+
81112
@classmethod
82113
def default_optimizer(cls, multilabel: bool, seed: int = 42) -> "Pipeline":
83114
"""
@@ -90,7 +121,7 @@ def default_optimizer(cls, multilabel: bool, seed: int = 42) -> "Pipeline":
90121
"""
91122
return cls.from_search_space(search_space=load_default_search_space(multilabel), seed=seed)
92123

93-
def _fit(self, context: Context, sampler: SamplerType = "brute") -> None:
124+
def _fit(self, context: Context, sampler: SamplerType) -> None:
94125
"""
95126
Optimize the pipeline.
96127
@@ -99,7 +130,7 @@ def _fit(self, context: Context, sampler: SamplerType = "brute") -> None:
99130
self.context = context
100131
self._logger.info("starting pipeline optimization...")
101132
self.context.callback_handler.start_run(
102-
run_name=self.context.logging_config.run_name,
133+
run_name=self.context.logging_config.get_run_name(),
103134
dirpath=self.context.logging_config.dirpath,
104135
)
105136
for node_type in NodeType:
@@ -123,7 +154,7 @@ def fit(
123154
self,
124155
dataset: Dataset,
125156
refit_after: bool = False,
126-
sampler: SamplerType = "brute",
157+
sampler: SamplerType | None = None,
127158
) -> Context:
128159
"""
129160
Optimize the pipeline from dataset.
@@ -148,6 +179,9 @@ def fit(
148179
"Test data is not provided. Final test metrics won't be calculated after pipeline optimization."
149180
)
150181

182+
if sampler is None:
183+
sampler = self.sampler or "brute"
184+
151185
self._fit(context, sampler)
152186

153187
if context.is_ram_to_clear():

autointent/configs/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from ._optimization import (
55
DataConfig,
66
LoggingConfig,
7-
TaskConfig,
87
VectorIndexConfig,
98
)
109

@@ -13,6 +12,5 @@
1312
"InferenceNodeConfig",
1413
"InferenceNodeConfig",
1514
"LoggingConfig",
16-
"TaskConfig",
1715
"VectorIndexConfig",
1816
]

autointent/configs/_optimization.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,78 +2,76 @@
22

33
from pathlib import Path
44

5-
from pydantic import BaseModel, Field, PositiveInt, field_validator
5+
from pydantic import BaseModel, Field, PositiveInt
66

77
from autointent._callbacks import REPORTERS_NAMES
8-
from autointent.custom_types import FloatFromZeroToOne, SamplerType, ValidationScheme
8+
from autointent.custom_types import FloatFromZeroToOne, ValidationScheme
99

1010
from ._name import get_run_name
1111

1212

1313
class DataConfig(BaseModel):
1414
"""Configuration for the data used in the optimization process."""
1515

16-
scheme: ValidationScheme = "ho"
16+
scheme: ValidationScheme = Field("ho", description="Validation scheme to use.")
1717
"""Hold-out or cross-validation."""
18-
n_folds: PositiveInt = 3
18+
n_folds: PositiveInt = Field(3, description="Number of folds in cross-validation.")
1919
"""Number of folds in cross-validation."""
20-
validation_size: FloatFromZeroToOne = 0.2
20+
validation_size: FloatFromZeroToOne = Field(
21+
0.2,
22+
description=(
23+
"Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split)."
24+
),
25+
)
2126
"""Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split)."""
22-
separation_ratio: FloatFromZeroToOne | None = 0.5
27+
separation_ratio: FloatFromZeroToOne | None = Field(
28+
0.5, description="Set to float to prevent data leak between scoring and decision nodes."
29+
)
2330
"""Set to float to prevent data leak between scoring and decision nodes."""
2431

2532

26-
class TaskConfig(BaseModel):
27-
"""Configuration for the task to optimize."""
28-
29-
search_space_path: Path | None = None
30-
"""Path to the search space configuration file. If None, the default search space will be used"""
31-
sampler: SamplerType = "brute"
32-
33-
3433
class LoggingConfig(BaseModel):
3534
"""Configuration for the logging."""
3635

37-
project_dir: Path = Field(default_factory=lambda: Path.cwd() / "runs")
36+
_dirpath: Path | None = None
37+
_dump_dir: Path | None = None
38+
39+
project_dir: Path | str | None = Field(None, description="Path to the directory with different runs.")
3840
"""Path to the directory with different runs."""
39-
run_name: str = Field(default_factory=get_run_name)
41+
run_name: str | None = Field(None, description="Name of the run. If None, a random name will be generated.")
4042
"""Name of the run. If None, a random name will be generated"""
41-
dump_modules: bool = False
43+
dump_modules: bool = Field(False, description="Whether to dump the modules or not")
4244
"""Whether to dump the modules or not"""
43-
clear_ram: bool = False
45+
clear_ram: bool = Field(False, description="Whether to clear the RAM after dumping the modules")
4446
"""Whether to clear the RAM after dumping the modules"""
45-
report_to: list[str] | None = None
47+
report_to: list[REPORTERS_NAMES] | None = Field( # type: ignore[valid-type]
48+
None, description="List of callbacks to report to. If None, no callbacks will be used"
49+
)
4650
"""List of callbacks to report to. If None, no callbacks will be used"""
4751

4852
@property
4953
def dirpath(self) -> Path:
5054
"""Path to the directory where the logs will be saved."""
51-
if not hasattr(self, "_dirpath"):
52-
self._dirpath = self.project_dir / self.run_name
55+
if self._dirpath is None:
56+
project_dir = Path.cwd() / "runs" if self.project_dir is None else Path(self.project_dir)
57+
self._dirpath = project_dir / self.get_run_name()
5358
return self._dirpath
5459

5560
@property
5661
def dump_dir(self) -> Path:
5762
"""Path to the directory where the modules will be dumped."""
58-
if not hasattr(self, "_dump_dir"):
63+
if self._dump_dir is None:
5964
self._dump_dir = self.dirpath / "modules_dumps"
6065
return self._dump_dir
6166

62-
@field_validator("report_to")
63-
@classmethod
64-
def validate_report_to(cls, v: list[str] | None) -> list[str] | None:
65-
"""Validate the report_to field."""
66-
if v is None:
67-
return None
68-
for reporter in v:
69-
if reporter not in REPORTERS_NAMES:
70-
msg = f"Reporter {reporter} is not supported. Supported reporters: {REPORTERS_NAMES}"
71-
raise ValueError(msg)
72-
return v
67+
def get_run_name(self) -> str:
68+
if self.run_name is None:
69+
self.run_name = get_run_name()
70+
return self.run_name
7371

7472

7573
class VectorIndexConfig(BaseModel):
7674
"""Configuration for the vector index."""
7775

78-
save_db: bool = False
76+
save_db: bool = Field(False, description="Whether to save the vector index database or not")
7977
"""Whether to save the vector index database or not"""

autointent/nodes/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
from ._inference_node import InferenceNode
44
from ._optimization import NodeOptimizer
5-
from .schemes import OptimizationConfig
5+
from .schemes import OptimizationSearchSpaceConfig
66

77
__all__ = [
88
"InferenceNode",
99
"NodeOptimizer",
10-
"OptimizationConfig",
10+
"OptimizationSearchSpaceConfig",
1111
]

autointent/nodes/schemes.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
from pydantic import BaseModel, Field, PositiveInt, RootModel
88

9-
from autointent.custom_types import NodeType
9+
from autointent.configs import DataConfig, LoggingConfig, VectorIndexConfig
10+
from autointent.custom_types import NodeType, SamplerType
1011
from autointent.modules.abc import BaseModule
1112
from autointent.nodes._optimization._node_optimizer import ParamSpaceFloat, ParamSpaceInt
1213
from autointent.nodes.info import DecisionNodeInfo, EmbeddingNodeInfo, RegexNodeInfo, ScoringNodeInfo
@@ -160,7 +161,7 @@ class RegexNodeValidator(BaseModel):
160161
SearchSpaceTypes: TypeAlias = EmbeddingNodeValidator | ScoringNodeValidator | DecisionNodeValidator | RegexNodeValidator
161162

162163

163-
class OptimizationConfig(RootModel[list[SearchSpaceTypes]]):
164+
class OptimizationSearchSpaceConfig(RootModel[list[SearchSpaceTypes]]):
164165
"""Optimizer configuration."""
165166

166167
def __iter__(
@@ -178,3 +179,21 @@ def __getitem__(self, item: int) -> SearchSpaceTypes:
178179
:return: Item
179180
"""
180181
return self.root[item]
182+
183+
184+
class TaskConfig(BaseModel):
185+
"""Configuration for the task to optimize."""
186+
187+
search_space: OptimizationSearchSpaceConfig
188+
"""Path to the search space configuration file. If None, the default search space will be used"""
189+
sampler: SamplerType = "brute"
190+
191+
192+
class OptimizationConfig(BaseModel):
193+
"""Configuration for the optimization process."""
194+
195+
data_config: DataConfig = DataConfig()
196+
task_config: TaskConfig
197+
logging_config: LoggingConfig = LoggingConfig()
198+
vector_index_config: VectorIndexConfig = VectorIndexConfig()
199+
seed: PositiveInt = 42

0 commit comments

Comments
 (0)