diff --git a/autointent/_pipeline/_pipeline.py b/autointent/_pipeline/_pipeline.py index 0549ecb05..0e9712395 100644 --- a/autointent/_pipeline/_pipeline.py +++ b/autointent/_pipeline/_pipeline.py @@ -31,7 +31,7 @@ from ._schemas import InferencePipelineOutput, InferencePipelineUtteranceOutput if TYPE_CHECKING: - from autointent.modules.base import BaseDecision, BaseScorer + from autointent.modules.base import BaseDecision, BaseRegex, BaseScorer class Pipeline: @@ -41,7 +41,7 @@ def __init__( self, nodes: list[NodeOptimizer] | list[InferenceNode], sampler: SamplerType = "brute", - seed: int = 42, + seed: int | None = 42, ) -> None: """Initialize the pipeline optimizer. @@ -85,7 +85,7 @@ def set_config(self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig assert_never(config) @classmethod - def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed: int = 42) -> "Pipeline": + def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed: int | None = 42) -> "Pipeline": """Search space to pipeline optimizer. Args: @@ -101,7 +101,7 @@ def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed return cls(nodes=nodes, seed=seed) @classmethod - def from_preset(cls, name: SearchSpacePresets, seed: int = 42) -> "Pipeline": + def from_preset(cls, name: SearchSpacePresets, seed: int | None = 42) -> "Pipeline": optimization_config = load_preset(name) config = OptimizationConfig(seed=seed, **optimization_config) return cls.from_optimization_config(config=config) @@ -186,7 +186,7 @@ def fit( msg = "Pipeline in inference mode cannot be fitted" raise RuntimeError(msg) - context = Context() + context = Context(self.seed) context.set_dataset(dataset, self.data_config) context.configure_logging(self.logging_config) context.configure_transformer(self.embedder_config) @@ -199,25 +199,43 @@ def fit( self._logger.warning( "Test data is not provided. Final test metrics won't be calculated after pipeline optimization." ) + elif context.logging_config.clear_ram and not context.logging_config.dump_modules: + self._logger.warning( + "Test data is provided, but final metrics won't be calculated " + "because fitted modules won't be saved neither in RAM nor in file system." + "Change settings in LoggerConfig to obtain different behavior." + ) if sampler is None: sampler = self.sampler self._fit(context, sampler) - if context.is_ram_to_clear(): + if context.logging_config.clear_ram and context.logging_config.dump_modules: nodes_configs = context.optimization_info.get_inference_nodes_config() nodes_list = [InferenceNode.from_config(cfg) for cfg in nodes_configs] - else: + elif not context.logging_config.clear_ram: modules_dict = context.optimization_info.get_best_modules() nodes_list = [InferenceNode(module, node_type) for node_type, module in modules_dict.items()] + else: + self._logger.info( + "Skipping calculating final metrics because fitted modules weren't saved." + "Change settings in LoggerConfig to obtain different behavior." + ) + return context - self.nodes = {node.node_type: node for node in nodes_list} + self.nodes = {node.node_type: node for node in nodes_list if node.node_type != NodeType.embedding} if refit_after: - # TODO reflect this refitting in dumped version of pipeline self._refit(context) + self._nodes_configs: dict[str, InferenceNodeConfig] = { + NodeType(cfg.node_type): cfg + for cfg in context.optimization_info.get_inference_nodes_config() + if cfg.node_type != NodeType.embedding + } + self._dump_dir = context.logging_config.dirpath + if test_utterances is not None: predictions = self.predict(test_utterances) for metric_name, metric in DECISION_METRICS.items(): @@ -229,6 +247,41 @@ def fit( return context + def dump(self, path: str | Path | None = None) -> None: + if isinstance(path, str): + path = Path(path) + elif path is None: + if hasattr(self, "_dump_dir"): + path = self._dump_dir + else: + msg = ( + "Either you didn't trained the pipeline yet or fitted modules weren't saved during optimization. " + "Change settings in LoggerConfig and retrain the pipeline to obtain different behavior." + ) + self._logger.error(msg) + raise RuntimeError(msg) + + scoring_module: BaseScorer = self.nodes[NodeType.scoring].module # type: ignore[assignment,union-attr] + decision_module: BaseDecision = self.nodes[NodeType.decision].module # type: ignore[assignment,union-attr] + + scoring_dump_dir = str(path / "scoring_module") + decision_dump_dir = str(path / "decision_module") + scoring_module.dump(scoring_dump_dir) + decision_module.dump(decision_dump_dir) + + self._nodes_configs[NodeType.scoring].load_path = scoring_dump_dir + self._nodes_configs[NodeType.decision].load_path = decision_dump_dir + + if NodeType.regex in self.nodes: + regex_module: BaseRegex = self.nodes[NodeType.regex].module # type: ignore[assignment,union-attr] + regex_dump_dir = str(path / "regex_module") + regex_module.dump(regex_dump_dir) + self._nodes_configs[NodeType.regex].load_path = regex_dump_dir + + inference_nodes_configs = [cfg.asdict() for cfg in self._nodes_configs.values()] + with (path / "inference_config.yaml").open("w") as file: + yaml.dump(inference_nodes_configs, file) + def validate_modules(self, dataset: Dataset, mode: SearchSpaceValidationMode) -> None: """Validate modules with dataset. @@ -240,18 +293,6 @@ def validate_modules(self, dataset: Dataset, mode: SearchSpaceValidationMode) -> if isinstance(node, NodeOptimizer): node.validate_nodes_with_dataset(dataset, mode) - @classmethod - def from_dict_config(cls, nodes_configs: list[dict[str, Any]]) -> "Pipeline": - """Create inference pipeline from dictionary config. - - Args: - nodes_configs: list of config for nodes - - Returns: - Inference pipeline - """ - return cls.from_config([InferenceNodeConfig(**cfg) for cfg in nodes_configs]) - @classmethod def from_config(cls, nodes_configs: list[InferenceNodeConfig]) -> "Pipeline": """Create inference pipeline from config. @@ -283,13 +324,13 @@ def load( Inference pipeline """ with (Path(path) / "inference_config.yaml").open() as file: - inference_dict_config: dict[str, Any] = yaml.safe_load(file) + inference_nodes_configs: list[dict[str, Any]] = yaml.safe_load(file) inference_config = [ InferenceNodeConfig( **node_config, embedder_config=embedder_config, cross_encoder_config=cross_encoder_config ) - for node_config in inference_dict_config["nodes_configs"] + for node_config in inference_nodes_configs ] return cls.from_config(inference_config) diff --git a/autointent/configs/_inference_node.py b/autointent/configs/_inference_node.py index 09fe1ed47..e63ae74ae 100644 --- a/autointent/configs/_inference_node.py +++ b/autointent/configs/_inference_node.py @@ -1,6 +1,6 @@ """Configuration for the nodes.""" -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import Any from autointent.custom_types import NodeType @@ -24,3 +24,15 @@ class InferenceNodeConfig: """One can override presaved embedder config while loading from file system.""" cross_encoder_config: CrossEncoderConfig | None = None """One can override presaved cross encoder config while loading from file system.""" + + def asdict(self) -> dict[str, Any]: + res = asdict(self) + if self.embedder_config is not None: + res["embedder_config"] = self.embedder_config.model_dump() + else: + res.pop("embedder_config") + if self.cross_encoder_config is not None: + res["cross_encoder_config"] = self.cross_encoder_config.model_dump() + else: + res.pop("cross_encoder_config") + return res diff --git a/autointent/context/_context.py b/autointent/context/_context.py index 9fb21f7a2..53b83859d 100644 --- a/autointent/context/_context.py +++ b/autointent/context/_context.py @@ -3,7 +3,6 @@ import json import logging from pathlib import Path -from typing import Any import yaml @@ -32,7 +31,7 @@ class Context: optimization_info: OptimizationInfo callback_handler = CallbackHandler() - def __init__(self, seed: int = 42) -> None: + def __init__(self, seed: int | None = 42) -> None: """Initialize the Context object. Args: @@ -71,22 +70,6 @@ def set_dataset(self, dataset: Dataset, config: DataConfig) -> None: """ self.data_handler = DataHandler(dataset=dataset, random_seed=self.seed, config=config) - def get_inference_config(self) -> dict[str, Any]: - """Generate configuration settings for inference. - - Returns: - Dictionary containing inference configuration. - """ - nodes_configs = self.optimization_info.get_inference_nodes_config(asdict=True) - return { - "metadata": { - "multilabel": self.is_multilabel(), - "n_classes": self.get_n_classes(), - "seed": self.seed, - }, - "nodes_configs": nodes_configs, - } - def dump(self) -> None: """Save logs, configurations, and datasets to disk.""" self._logger.debug("dumping logs...") @@ -103,7 +86,7 @@ def dump(self) -> None: self._logger.info("logs and other assets are saved to %s", logs_dir) - inference_config = self.get_inference_config() + inference_config = self.optimization_info.get_inference_nodes_config(asdict=True) inference_config_path = logs_dir / "inference_config.yaml" with inference_config_path.open("w") as file: yaml.dump(inference_config, file) diff --git a/autointent/context/data_handler/_data_handler.py b/autointent/context/data_handler/_data_handler.py index 6a7291b5a..7e3b1f15a 100644 --- a/autointent/context/data_handler/_data_handler.py +++ b/autointent/context/data_handler/_data_handler.py @@ -2,7 +2,7 @@ import logging from collections.abc import Generator -from typing import TypedDict, cast +from typing import cast from datasets import concatenate_datasets from transformers import set_seed @@ -16,28 +16,14 @@ logger = logging.getLogger(__name__) -class RegexPatterns(TypedDict): - """Regex patterns for each intent class. - - Attributes: - id: Intent class id. - regex_full_match: Full match regex patterns. - regex_partial_match: Partial match regex patterns. - """ - - id: int - regex_full_match: list[str] - regex_partial_match: list[str] - - -class DataHandler: # TODO rename to Validator +class DataHandler: """Data handler class.""" def __init__( self, dataset: Dataset, config: DataConfig | None = None, - random_seed: int = 0, + random_seed: int | None = 0, ) -> None: """Initialize the data handler. @@ -46,7 +32,8 @@ def __init__( config: Configuration object random_seed: Seed for random number generation. """ - set_seed(random_seed) + if random_seed is not None: + set_seed(random_seed) self.random_seed = random_seed self.dataset = dataset @@ -59,15 +46,6 @@ def __init__( elif self.config.scheme == "cv": self._split_cv() - self.regex_patterns = [ - RegexPatterns( - id=intent.id, - regex_full_match=intent.regex_full_match, - regex_partial_match=intent.regex_partial_match, - ) - for intent in self.dataset.intents - ] - self.intent_descriptions = [intent.description for intent in self.dataset.intents] self.tags = self.dataset.get_tags() diff --git a/autointent/context/data_handler/_stratification.py b/autointent/context/data_handler/_stratification.py index a28d1469b..da3434cdb 100644 --- a/autointent/context/data_handler/_stratification.py +++ b/autointent/context/data_handler/_stratification.py @@ -29,7 +29,7 @@ def __init__( self, test_size: float, label_feature: str, - random_seed: int, + random_seed: int | None, shuffle: bool = True, ) -> None: """Initialize the StratifiedSplitter. @@ -283,7 +283,7 @@ def split_dataset( dataset: Dataset, split: str, test_size: float, - random_seed: int, + random_seed: int | None, allow_oos_in_train: bool | None = None, ) -> tuple[HFDataset, HFDataset]: """Split a Dataset object into training and testing subsets. diff --git a/autointent/metrics/regex.py b/autointent/metrics/regex.py index 29a8b5099..b92c9bd86 100644 --- a/autointent/metrics/regex.py +++ b/autointent/metrics/regex.py @@ -1,4 +1,4 @@ -"""Regex metrics for intent recognition.""" +"""Metrics for regex modules.""" from typing import Protocol diff --git a/autointent/modules/__init__.py b/autointent/modules/__init__.py index 7cf4ae529..5eefed1e7 100644 --- a/autointent/modules/__init__.py +++ b/autointent/modules/__init__.py @@ -11,7 +11,7 @@ TunableDecision, ) from .embedding import LogregAimedEmbedding, RetrievalAimedEmbedding -from .regex import Regex +from .regex import SimpleRegex from .scoring import DescriptionScorer, DNNCScorer, KNNScorer, LinearScorer, MLKnnScorer, RerankScorer, SklearnScorer T = TypeVar("T", bound=BaseModule) @@ -21,7 +21,7 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]: return {module.name: module for module in modules} -REGEX_MODULES: dict[str, type[BaseRegex]] = _create_modules_dict([Regex]) +REGEX_MODULES: dict[str, type[BaseRegex]] = _create_modules_dict([SimpleRegex]) EMBEDDING_MODULES: dict[str, type[BaseEmbedding]] = _create_modules_dict( [RetrievalAimedEmbedding, LogregAimedEmbedding] diff --git a/autointent/modules/decision/_tunable.py b/autointent/modules/decision/_tunable.py index 89ab24a11..8e0fff36a 100644 --- a/autointent/modules/decision/_tunable.py +++ b/autointent/modules/decision/_tunable.py @@ -85,7 +85,7 @@ def __init__( self, target_metric: MetricType = "decision_accuracy", n_optuna_trials: PositiveInt = 320, - seed: int = 0, + seed: int | None = 0, tags: list[Tag] | None = None, ) -> None: """Initialize tunable predictor. @@ -222,7 +222,7 @@ def fit( self, probas: npt.NDArray[Any], labels: ListOfGenericLabels, - seed: int, + seed: int | None, tags: list[Tag] | None = None, ) -> None: """Fit the optimizer by finding optimal thresholds. diff --git a/autointent/modules/regex/__init__.py b/autointent/modules/regex/__init__.py index 895cb6924..58dee76f5 100644 --- a/autointent/modules/regex/__init__.py +++ b/autointent/modules/regex/__init__.py @@ -1,3 +1,3 @@ -from ._simple import Regex +from ._simple import SimpleRegex -__all__ = ["Regex"] +__all__ = ["SimpleRegex"] diff --git a/autointent/modules/regex/_simple.py b/autointent/modules/regex/_simple.py index 0ce441ca4..4945278ce 100644 --- a/autointent/modules/regex/_simple.py +++ b/autointent/modules/regex/_simple.py @@ -1,14 +1,16 @@ """Module for regular expressions based intent detection.""" +import json import re from collections.abc import Iterable +from pathlib import Path from typing import Any, TypedDict import numpy as np import numpy.typing as npt from autointent import Context -from autointent.context.data_handler._data_handler import RegexPatterns +from autointent.configs import CrossEncoderConfig, EmbedderConfig from autointent.context.optimization_info import Artifact from autointent.custom_types import LabelType, ListOfGenericLabels, ListOfLabels from autointent.metrics import REGEX_METRICS @@ -30,7 +32,7 @@ class RegexPatternsCompiled(TypedDict): regex_partial_match: list[re.Pattern[str]] -class Regex(BaseRegex): +class SimpleRegex(BaseRegex): """Regular expressions based intent detection module. A module that uses regular expressions to detect intents in text utterances. @@ -46,14 +48,14 @@ class Regex(BaseRegex): supports_oos = False @classmethod - def from_context(cls, context: Context) -> "Regex": + def from_context(cls, context: Context) -> "SimpleRegex": """Initialize from context. Args: context: Context object containing configuration Returns: - Initialized Regex instance + Initialized SimpleRegex instance """ return cls() @@ -63,15 +65,15 @@ def fit(self, intents: list[Intent]) -> None: Args: intents: List of intents to fit the model with """ - self.regex_patterns = [ - RegexPatterns( - id=intent.id, - regex_full_match=intent.regex_full_match, - regex_partial_match=intent.regex_partial_match, - ) + regex_patterns = [ + { + "id": intent.id, + "regex_full_match": intent.regex_full_match, + "regex_partial_match": intent.regex_partial_match, + } for intent in intents ] - self._compile_regex_patterns() + self._compile_regex_patterns(regex_patterns) def predict(self, utterances: list[str]) -> list[LabelType]: """Predict intents for given utterances. @@ -214,7 +216,7 @@ def score_metrics_cv( def clear_cache(self) -> None: """Clear cached regex patterns.""" - del self.regex_patterns + del self.regex_patterns_compiled def get_assets(self) -> Artifact: """Get model assets. @@ -224,7 +226,7 @@ def get_assets(self) -> Artifact: """ return Artifact() - def _compile_regex_patterns(self) -> None: + def _compile_regex_patterns(self, regex_patterns: list[dict[str, Any]]) -> None: """Compile regex patterns with case-insensitive flag.""" self.regex_patterns_compiled = [ RegexPatternsCompiled( @@ -236,5 +238,31 @@ def _compile_regex_patterns(self) -> None: re.compile(ptn, flags=re.IGNORECASE) for ptn in regex_patterns["regex_partial_match"] ], ) - for regex_patterns in self.regex_patterns + for regex_patterns in regex_patterns + ] + + def dump(self, path: str) -> None: + serialized = [ + { + "id": regex_patterns["id"], + "regex_full_match": [pattern.pattern for pattern in regex_patterns["regex_full_match"]], + "regex_partial_match": [pattern.pattern for pattern in regex_patterns["regex_partial_match"]], + } + for regex_patterns in self.regex_patterns_compiled ] + + dump_dir = Path(path) + dump_dir.mkdir(parents=True, exist_ok=True) + with (dump_dir / "regex_patterns.json").open("w") as file: + json.dump(serialized, file, indent=4, ensure_ascii=False) + + def load( + self, + path: str, + embedder_config: EmbedderConfig | None = None, + cross_encoder_config: CrossEncoderConfig | None = None, + ) -> None: + with (Path(path) / "regex_patterns.json").open() as file: + serialized: list[dict[str, Any]] = json.load(file) + + self._compile_regex_patterns(serialized) diff --git a/tests/assets/configs/multiclass.yaml b/tests/assets/configs/multiclass.yaml index eedaf5df5..689813cd3 100644 --- a/tests/assets/configs/multiclass.yaml +++ b/tests/assets/configs/multiclass.yaml @@ -25,6 +25,9 @@ m: [ 2, 3 ] cross_encoder_config: - cross-encoder/ms-marco-MiniLM-L-6-v2 + - module_name: sklearn + clf_name: [RandomForestClassifier] + n_estimators: [5, 10] - node_type: decision target_metric: decision_accuracy search_space: diff --git a/tests/assets/configs/multilabel.yaml b/tests/assets/configs/multilabel.yaml index 8b6cefc3a..241239b3c 100644 --- a/tests/assets/configs/multilabel.yaml +++ b/tests/assets/configs/multilabel.yaml @@ -21,6 +21,9 @@ m: [ 2, 3 ] cross_encoder_config: - model_name: cross-encoder/ms-marco-MiniLM-L-6-v2 + - module_name: sklearn + clf_name: [RandomForestClassifier] + n_estimators: [5, 10] - node_type: decision target_metric: decision_accuracy search_space: diff --git a/tests/modules/test_regex.py b/tests/modules/test_regex.py index c55befec2..a9fdb275b 100644 --- a/tests/modules/test_regex.py +++ b/tests/modules/test_regex.py @@ -1,6 +1,6 @@ import pytest -from autointent.modules import Regex +from autointent.modules import SimpleRegex from autointent.schemas import Intent @@ -14,7 +14,7 @@ def test_base_regex(partial_match, expected_predictions): Intent(id=1, name="account_blocked", regex_partial_match=[partial_match]), ] - matcher = Regex() + matcher = SimpleRegex() matcher.fit(train_data) test_data = [ diff --git a/tests/pipeline/test_inference.py b/tests/pipeline/test_inference.py index 3e864fd03..542861f22 100644 --- a/tests/pipeline/test_inference.py +++ b/tests/pipeline/test_inference.py @@ -8,7 +8,7 @@ @pytest.mark.parametrize( "task_type", - ["multiclass", "multilabel", "description"], + ["regex", "multiclass", "multilabel", "description"], ) def test_inference_from_config(dataset, task_type): project_dir = setup_environment() @@ -22,6 +22,7 @@ def test_inference_from_config(dataset, task_type): if task_type == "multilabel": dataset = dataset.to_multilabel() + # case 1: inference from file system context = pipeline_optimizer.fit(dataset) context.dump() @@ -30,13 +31,26 @@ def test_inference_from_config(dataset, task_type): prediction = inference_pipeline.predict(utterances) assert len(prediction) == 2 + # case 2: rich inference from file system rich_outputs = inference_pipeline.predict_with_metadata(utterances) assert len(rich_outputs.predictions) == len(utterances) + if task_type == "regex": + assert rich_outputs.regex_predictions is not None + + # case 3: dump and then load pipeline + dump_dir = project_dir / "dumped_pipeline" + pipeline_optimizer.dump(dump_dir) + del pipeline_optimizer + + loaded_pipe = Pipeline.load(dump_dir) + prediction_v2 = loaded_pipe.predict(utterances) + assert prediction == prediction_v2 + @pytest.mark.parametrize( "task_type", - ["multiclass", "multilabel", "description"], + ["regex", "multiclass", "multilabel", "description"], ) def test_inference_on_the_fly(dataset, task_type): project_dir = setup_environment() @@ -44,21 +58,32 @@ def test_inference_on_the_fly(dataset, task_type): pipeline = Pipeline.from_search_space(search_space) - pipeline.set_config(LoggingConfig(project_dir=project_dir, dump_modules=False, clear_ram=False)) + logging_config = LoggingConfig(project_dir=project_dir, dump_modules=False, clear_ram=False) + pipeline.set_config(logging_config) if task_type == "multilabel": dataset = dataset.to_multilabel() - context = pipeline.fit(dataset) + # case 1: simple inference on the fly + pipeline.fit(dataset) utterances = ["123", "hello world"] prediction = pipeline.predict(utterances) - assert len(prediction) == 2 + # case 2: rich inference on the fly rich_outputs = pipeline.predict_with_metadata(utterances) assert len(rich_outputs.predictions) == len(utterances) - context.dump() + if task_type == "regex": + assert rich_outputs.regex_predictions is not None + + # case 3: dump and then load pipeline + pipeline.dump() + del pipeline + + loaded_pipe = Pipeline.load(logging_config.dirpath) + prediction_v2 = loaded_pipe.predict(utterances) + assert prediction == prediction_v2 def test_load_with_overrided_params(dataset): @@ -73,15 +98,37 @@ def test_load_with_overrided_params(dataset): context = pipeline_optimizer.fit(dataset) context.dump() + # case 1: simple inference from file system inference_pipeline = Pipeline.load(logging_config.dirpath, embedder_config=EmbedderConfig(max_length=8)) utterances = ["123", "hello world"] prediction = inference_pipeline.predict(utterances) assert len(prediction) == 2 + # case 2: rich inference from file system rich_outputs = inference_pipeline.predict_with_metadata(utterances) assert len(rich_outputs.predictions) == len(utterances) - assert inference_pipeline.nodes[NodeType.scoring].module._embedder.max_length == 8 + del inference_pipeline + + # case 3: dump and then load pipeline + pipeline_optimizer.dump() + del pipeline_optimizer + loaded_pipe = Pipeline.load(logging_config.dirpath, embedder_config=EmbedderConfig(max_length=8)) + prediction_v2 = loaded_pipe.predict(utterances) + assert prediction == prediction_v2 + assert loaded_pipe.nodes[NodeType.scoring].module._embedder.max_length == 8 + + +def test_no_saving(dataset): + project_dir = setup_environment() + search_space = get_search_space("light") + + pipeline_optimizer = Pipeline.from_search_space(search_space) + + logging_config = LoggingConfig(project_dir=project_dir, dump_modules=False, clear_ram=True) + pipeline_optimizer.set_config(logging_config) -# TODO Pipeline.dump() + pipeline_optimizer.fit(dataset) + with pytest.raises(RuntimeError): + pipeline_optimizer.dump()