Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion autointent/_callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

REPORTERS = {cb.name: cb for cb in [WandbCallback, TensorBoardCallback]}

REPORTERS_NAMES = list(REPORTERS.keys())
REPORTERS_NAMES = Literal[tuple(REPORTERS.keys())] # type: ignore[valid-type]


def get_callbacks(reporters: list[str] | None) -> CallbackHandler:
Expand Down
44 changes: 39 additions & 5 deletions autointent/_pipeline/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, get_args

import numpy as np
import yaml
Expand All @@ -13,7 +13,7 @@
from autointent.custom_types import ListOfGenericLabels, NodeType, SamplerType
from autointent.metrics import DECISION_METRICS
from autointent.nodes import InferenceNode, NodeOptimizer
from autointent.nodes.schemes import OptimizationConfig
from autointent.nodes.schemes import OptimizationConfig, OptimizationSearchSpaceConfig
from autointent.utils import load_default_search_space, load_search_space

from ._schemas import InferencePipelineOutput, InferencePipelineUtteranceOutput
Expand All @@ -28,17 +28,24 @@ class Pipeline:
def __init__(
self,
nodes: list[NodeOptimizer] | list[InferenceNode],
sampler: SamplerType = "brute",
seed: int = 42,
) -> None:
"""
Initialize the pipeline optimizer.

:param nodes: list of nodes
:param sampler: sampler type
:param seed: random seed
"""
self._logger = logging.getLogger(__name__)
self.nodes = {node.node_type: node for node in nodes}
self.seed = seed
if sampler not in get_args(SamplerType):
msg = f"Sampler should be one of {get_args(SamplerType)}"
raise ValueError(msg)

self.sampler = sampler

if isinstance(nodes[0], NodeOptimizer):
self.logging_config = LoggingConfig(dump_dir=None)
Expand Down Expand Up @@ -74,10 +81,34 @@ def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed
"""
if isinstance(search_space, Path | str):
search_space = load_search_space(search_space)
validated_search_space = OptimizationConfig(search_space).model_dump() # type: ignore[arg-type]
validated_search_space = OptimizationSearchSpaceConfig(search_space).model_dump() # type: ignore[arg-type]
nodes = [NodeOptimizer(**node) for node in validated_search_space]
return cls(nodes=nodes, seed=seed)

@classmethod
def from_optimization_config(cls, config: dict[str, Any] | Path | str) -> "Pipeline":
"""
Create pipeline optimizer from optimization config.

:param config: Optimization config
:return:
"""
if isinstance(config, Path | str):
with Path(config).open() as file:
loaded_config = yaml.safe_load(file)
else:
loaded_config = config
optimization_config = OptimizationConfig(**loaded_config)
pipeline = cls(
[NodeOptimizer(**node.model_dump()) for node in optimization_config.task_config.search_space],
optimization_config.task_config.sampler,
optimization_config.seed,
)
pipeline.set_config(optimization_config.logging_config)
pipeline.set_config(optimization_config.vector_index_config)
pipeline.set_config(optimization_config.data_config)
return pipeline

@classmethod
def default_optimizer(cls, multilabel: bool, seed: int = 42) -> "Pipeline":
"""
Expand All @@ -90,7 +121,7 @@ def default_optimizer(cls, multilabel: bool, seed: int = 42) -> "Pipeline":
"""
return cls.from_search_space(search_space=load_default_search_space(multilabel), seed=seed)

def _fit(self, context: Context, sampler: SamplerType = "brute") -> None:
def _fit(self, context: Context, sampler: SamplerType) -> None:
"""
Optimize the pipeline.

Expand Down Expand Up @@ -123,7 +154,7 @@ def fit(
self,
dataset: Dataset,
refit_after: bool = False,
sampler: SamplerType = "brute",
sampler: SamplerType | None = None,
) -> Context:
"""
Optimize the pipeline from dataset.
Expand All @@ -148,6 +179,9 @@ def fit(
"Test data is not provided. Final test metrics won't be calculated after pipeline optimization."
)

if sampler is None:
sampler = self.sampler or "brute"

self._fit(context, sampler)

if context.is_ram_to_clear():
Expand Down
2 changes: 0 additions & 2 deletions autointent/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from ._optimization import (
DataConfig,
LoggingConfig,
TaskConfig,
VectorIndexConfig,
)

Expand All @@ -13,6 +12,5 @@
"InferenceNodeConfig",
"InferenceNodeConfig",
"LoggingConfig",
"TaskConfig",
"VectorIndexConfig",
]
30 changes: 5 additions & 25 deletions autointent/configs/_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from pathlib import Path

from pydantic import BaseModel, Field, PositiveInt, field_validator
from pydantic import BaseModel, Field, PositiveInt

from autointent._callbacks import REPORTERS_NAMES
from autointent.custom_types import FloatFromZeroToOne, SamplerType, ValidationScheme
from autointent.custom_types import FloatFromZeroToOne, ValidationScheme

from ._name import get_run_name

Expand All @@ -23,33 +23,25 @@ class DataConfig(BaseModel):
"""Set to float to prevent data leak between scoring and decision nodes."""


class TaskConfig(BaseModel):
"""Configuration for the task to optimize."""

search_space_path: Path | None = None
"""Path to the search space configuration file. If None, the default search space will be used"""
sampler: SamplerType = "brute"


class LoggingConfig(BaseModel):
"""Configuration for the logging."""

project_dir: Path = Field(default_factory=lambda: Path.cwd() / "runs")
project_dir: Path | str = Field(default_factory=lambda: Path.cwd() / "runs")
"""Path to the directory with different runs."""
run_name: str = Field(default_factory=get_run_name)
"""Name of the run. If None, a random name will be generated"""
dump_modules: bool = False
"""Whether to dump the modules or not"""
clear_ram: bool = False
"""Whether to clear the RAM after dumping the modules"""
report_to: list[str] | None = None
report_to: list[REPORTERS_NAMES] | None = None # type: ignore[valid-type]
"""List of callbacks to report to. If None, no callbacks will be used"""

@property
def dirpath(self) -> Path:
"""Path to the directory where the logs will be saved."""
if not hasattr(self, "_dirpath"):
self._dirpath = self.project_dir / self.run_name
self._dirpath = Path(self.project_dir) / self.run_name
return self._dirpath

@property
Expand All @@ -59,18 +51,6 @@ def dump_dir(self) -> Path:
self._dump_dir = self.dirpath / "modules_dumps"
return self._dump_dir

@field_validator("report_to")
@classmethod
def validate_report_to(cls, v: list[str] | None) -> list[str] | None:
"""Validate the report_to field."""
if v is None:
return None
for reporter in v:
if reporter not in REPORTERS_NAMES:
msg = f"Reporter {reporter} is not supported. Supported reporters: {REPORTERS_NAMES}"
raise ValueError(msg)
return v


class VectorIndexConfig(BaseModel):
"""Configuration for the vector index."""
Expand Down
4 changes: 2 additions & 2 deletions autointent/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from ._inference_node import InferenceNode
from ._optimization import NodeOptimizer
from .schemes import OptimizationConfig
from .schemes import OptimizationSearchSpaceConfig

__all__ = [
"InferenceNode",
"NodeOptimizer",
"OptimizationConfig",
"OptimizationSearchSpaceConfig",
]
23 changes: 21 additions & 2 deletions autointent/nodes/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from pydantic import BaseModel, Field, PositiveInt, RootModel

from autointent.custom_types import NodeType
from autointent.configs import DataConfig, LoggingConfig, VectorIndexConfig
from autointent.custom_types import NodeType, SamplerType
from autointent.modules.abc import BaseModule
from autointent.nodes._optimization._node_optimizer import ParamSpaceFloat, ParamSpaceInt
from autointent.nodes.info import DecisionNodeInfo, EmbeddingNodeInfo, RegexNodeInfo, ScoringNodeInfo
Expand Down Expand Up @@ -160,7 +161,7 @@ class RegexNodeValidator(BaseModel):
SearchSpaceTypes: TypeAlias = EmbeddingNodeValidator | ScoringNodeValidator | DecisionNodeValidator | RegexNodeValidator


class OptimizationConfig(RootModel[list[SearchSpaceTypes]]):
class OptimizationSearchSpaceConfig(RootModel[list[SearchSpaceTypes]]):
"""Optimizer configuration."""

def __iter__(
Expand All @@ -178,3 +179,21 @@ def __getitem__(self, item: int) -> SearchSpaceTypes:
:return: Item
"""
return self.root[item]


class TaskConfig(BaseModel):
"""Configuration for the task to optimize."""

search_space: OptimizationSearchSpaceConfig
"""Path to the search space configuration file. If None, the default search space will be used"""
sampler: SamplerType = "brute"


class OptimizationConfig(BaseModel):
"""Configuration for the optimization process."""

data_config: DataConfig = DataConfig()
task_config: TaskConfig
logging_config: LoggingConfig = LoggingConfig()
vector_index_config: VectorIndexConfig = VectorIndexConfig()
seed: PositiveInt = 42
Comment on lines +184 to +199
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

это конфиги и они должны быть в autointent.configs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Могу сюда все конфиги тогда перенести

Loading