Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion autointent/_callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

REPORTERS = {cb.name: cb for cb in [WandbCallback, TensorBoardCallback]}

REPORTERS_NAMES = list(REPORTERS.keys())
REPORTERS_NAMES = Literal[tuple(REPORTERS.keys())] # type: ignore[valid-type]


def get_callbacks(reporters: list[str] | None) -> CallbackHandler:
Expand Down
46 changes: 40 additions & 6 deletions autointent/_pipeline/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, get_args

import numpy as np
import yaml
Expand All @@ -13,7 +13,7 @@
from autointent.custom_types import ListOfGenericLabels, NodeType, SamplerType
from autointent.metrics import DECISION_METRICS
from autointent.nodes import InferenceNode, NodeOptimizer
from autointent.nodes.schemes import OptimizationConfig
from autointent.nodes.schemes import OptimizationConfig, OptimizationSearchSpaceConfig
from autointent.utils import load_default_search_space, load_search_space

from ._schemas import InferencePipelineOutput, InferencePipelineUtteranceOutput
Expand All @@ -28,17 +28,24 @@ class Pipeline:
def __init__(
self,
nodes: list[NodeOptimizer] | list[InferenceNode],
sampler: SamplerType = "brute",
seed: int = 42,
) -> None:
"""
Initialize the pipeline optimizer.

:param nodes: list of nodes
:param sampler: sampler type
:param seed: random seed
"""
self._logger = logging.getLogger(__name__)
self.nodes = {node.node_type: node for node in nodes}
self.seed = seed
if sampler not in get_args(SamplerType):
msg = f"Sampler should be one of {get_args(SamplerType)}"
raise ValueError(msg)

self.sampler = sampler

if isinstance(nodes[0], NodeOptimizer):
self.logging_config = LoggingConfig(dump_dir=None)
Expand Down Expand Up @@ -74,10 +81,34 @@ def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed
"""
if isinstance(search_space, Path | str):
search_space = load_search_space(search_space)
validated_search_space = OptimizationConfig(search_space).model_dump() # type: ignore[arg-type]
validated_search_space = OptimizationSearchSpaceConfig(search_space).model_dump() # type: ignore[arg-type]
nodes = [NodeOptimizer(**node) for node in validated_search_space]
return cls(nodes=nodes, seed=seed)

@classmethod
def from_optimization_config(cls, config: dict[str, Any] | Path | str) -> "Pipeline":
"""
Create pipeline optimizer from optimization config.

:param config: Optimization config
:return:
"""
if isinstance(config, Path | str):
with Path(config).open() as file:
loaded_config = yaml.safe_load(file)
else:
loaded_config = config
optimization_config = OptimizationConfig(**loaded_config)
pipeline = cls(
[NodeOptimizer(**node.model_dump()) for node in optimization_config.task_config.search_space],
optimization_config.task_config.sampler,
optimization_config.seed,
)
pipeline.set_config(optimization_config.logging_config)
pipeline.set_config(optimization_config.vector_index_config)
pipeline.set_config(optimization_config.data_config)
return pipeline

@classmethod
def default_optimizer(cls, multilabel: bool, seed: int = 42) -> "Pipeline":
"""
Expand All @@ -90,7 +121,7 @@ def default_optimizer(cls, multilabel: bool, seed: int = 42) -> "Pipeline":
"""
return cls.from_search_space(search_space=load_default_search_space(multilabel), seed=seed)

def _fit(self, context: Context, sampler: SamplerType = "brute") -> None:
def _fit(self, context: Context, sampler: SamplerType) -> None:
"""
Optimize the pipeline.

Expand All @@ -99,7 +130,7 @@ def _fit(self, context: Context, sampler: SamplerType = "brute") -> None:
self.context = context
self._logger.info("starting pipeline optimization...")
self.context.callback_handler.start_run(
run_name=self.context.logging_config.run_name,
run_name=self.context.logging_config.get_run_name(),
dirpath=self.context.logging_config.dirpath,
)
for node_type in NodeType:
Expand All @@ -123,7 +154,7 @@ def fit(
self,
dataset: Dataset,
refit_after: bool = False,
sampler: SamplerType = "brute",
sampler: SamplerType | None = None,
) -> Context:
"""
Optimize the pipeline from dataset.
Expand All @@ -148,6 +179,9 @@ def fit(
"Test data is not provided. Final test metrics won't be calculated after pipeline optimization."
)

if sampler is None:
sampler = self.sampler or "brute"

self._fit(context, sampler)

if context.is_ram_to_clear():
Expand Down
2 changes: 0 additions & 2 deletions autointent/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from ._optimization import (
DataConfig,
LoggingConfig,
TaskConfig,
VectorIndexConfig,
)

Expand All @@ -13,6 +12,5 @@
"InferenceNodeConfig",
"InferenceNodeConfig",
"LoggingConfig",
"TaskConfig",
"VectorIndexConfig",
]
66 changes: 32 additions & 34 deletions autointent/configs/_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,78 +2,76 @@

from pathlib import Path

from pydantic import BaseModel, Field, PositiveInt, field_validator
from pydantic import BaseModel, Field, PositiveInt

from autointent._callbacks import REPORTERS_NAMES
from autointent.custom_types import FloatFromZeroToOne, SamplerType, ValidationScheme
from autointent.custom_types import FloatFromZeroToOne, ValidationScheme

from ._name import get_run_name


class DataConfig(BaseModel):
"""Configuration for the data used in the optimization process."""

scheme: ValidationScheme = "ho"
scheme: ValidationScheme = Field("ho", description="Validation scheme to use.")
"""Hold-out or cross-validation."""
n_folds: PositiveInt = 3
n_folds: PositiveInt = Field(3, description="Number of folds in cross-validation.")
"""Number of folds in cross-validation."""
validation_size: FloatFromZeroToOne = 0.2
validation_size: FloatFromZeroToOne = Field(
0.2,
description=(
"Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split)."
),
)
"""Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split)."""
separation_ratio: FloatFromZeroToOne | None = 0.5
separation_ratio: FloatFromZeroToOne | None = Field(
0.5, description="Set to float to prevent data leak between scoring and decision nodes."
)
"""Set to float to prevent data leak between scoring and decision nodes."""


class TaskConfig(BaseModel):
"""Configuration for the task to optimize."""

search_space_path: Path | None = None
"""Path to the search space configuration file. If None, the default search space will be used"""
sampler: SamplerType = "brute"


class LoggingConfig(BaseModel):
"""Configuration for the logging."""

project_dir: Path = Field(default_factory=lambda: Path.cwd() / "runs")
_dirpath: Path | None = None
_dump_dir: Path | None = None

project_dir: Path | str | None = Field(None, description="Path to the directory with different runs.")
"""Path to the directory with different runs."""
run_name: str = Field(default_factory=get_run_name)
run_name: str | None = Field(None, description="Name of the run. If None, a random name will be generated.")
"""Name of the run. If None, a random name will be generated"""
dump_modules: bool = False
dump_modules: bool = Field(False, description="Whether to dump the modules or not")
"""Whether to dump the modules or not"""
clear_ram: bool = False
clear_ram: bool = Field(False, description="Whether to clear the RAM after dumping the modules")
"""Whether to clear the RAM after dumping the modules"""
report_to: list[str] | None = None
report_to: list[REPORTERS_NAMES] | None = Field( # type: ignore[valid-type]
None, description="List of callbacks to report to. If None, no callbacks will be used"
)
"""List of callbacks to report to. If None, no callbacks will be used"""

@property
def dirpath(self) -> Path:
"""Path to the directory where the logs will be saved."""
if not hasattr(self, "_dirpath"):
self._dirpath = self.project_dir / self.run_name
if self._dirpath is None:
project_dir = Path.cwd() / "runs" if self.project_dir is None else Path(self.project_dir)
self._dirpath = project_dir / self.get_run_name()
return self._dirpath

@property
def dump_dir(self) -> Path:
"""Path to the directory where the modules will be dumped."""
if not hasattr(self, "_dump_dir"):
if self._dump_dir is None:
self._dump_dir = self.dirpath / "modules_dumps"
return self._dump_dir

@field_validator("report_to")
@classmethod
def validate_report_to(cls, v: list[str] | None) -> list[str] | None:
"""Validate the report_to field."""
if v is None:
return None
for reporter in v:
if reporter not in REPORTERS_NAMES:
msg = f"Reporter {reporter} is not supported. Supported reporters: {REPORTERS_NAMES}"
raise ValueError(msg)
return v
def get_run_name(self) -> str:
if self.run_name is None:
self.run_name = get_run_name()
return self.run_name


class VectorIndexConfig(BaseModel):
"""Configuration for the vector index."""

save_db: bool = False
save_db: bool = Field(False, description="Whether to save the vector index database or not")
"""Whether to save the vector index database or not"""
4 changes: 2 additions & 2 deletions autointent/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from ._inference_node import InferenceNode
from ._optimization import NodeOptimizer
from .schemes import OptimizationConfig
from .schemes import OptimizationSearchSpaceConfig

__all__ = [
"InferenceNode",
"NodeOptimizer",
"OptimizationConfig",
"OptimizationSearchSpaceConfig",
]
23 changes: 21 additions & 2 deletions autointent/nodes/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from pydantic import BaseModel, Field, PositiveInt, RootModel

from autointent.custom_types import NodeType
from autointent.configs import DataConfig, LoggingConfig, VectorIndexConfig
from autointent.custom_types import NodeType, SamplerType
from autointent.modules.abc import BaseModule
from autointent.nodes._optimization._node_optimizer import ParamSpaceFloat, ParamSpaceInt
from autointent.nodes.info import DecisionNodeInfo, EmbeddingNodeInfo, RegexNodeInfo, ScoringNodeInfo
Expand Down Expand Up @@ -160,7 +161,7 @@ class RegexNodeValidator(BaseModel):
SearchSpaceTypes: TypeAlias = EmbeddingNodeValidator | ScoringNodeValidator | DecisionNodeValidator | RegexNodeValidator


class OptimizationConfig(RootModel[list[SearchSpaceTypes]]):
class OptimizationSearchSpaceConfig(RootModel[list[SearchSpaceTypes]]):
"""Optimizer configuration."""

def __iter__(
Expand All @@ -178,3 +179,21 @@ def __getitem__(self, item: int) -> SearchSpaceTypes:
:return: Item
"""
return self.root[item]


class TaskConfig(BaseModel):
"""Configuration for the task to optimize."""

search_space: OptimizationSearchSpaceConfig
"""Path to the search space configuration file. If None, the default search space will be used"""
sampler: SamplerType = "brute"


class OptimizationConfig(BaseModel):
"""Configuration for the optimization process."""

data_config: DataConfig = DataConfig()
task_config: TaskConfig
logging_config: LoggingConfig = LoggingConfig()
vector_index_config: VectorIndexConfig = VectorIndexConfig()
seed: PositiveInt = 42
Comment on lines +184 to +199
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

это конфиги и они должны быть в autointent.configs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Могу сюда все конфиги тогда перенести

Loading
Loading