Skip to content
87 changes: 64 additions & 23 deletions autointent/_pipeline/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ._schemas import InferencePipelineOutput, InferencePipelineUtteranceOutput

if TYPE_CHECKING:
from autointent.modules.base import BaseDecision, BaseScorer
from autointent.modules.base import BaseDecision, BaseRegex, BaseScorer


class Pipeline:
Expand All @@ -41,7 +41,7 @@ def __init__(
self,
nodes: list[NodeOptimizer] | list[InferenceNode],
sampler: SamplerType = "brute",
seed: int = 42,
seed: int | None = 42,
) -> None:
"""Initialize the pipeline optimizer.

Expand Down Expand Up @@ -85,7 +85,7 @@ def set_config(self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig
assert_never(config)

@classmethod
def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed: int = 42) -> "Pipeline":
def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed: int | None = 42) -> "Pipeline":
"""Search space to pipeline optimizer.

Args:
Expand All @@ -101,7 +101,7 @@ def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed
return cls(nodes=nodes, seed=seed)

@classmethod
def from_preset(cls, name: SearchSpacePresets, seed: int = 42) -> "Pipeline":
def from_preset(cls, name: SearchSpacePresets, seed: int | None = 42) -> "Pipeline":
optimization_config = load_preset(name)
config = OptimizationConfig(seed=seed, **optimization_config)
return cls.from_optimization_config(config=config)
Expand Down Expand Up @@ -186,7 +186,7 @@ def fit(
msg = "Pipeline in inference mode cannot be fitted"
raise RuntimeError(msg)

context = Context()
context = Context(self.seed)
context.set_dataset(dataset, self.data_config)
context.configure_logging(self.logging_config)
context.configure_transformer(self.embedder_config)
Expand All @@ -199,25 +199,43 @@ def fit(
self._logger.warning(
"Test data is not provided. Final test metrics won't be calculated after pipeline optimization."
)
elif context.logging_config.clear_ram and not context.logging_config.dump_modules:
self._logger.warning(
"Test data is provided, but final metrics won't be calculated "
"because fitted modules won't be saved neither in RAM nor in file system."
"Change settings in LoggerConfig to obtain different behavior."
)

if sampler is None:
sampler = self.sampler

self._fit(context, sampler)

if context.is_ram_to_clear():
if context.logging_config.clear_ram and context.logging_config.dump_modules:
nodes_configs = context.optimization_info.get_inference_nodes_config()
nodes_list = [InferenceNode.from_config(cfg) for cfg in nodes_configs]
else:
elif not context.logging_config.clear_ram:
modules_dict = context.optimization_info.get_best_modules()
nodes_list = [InferenceNode(module, node_type) for node_type, module in modules_dict.items()]
else:
self._logger.info(
"Skipping calculating final metrics because fitted modules weren't saved."
"Change settings in LoggerConfig to obtain different behavior."
)
return context

self.nodes = {node.node_type: node for node in nodes_list}
self.nodes = {node.node_type: node for node in nodes_list if node.node_type != NodeType.embedding}

if refit_after:
# TODO reflect this refitting in dumped version of pipeline
self._refit(context)

self._nodes_configs: dict[str, InferenceNodeConfig] = {
NodeType(cfg.node_type): cfg
for cfg in context.optimization_info.get_inference_nodes_config()
if cfg.node_type != NodeType.embedding
}
self._dump_dir = context.logging_config.dirpath

if test_utterances is not None:
predictions = self.predict(test_utterances)
for metric_name, metric in DECISION_METRICS.items():
Expand All @@ -229,6 +247,41 @@ def fit(

return context

def dump(self, path: str | Path | None = None) -> None:
if isinstance(path, str):
path = Path(path)
elif path is None:
if hasattr(self, "_dump_dir"):
path = self._dump_dir
else:
msg = (
"Either you didn't trained the pipeline yet or fitted modules weren't saved during optimization. "
"Change settings in LoggerConfig and retrain the pipeline to obtain different behavior."
)
self._logger.error(msg)
raise RuntimeError(msg)

scoring_module: BaseScorer = self.nodes[NodeType.scoring].module # type: ignore[assignment,union-attr]
decision_module: BaseDecision = self.nodes[NodeType.decision].module # type: ignore[assignment,union-attr]

scoring_dump_dir = str(path / "scoring_module")
decision_dump_dir = str(path / "decision_module")
scoring_module.dump(scoring_dump_dir)
decision_module.dump(decision_dump_dir)

self._nodes_configs[NodeType.scoring].load_path = scoring_dump_dir
self._nodes_configs[NodeType.decision].load_path = decision_dump_dir

if NodeType.regex in self.nodes:
regex_module: BaseRegex = self.nodes[NodeType.regex].module # type: ignore[assignment,union-attr]
regex_dump_dir = str(path / "regex_module")
regex_module.dump(regex_dump_dir)
self._nodes_configs[NodeType.regex].load_path = regex_dump_dir

inference_nodes_configs = [cfg.asdict() for cfg in self._nodes_configs.values()]
with (path / "inference_config.yaml").open("w") as file:
yaml.dump(inference_nodes_configs, file)

def validate_modules(self, dataset: Dataset, mode: SearchSpaceValidationMode) -> None:
"""Validate modules with dataset.

Expand All @@ -240,18 +293,6 @@ def validate_modules(self, dataset: Dataset, mode: SearchSpaceValidationMode) ->
if isinstance(node, NodeOptimizer):
node.validate_nodes_with_dataset(dataset, mode)

@classmethod
def from_dict_config(cls, nodes_configs: list[dict[str, Any]]) -> "Pipeline":
"""Create inference pipeline from dictionary config.

Args:
nodes_configs: list of config for nodes

Returns:
Inference pipeline
"""
return cls.from_config([InferenceNodeConfig(**cfg) for cfg in nodes_configs])

@classmethod
def from_config(cls, nodes_configs: list[InferenceNodeConfig]) -> "Pipeline":
"""Create inference pipeline from config.
Expand Down Expand Up @@ -283,13 +324,13 @@ def load(
Inference pipeline
"""
with (Path(path) / "inference_config.yaml").open() as file:
inference_dict_config: dict[str, Any] = yaml.safe_load(file)
inference_nodes_configs: list[dict[str, Any]] = yaml.safe_load(file)

inference_config = [
InferenceNodeConfig(
**node_config, embedder_config=embedder_config, cross_encoder_config=cross_encoder_config
)
for node_config in inference_dict_config["nodes_configs"]
for node_config in inference_nodes_configs
]

return cls.from_config(inference_config)
Expand Down
14 changes: 13 additions & 1 deletion autointent/configs/_inference_node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Configuration for the nodes."""

from dataclasses import dataclass
from dataclasses import asdict, dataclass
from typing import Any

from autointent.custom_types import NodeType
Expand All @@ -24,3 +24,15 @@ class InferenceNodeConfig:
"""One can override presaved embedder config while loading from file system."""
cross_encoder_config: CrossEncoderConfig | None = None
"""One can override presaved cross encoder config while loading from file system."""

def asdict(self) -> dict[str, Any]:
res = asdict(self)
if self.embedder_config is not None:
res["embedder_config"] = self.embedder_config.model_dump()
else:
res.pop("embedder_config")
if self.cross_encoder_config is not None:
res["cross_encoder_config"] = self.cross_encoder_config.model_dump()
else:
res.pop("cross_encoder_config")
return res
21 changes: 2 additions & 19 deletions autointent/context/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import json
import logging
from pathlib import Path
from typing import Any

import yaml

Expand Down Expand Up @@ -32,7 +31,7 @@ class Context:
optimization_info: OptimizationInfo
callback_handler = CallbackHandler()

def __init__(self, seed: int = 42) -> None:
def __init__(self, seed: int | None = 42) -> None:
"""Initialize the Context object.

Args:
Expand Down Expand Up @@ -71,22 +70,6 @@ def set_dataset(self, dataset: Dataset, config: DataConfig) -> None:
"""
self.data_handler = DataHandler(dataset=dataset, random_seed=self.seed, config=config)

def get_inference_config(self) -> dict[str, Any]:
"""Generate configuration settings for inference.

Returns:
Dictionary containing inference configuration.
"""
nodes_configs = self.optimization_info.get_inference_nodes_config(asdict=True)
return {
"metadata": {
"multilabel": self.is_multilabel(),
"n_classes": self.get_n_classes(),
"seed": self.seed,
},
"nodes_configs": nodes_configs,
}

def dump(self) -> None:
"""Save logs, configurations, and datasets to disk."""
self._logger.debug("dumping logs...")
Expand All @@ -103,7 +86,7 @@ def dump(self) -> None:

self._logger.info("logs and other assets are saved to %s", logs_dir)

inference_config = self.get_inference_config()
inference_config = self.optimization_info.get_inference_nodes_config(asdict=True)
inference_config_path = logs_dir / "inference_config.yaml"
with inference_config_path.open("w") as file:
yaml.dump(inference_config, file)
Expand Down
32 changes: 5 additions & 27 deletions autointent/context/data_handler/_data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from collections.abc import Generator
from typing import TypedDict, cast
from typing import cast

from datasets import concatenate_datasets
from transformers import set_seed
Expand All @@ -16,28 +16,14 @@
logger = logging.getLogger(__name__)


class RegexPatterns(TypedDict):
"""Regex patterns for each intent class.

Attributes:
id: Intent class id.
regex_full_match: Full match regex patterns.
regex_partial_match: Partial match regex patterns.
"""

id: int
regex_full_match: list[str]
regex_partial_match: list[str]


class DataHandler: # TODO rename to Validator
class DataHandler:
"""Data handler class."""

def __init__(
self,
dataset: Dataset,
config: DataConfig | None = None,
random_seed: int = 0,
random_seed: int | None = 0,
) -> None:
"""Initialize the data handler.

Expand All @@ -46,7 +32,8 @@ def __init__(
config: Configuration object
random_seed: Seed for random number generation.
"""
set_seed(random_seed)
if random_seed is not None:
set_seed(random_seed)
self.random_seed = random_seed

self.dataset = dataset
Expand All @@ -59,15 +46,6 @@ def __init__(
elif self.config.scheme == "cv":
self._split_cv()

self.regex_patterns = [
RegexPatterns(
id=intent.id,
regex_full_match=intent.regex_full_match,
regex_partial_match=intent.regex_partial_match,
)
for intent in self.dataset.intents
]

self.intent_descriptions = [intent.description for intent in self.dataset.intents]
self.tags = self.dataset.get_tags()

Expand Down
4 changes: 2 additions & 2 deletions autointent/context/data_handler/_stratification.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
self,
test_size: float,
label_feature: str,
random_seed: int,
random_seed: int | None,
shuffle: bool = True,
) -> None:
"""Initialize the StratifiedSplitter.
Expand Down Expand Up @@ -283,7 +283,7 @@ def split_dataset(
dataset: Dataset,
split: str,
test_size: float,
random_seed: int,
random_seed: int | None,
allow_oos_in_train: bool | None = None,
) -> tuple[HFDataset, HFDataset]:
"""Split a Dataset object into training and testing subsets.
Expand Down
2 changes: 1 addition & 1 deletion autointent/metrics/regex.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Regex metrics for intent recognition."""
"""Metrics for regex modules."""

from typing import Protocol

Expand Down
4 changes: 2 additions & 2 deletions autointent/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
TunableDecision,
)
from .embedding import LogregAimedEmbedding, RetrievalAimedEmbedding
from .regex import Regex
from .regex import SimpleRegex
from .scoring import DescriptionScorer, DNNCScorer, KNNScorer, LinearScorer, MLKnnScorer, RerankScorer, SklearnScorer

T = TypeVar("T", bound=BaseModule)
Expand All @@ -21,7 +21,7 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
return {module.name: module for module in modules}


REGEX_MODULES: dict[str, type[BaseRegex]] = _create_modules_dict([Regex])
REGEX_MODULES: dict[str, type[BaseRegex]] = _create_modules_dict([SimpleRegex])

EMBEDDING_MODULES: dict[str, type[BaseEmbedding]] = _create_modules_dict(
[RetrievalAimedEmbedding, LogregAimedEmbedding]
Expand Down
4 changes: 2 additions & 2 deletions autointent/modules/decision/_tunable.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
self,
target_metric: MetricType = "decision_accuracy",
n_optuna_trials: PositiveInt = 320,
seed: int = 0,
seed: int | None = 0,
tags: list[Tag] | None = None,
) -> None:
"""Initialize tunable predictor.
Expand Down Expand Up @@ -222,7 +222,7 @@ def fit(
self,
probas: npt.NDArray[Any],
labels: ListOfGenericLabels,
seed: int,
seed: int | None,
tags: list[Tag] | None = None,
) -> None:
"""Fit the optimizer by finding optimal thresholds.
Expand Down
4 changes: 2 additions & 2 deletions autointent/modules/regex/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from ._simple import Regex
from ._simple import SimpleRegex

__all__ = ["Regex"]
__all__ = ["SimpleRegex"]
Loading
Loading