From 646611578a6ad6dc060a0202d6b284b556663f53 Mon Sep 17 00:00:00 2001 From: voorhs Date: Mon, 3 Mar 2025 11:54:44 +0300 Subject: [PATCH 01/15] implement logic --- autointent/_pipeline/_pipeline.py | 73 ++++++++++++++++++++------- autointent/configs/_inference_node.py | 10 +++- autointent/context/_context.py | 19 +------ 3 files changed, 64 insertions(+), 38 deletions(-) diff --git a/autointent/_pipeline/_pipeline.py b/autointent/_pipeline/_pipeline.py index 0549ecb05..4631778b3 100644 --- a/autointent/_pipeline/_pipeline.py +++ b/autointent/_pipeline/_pipeline.py @@ -199,25 +199,42 @@ def fit( self._logger.warning( "Test data is not provided. Final test metrics won't be calculated after pipeline optimization." ) + elif context.logging_config.clear_ram and not context.logging_config.dump_modules: + self._logger.warning( + "Test data is provided, but final metrics won't be calculated " + "because fitted modules won't be saved neither in RAM nor in file system." + "Change settings in LoggerConfig to obtain different behavior." + ) if sampler is None: sampler = self.sampler self._fit(context, sampler) - if context.is_ram_to_clear(): + if context.logging_config.clear_ram and context.logging_config.dump_modules: nodes_configs = context.optimization_info.get_inference_nodes_config() nodes_list = [InferenceNode.from_config(cfg) for cfg in nodes_configs] - else: + elif not context.logging_config.clear_ram: modules_dict = context.optimization_info.get_best_modules() nodes_list = [InferenceNode(module, node_type) for node_type, module in modules_dict.items()] - - self.nodes = {node.node_type: node for node in nodes_list} + else: + self._logger.info( + "Skipping calculating final metrics because fitted modules weren't saved." + "Change settings in LoggerConfig to obtain different behavior." + ) + return context if refit_after: - # TODO reflect this refitting in dumped version of pipeline self._refit(context) + self._nodes_configs: dict[str, InferenceNodeConfig] = { + NodeType(cfg.node_type): cfg + for cfg in context.optimization_info.get_inference_nodes_config() + if cfg.node_type != NodeType.embedding + } + self.nodes = {node.node_type: node for node in nodes_list if node.node_type != NodeType.embedding} + self._dump_dir = context.logging_config.dump_dir + if test_utterances is not None: predictions = self.predict(test_utterances) for metric_name, metric in DECISION_METRICS.items(): @@ -229,6 +246,36 @@ def fit( return context + def dump(self, path: str | Path | None = None) -> None: + if isinstance(path, str): + path = Path(path) + elif path is None: + if hasattr(self, "_dump_dir"): + path = self._dump_dir + else: + msg = ( + "Either you didn't trained the pipeline yet or fitted modules weren't saved during optimization. " + "Change settings in LoggerConfig and retrain the pipeline to obtain different behavior." + ) + self._logger.error(msg) + raise RuntimeError(msg) + + # TODO add regex module handling + scoring_module: BaseScorer = self.nodes[NodeType.scoring].module # type: ignore[assignment,union-attr] + decision_module: BaseDecision = self.nodes[NodeType.decision].module # type: ignore[assignment,union-attr] + + scoring_dump_dir = str(path / "scoring_module") + decision_dump_dir = str(path / "decision_module") + scoring_module.dump(scoring_dump_dir) + decision_module.dump(decision_dump_dir) + + self._nodes_configs[NodeType.scoring].load_path = scoring_dump_dir + self._nodes_configs[NodeType.decision].load_path = decision_dump_dir + + inference_nodes_configs = [cfg.asdict() for cfg in self._nodes_configs.values()] + with (path / "inference_config.yaml").open("w") as file: + yaml.dump(inference_nodes_configs, file) + def validate_modules(self, dataset: Dataset, mode: SearchSpaceValidationMode) -> None: """Validate modules with dataset. @@ -240,18 +287,6 @@ def validate_modules(self, dataset: Dataset, mode: SearchSpaceValidationMode) -> if isinstance(node, NodeOptimizer): node.validate_nodes_with_dataset(dataset, mode) - @classmethod - def from_dict_config(cls, nodes_configs: list[dict[str, Any]]) -> "Pipeline": - """Create inference pipeline from dictionary config. - - Args: - nodes_configs: list of config for nodes - - Returns: - Inference pipeline - """ - return cls.from_config([InferenceNodeConfig(**cfg) for cfg in nodes_configs]) - @classmethod def from_config(cls, nodes_configs: list[InferenceNodeConfig]) -> "Pipeline": """Create inference pipeline from config. @@ -283,13 +318,13 @@ def load( Inference pipeline """ with (Path(path) / "inference_config.yaml").open() as file: - inference_dict_config: dict[str, Any] = yaml.safe_load(file) + inference_nodes_configs: list[dict[str, Any]] = yaml.safe_load(file) inference_config = [ InferenceNodeConfig( **node_config, embedder_config=embedder_config, cross_encoder_config=cross_encoder_config ) - for node_config in inference_dict_config["nodes_configs"] + for node_config in inference_nodes_configs ] return cls.from_config(inference_config) diff --git a/autointent/configs/_inference_node.py b/autointent/configs/_inference_node.py index 09fe1ed47..ef5d93902 100644 --- a/autointent/configs/_inference_node.py +++ b/autointent/configs/_inference_node.py @@ -1,6 +1,6 @@ """Configuration for the nodes.""" -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import Any from autointent.custom_types import NodeType @@ -24,3 +24,11 @@ class InferenceNodeConfig: """One can override presaved embedder config while loading from file system.""" cross_encoder_config: CrossEncoderConfig | None = None """One can override presaved cross encoder config while loading from file system.""" + + def asdict(self) -> dict[str, Any]: + res = asdict(self) + if self.embedder_config is not None: + res["embedder_config"] = self.embedder_config.model_dump() + if self.cross_encoder_config is not None: + res["cross_encoder_config"] = self.cross_encoder_config.model_dump() + return res diff --git a/autointent/context/_context.py b/autointent/context/_context.py index 9fb21f7a2..09ec109b3 100644 --- a/autointent/context/_context.py +++ b/autointent/context/_context.py @@ -3,7 +3,6 @@ import json import logging from pathlib import Path -from typing import Any import yaml @@ -71,22 +70,6 @@ def set_dataset(self, dataset: Dataset, config: DataConfig) -> None: """ self.data_handler = DataHandler(dataset=dataset, random_seed=self.seed, config=config) - def get_inference_config(self) -> dict[str, Any]: - """Generate configuration settings for inference. - - Returns: - Dictionary containing inference configuration. - """ - nodes_configs = self.optimization_info.get_inference_nodes_config(asdict=True) - return { - "metadata": { - "multilabel": self.is_multilabel(), - "n_classes": self.get_n_classes(), - "seed": self.seed, - }, - "nodes_configs": nodes_configs, - } - def dump(self) -> None: """Save logs, configurations, and datasets to disk.""" self._logger.debug("dumping logs...") @@ -103,7 +86,7 @@ def dump(self) -> None: self._logger.info("logs and other assets are saved to %s", logs_dir) - inference_config = self.get_inference_config() + inference_config = self.optimization_info.get_inference_nodes_config(asdict=True) inference_config_path = logs_dir / "inference_config.yaml" with inference_config_path.open("w") as file: yaml.dump(inference_config, file) From c220f1c22a80a04889c9642f3af5dadee2aa2344 Mon Sep 17 00:00:00 2001 From: voorhs Date: Mon, 3 Mar 2025 12:27:40 +0300 Subject: [PATCH 02/15] refactor tests a little bit --- tests/pipeline/test_inference.py | 53 ++++++++++++++++++++++++++++---- 1 file changed, 47 insertions(+), 6 deletions(-) diff --git a/tests/pipeline/test_inference.py b/tests/pipeline/test_inference.py index 3e864fd03..3c5f5c000 100644 --- a/tests/pipeline/test_inference.py +++ b/tests/pipeline/test_inference.py @@ -22,6 +22,8 @@ def test_inference_from_config(dataset, task_type): if task_type == "multilabel": dataset = dataset.to_multilabel() + + # case 1: inference from file system context = pipeline_optimizer.fit(dataset) context.dump() @@ -30,9 +32,18 @@ def test_inference_from_config(dataset, task_type): prediction = inference_pipeline.predict(utterances) assert len(prediction) == 2 + # case 2: rich inference from file system rich_outputs = inference_pipeline.predict_with_metadata(utterances) assert len(rich_outputs.predictions) == len(utterances) + # case 3: dump and then load pipeline + dump_dir = project_dir / "dumped_pipeline" + pipeline_optimizer.dump(dump_dir) + del pipeline_optimizer + + loaded_pipe = Pipeline.load(dump_dir) + prediction_v2 = loaded_pipe.predict(utterances) + assert prediction == prediction_v2 @pytest.mark.parametrize( "task_type", @@ -44,21 +55,29 @@ def test_inference_on_the_fly(dataset, task_type): pipeline = Pipeline.from_search_space(search_space) - pipeline.set_config(LoggingConfig(project_dir=project_dir, dump_modules=False, clear_ram=False)) + logging_config = LoggingConfig(project_dir=project_dir, dump_modules=False, clear_ram=False) + pipeline.set_config(logging_config) if task_type == "multilabel": dataset = dataset.to_multilabel() - context = pipeline.fit(dataset) + # case 1: simple inference on the fly + pipeline.fit(dataset) utterances = ["123", "hello world"] prediction = pipeline.predict(utterances) - assert len(prediction) == 2 + # case 2: rich inference on the fly rich_outputs = pipeline.predict_with_metadata(utterances) assert len(rich_outputs.predictions) == len(utterances) - context.dump() + # case 3: dump and then load pipeline + pipeline.dump() + del pipeline + + loaded_pipe = Pipeline.load(logging_config.dirpath) + prediction_v2 = loaded_pipe.predict(utterances) + assert prediction == prediction_v2 def test_load_with_overrided_params(dataset): @@ -73,15 +92,37 @@ def test_load_with_overrided_params(dataset): context = pipeline_optimizer.fit(dataset) context.dump() + # case 1: simple inference from file system inference_pipeline = Pipeline.load(logging_config.dirpath, embedder_config=EmbedderConfig(max_length=8)) utterances = ["123", "hello world"] prediction = inference_pipeline.predict(utterances) assert len(prediction) == 2 + # case 2: rich inference from file system rich_outputs = inference_pipeline.predict_with_metadata(utterances) assert len(rich_outputs.predictions) == len(utterances) - assert inference_pipeline.nodes[NodeType.scoring].module._embedder.max_length == 8 + del inference_pipeline + + # case 3: dump and then load pipeline + pipeline_optimizer.dump() + del pipeline_optimizer + + loaded_pipe = Pipeline.load(logging_config.dirpath) + prediction_v2 = loaded_pipe.predict(utterances) + assert prediction == prediction_v2 + assert loaded_pipe.nodes[NodeType.scoring].module._embedder.max_length == 8 -# TODO Pipeline.dump() +def test_no_saving(dataset): + project_dir = setup_environment() + search_space = get_search_space("light") + + pipeline_optimizer = Pipeline.from_search_space(search_space) + + logging_config = LoggingConfig(project_dir=project_dir, dump_modules=False, clear_ram=True) + pipeline_optimizer.set_config(logging_config) + + pipeline_optimizer.fit(dataset) + with pytest.raises(RuntimeError): + pipeline_optimizer.dump() From 49b817bcb4b1052b34c53cc080905b9f4d93f374 Mon Sep 17 00:00:00 2001 From: voorhs Date: Mon, 3 Mar 2025 12:28:03 +0300 Subject: [PATCH 03/15] minor bug fix --- autointent/configs/_inference_node.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/autointent/configs/_inference_node.py b/autointent/configs/_inference_node.py index ef5d93902..e63ae74ae 100644 --- a/autointent/configs/_inference_node.py +++ b/autointent/configs/_inference_node.py @@ -29,6 +29,10 @@ def asdict(self) -> dict[str, Any]: res = asdict(self) if self.embedder_config is not None: res["embedder_config"] = self.embedder_config.model_dump() + else: + res.pop("embedder_config") if self.cross_encoder_config is not None: res["cross_encoder_config"] = self.cross_encoder_config.model_dump() + else: + res.pop("cross_encoder_config") return res From e1c0aab81e14df6fd2ee60b02f6abd6884c3b82c Mon Sep 17 00:00:00 2001 From: voorhs Date: Mon, 3 Mar 2025 12:28:31 +0300 Subject: [PATCH 04/15] add `None` option for random seed --- autointent/_pipeline/_pipeline.py | 8 ++++---- autointent/context/_context.py | 2 +- autointent/context/data_handler/_data_handler.py | 7 ++++--- autointent/context/data_handler/_stratification.py | 4 ++-- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/autointent/_pipeline/_pipeline.py b/autointent/_pipeline/_pipeline.py index 4631778b3..8b6d19205 100644 --- a/autointent/_pipeline/_pipeline.py +++ b/autointent/_pipeline/_pipeline.py @@ -41,7 +41,7 @@ def __init__( self, nodes: list[NodeOptimizer] | list[InferenceNode], sampler: SamplerType = "brute", - seed: int = 42, + seed: int | None = 42, ) -> None: """Initialize the pipeline optimizer. @@ -85,7 +85,7 @@ def set_config(self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig assert_never(config) @classmethod - def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed: int = 42) -> "Pipeline": + def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed: int | None = 42) -> "Pipeline": """Search space to pipeline optimizer. Args: @@ -101,7 +101,7 @@ def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed return cls(nodes=nodes, seed=seed) @classmethod - def from_preset(cls, name: SearchSpacePresets, seed: int = 42) -> "Pipeline": + def from_preset(cls, name: SearchSpacePresets, seed: int | None = 42) -> "Pipeline": optimization_config = load_preset(name) config = OptimizationConfig(seed=seed, **optimization_config) return cls.from_optimization_config(config=config) @@ -186,7 +186,7 @@ def fit( msg = "Pipeline in inference mode cannot be fitted" raise RuntimeError(msg) - context = Context() + context = Context(self.seed) context.set_dataset(dataset, self.data_config) context.configure_logging(self.logging_config) context.configure_transformer(self.embedder_config) diff --git a/autointent/context/_context.py b/autointent/context/_context.py index 09ec109b3..53b83859d 100644 --- a/autointent/context/_context.py +++ b/autointent/context/_context.py @@ -31,7 +31,7 @@ class Context: optimization_info: OptimizationInfo callback_handler = CallbackHandler() - def __init__(self, seed: int = 42) -> None: + def __init__(self, seed: int | None = 42) -> None: """Initialize the Context object. Args: diff --git a/autointent/context/data_handler/_data_handler.py b/autointent/context/data_handler/_data_handler.py index 6a7291b5a..910243a42 100644 --- a/autointent/context/data_handler/_data_handler.py +++ b/autointent/context/data_handler/_data_handler.py @@ -30,14 +30,14 @@ class RegexPatterns(TypedDict): regex_partial_match: list[str] -class DataHandler: # TODO rename to Validator +class DataHandler: """Data handler class.""" def __init__( self, dataset: Dataset, config: DataConfig | None = None, - random_seed: int = 0, + random_seed: int | None = 0, ) -> None: """Initialize the data handler. @@ -46,7 +46,8 @@ def __init__( config: Configuration object random_seed: Seed for random number generation. """ - set_seed(random_seed) + if random_seed is not None: + set_seed(random_seed) self.random_seed = random_seed self.dataset = dataset diff --git a/autointent/context/data_handler/_stratification.py b/autointent/context/data_handler/_stratification.py index a28d1469b..da3434cdb 100644 --- a/autointent/context/data_handler/_stratification.py +++ b/autointent/context/data_handler/_stratification.py @@ -29,7 +29,7 @@ def __init__( self, test_size: float, label_feature: str, - random_seed: int, + random_seed: int | None, shuffle: bool = True, ) -> None: """Initialize the StratifiedSplitter. @@ -283,7 +283,7 @@ def split_dataset( dataset: Dataset, split: str, test_size: float, - random_seed: int, + random_seed: int | None, allow_oos_in_train: bool | None = None, ) -> tuple[HFDataset, HFDataset]: """Split a Dataset object into training and testing subsets. From 7fae6184f9e458e05e9fa41f1e237fbfb21d0848 Mon Sep 17 00:00:00 2001 From: voorhs Date: Mon, 3 Mar 2025 12:29:38 +0300 Subject: [PATCH 05/15] fix typing and codestyle --- autointent/modules/decision/_tunable.py | 4 ++-- tests/pipeline/test_inference.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/autointent/modules/decision/_tunable.py b/autointent/modules/decision/_tunable.py index 89ab24a11..8e0fff36a 100644 --- a/autointent/modules/decision/_tunable.py +++ b/autointent/modules/decision/_tunable.py @@ -85,7 +85,7 @@ def __init__( self, target_metric: MetricType = "decision_accuracy", n_optuna_trials: PositiveInt = 320, - seed: int = 0, + seed: int | None = 0, tags: list[Tag] | None = None, ) -> None: """Initialize tunable predictor. @@ -222,7 +222,7 @@ def fit( self, probas: npt.NDArray[Any], labels: ListOfGenericLabels, - seed: int, + seed: int | None, tags: list[Tag] | None = None, ) -> None: """Fit the optimizer by finding optimal thresholds. diff --git a/tests/pipeline/test_inference.py b/tests/pipeline/test_inference.py index 3c5f5c000..40789e240 100644 --- a/tests/pipeline/test_inference.py +++ b/tests/pipeline/test_inference.py @@ -22,7 +22,6 @@ def test_inference_from_config(dataset, task_type): if task_type == "multilabel": dataset = dataset.to_multilabel() - # case 1: inference from file system context = pipeline_optimizer.fit(dataset) context.dump() @@ -45,6 +44,7 @@ def test_inference_from_config(dataset, task_type): prediction_v2 = loaded_pipe.predict(utterances) assert prediction == prediction_v2 + @pytest.mark.parametrize( "task_type", ["multiclass", "multilabel", "description"], From bb358e83586ac1ef38106a2fd0eec130c6d9e11b Mon Sep 17 00:00:00 2001 From: voorhs Date: Mon, 3 Mar 2025 12:46:14 +0300 Subject: [PATCH 06/15] minor bug fix --- autointent/_pipeline/_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autointent/_pipeline/_pipeline.py b/autointent/_pipeline/_pipeline.py index 8b6d19205..bab63ec9a 100644 --- a/autointent/_pipeline/_pipeline.py +++ b/autointent/_pipeline/_pipeline.py @@ -233,7 +233,7 @@ def fit( if cfg.node_type != NodeType.embedding } self.nodes = {node.node_type: node for node in nodes_list if node.node_type != NodeType.embedding} - self._dump_dir = context.logging_config.dump_dir + self._dump_dir = context.logging_config.dirpath if test_utterances is not None: predictions = self.predict(test_utterances) From 7cbe5eb70285b2ee87b64b50c30aac8d7e11dd39 Mon Sep 17 00:00:00 2001 From: voorhs Date: Mon, 3 Mar 2025 12:46:20 +0300 Subject: [PATCH 07/15] add sklearn to tests --- tests/assets/configs/multiclass.yaml | 3 +++ tests/assets/configs/multilabel.yaml | 3 +++ 2 files changed, 6 insertions(+) diff --git a/tests/assets/configs/multiclass.yaml b/tests/assets/configs/multiclass.yaml index eedaf5df5..689813cd3 100644 --- a/tests/assets/configs/multiclass.yaml +++ b/tests/assets/configs/multiclass.yaml @@ -25,6 +25,9 @@ m: [ 2, 3 ] cross_encoder_config: - cross-encoder/ms-marco-MiniLM-L-6-v2 + - module_name: sklearn + clf_name: [RandomForestClassifier] + n_estimators: [5, 10] - node_type: decision target_metric: decision_accuracy search_space: diff --git a/tests/assets/configs/multilabel.yaml b/tests/assets/configs/multilabel.yaml index 8b6cefc3a..241239b3c 100644 --- a/tests/assets/configs/multilabel.yaml +++ b/tests/assets/configs/multilabel.yaml @@ -21,6 +21,9 @@ m: [ 2, 3 ] cross_encoder_config: - model_name: cross-encoder/ms-marco-MiniLM-L-6-v2 + - module_name: sklearn + clf_name: [RandomForestClassifier] + n_estimators: [5, 10] - node_type: decision target_metric: decision_accuracy search_space: From 52ba6d2cef2ddcab4234abc76a99b4a6bd906643 Mon Sep 17 00:00:00 2001 From: voorhs Date: Mon, 3 Mar 2025 12:51:50 +0300 Subject: [PATCH 08/15] fix test --- tests/pipeline/test_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipeline/test_inference.py b/tests/pipeline/test_inference.py index 40789e240..8cf92885f 100644 --- a/tests/pipeline/test_inference.py +++ b/tests/pipeline/test_inference.py @@ -108,7 +108,7 @@ def test_load_with_overrided_params(dataset): pipeline_optimizer.dump() del pipeline_optimizer - loaded_pipe = Pipeline.load(logging_config.dirpath) + loaded_pipe = Pipeline.load(logging_config.dirpath, embedder_config=EmbedderConfig(max_length=8)) prediction_v2 = loaded_pipe.predict(utterances) assert prediction == prediction_v2 assert loaded_pipe.nodes[NodeType.scoring].module._embedder.max_length == 8 From cbb74515f9c59dc318a1c577d3ec6bef047a222f Mon Sep 17 00:00:00 2001 From: voorhs Date: Mon, 3 Mar 2025 12:54:51 +0300 Subject: [PATCH 09/15] `Regex` -> `SimpleRegex` --- autointent/context/data_handler/_data_handler.py | 2 +- autointent/metrics/regex.py | 2 +- autointent/modules/__init__.py | 4 ++-- autointent/modules/regex/__init__.py | 4 ++-- autointent/modules/regex/_simple.py | 6 +++--- autointent/nodes/info/_regex.py | 4 ++-- tests/modules/test_regex.py | 4 ++-- 7 files changed, 13 insertions(+), 13 deletions(-) diff --git a/autointent/context/data_handler/_data_handler.py b/autointent/context/data_handler/_data_handler.py index 910243a42..f51721c32 100644 --- a/autointent/context/data_handler/_data_handler.py +++ b/autointent/context/data_handler/_data_handler.py @@ -17,7 +17,7 @@ class RegexPatterns(TypedDict): - """Regex patterns for each intent class. + """SimpleRegex patterns for each intent class. Attributes: id: Intent class id. diff --git a/autointent/metrics/regex.py b/autointent/metrics/regex.py index 29a8b5099..d5e29e2e9 100644 --- a/autointent/metrics/regex.py +++ b/autointent/metrics/regex.py @@ -1,4 +1,4 @@ -"""Regex metrics for intent recognition.""" +"""SimpleRegex metrics for intent recognition.""" from typing import Protocol diff --git a/autointent/modules/__init__.py b/autointent/modules/__init__.py index 7cf4ae529..5eefed1e7 100644 --- a/autointent/modules/__init__.py +++ b/autointent/modules/__init__.py @@ -11,7 +11,7 @@ TunableDecision, ) from .embedding import LogregAimedEmbedding, RetrievalAimedEmbedding -from .regex import Regex +from .regex import SimpleRegex from .scoring import DescriptionScorer, DNNCScorer, KNNScorer, LinearScorer, MLKnnScorer, RerankScorer, SklearnScorer T = TypeVar("T", bound=BaseModule) @@ -21,7 +21,7 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]: return {module.name: module for module in modules} -REGEX_MODULES: dict[str, type[BaseRegex]] = _create_modules_dict([Regex]) +REGEX_MODULES: dict[str, type[BaseRegex]] = _create_modules_dict([SimpleRegex]) EMBEDDING_MODULES: dict[str, type[BaseEmbedding]] = _create_modules_dict( [RetrievalAimedEmbedding, LogregAimedEmbedding] diff --git a/autointent/modules/regex/__init__.py b/autointent/modules/regex/__init__.py index 895cb6924..58dee76f5 100644 --- a/autointent/modules/regex/__init__.py +++ b/autointent/modules/regex/__init__.py @@ -1,3 +1,3 @@ -from ._simple import Regex +from ._simple import SimpleRegex -__all__ = ["Regex"] +__all__ = ["SimpleRegex"] diff --git a/autointent/modules/regex/_simple.py b/autointent/modules/regex/_simple.py index 0ce441ca4..09a80cf81 100644 --- a/autointent/modules/regex/_simple.py +++ b/autointent/modules/regex/_simple.py @@ -30,7 +30,7 @@ class RegexPatternsCompiled(TypedDict): regex_partial_match: list[re.Pattern[str]] -class Regex(BaseRegex): +class SimpleRegex(BaseRegex): """Regular expressions based intent detection module. A module that uses regular expressions to detect intents in text utterances. @@ -46,14 +46,14 @@ class Regex(BaseRegex): supports_oos = False @classmethod - def from_context(cls, context: Context) -> "Regex": + def from_context(cls, context: Context) -> "SimpleRegex": """Initialize from context. Args: context: Context object containing configuration Returns: - Initialized Regex instance + Initialized SimpleRegex instance """ return cls() diff --git a/autointent/nodes/info/_regex.py b/autointent/nodes/info/_regex.py index d01603100..dce222230 100644 --- a/autointent/nodes/info/_regex.py +++ b/autointent/nodes/info/_regex.py @@ -1,4 +1,4 @@ -"""Regex node info.""" +"""SimpleRegex node info.""" from collections.abc import Mapping from typing import ClassVar @@ -13,7 +13,7 @@ class RegexNodeInfo(NodeInfo): - """Regex node info.""" + """SimpleRegex node info.""" metrics_available: ClassVar[Mapping[str, RegexMetricFn]] = REGEX_METRICS diff --git a/tests/modules/test_regex.py b/tests/modules/test_regex.py index c55befec2..a9fdb275b 100644 --- a/tests/modules/test_regex.py +++ b/tests/modules/test_regex.py @@ -1,6 +1,6 @@ import pytest -from autointent.modules import Regex +from autointent.modules import SimpleRegex from autointent.schemas import Intent @@ -14,7 +14,7 @@ def test_base_regex(partial_match, expected_predictions): Intent(id=1, name="account_blocked", regex_partial_match=[partial_match]), ] - matcher = Regex() + matcher = SimpleRegex() matcher.fit(train_data) test_data = [ From ad3cc1bffb21043f723dba739792b258e9463174 Mon Sep 17 00:00:00 2001 From: voorhs Date: Mon, 3 Mar 2025 13:11:13 +0300 Subject: [PATCH 10/15] implement regex module loading and dumping --- autointent/_pipeline/_pipeline.py | 9 +++++-- autointent/modules/regex/_simple.py | 37 +++++++++++++++++++++++++---- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/autointent/_pipeline/_pipeline.py b/autointent/_pipeline/_pipeline.py index bab63ec9a..4e0a3343d 100644 --- a/autointent/_pipeline/_pipeline.py +++ b/autointent/_pipeline/_pipeline.py @@ -31,7 +31,7 @@ from ._schemas import InferencePipelineOutput, InferencePipelineUtteranceOutput if TYPE_CHECKING: - from autointent.modules.base import BaseDecision, BaseScorer + from autointent.modules.base import BaseDecision, BaseRegex, BaseScorer class Pipeline: @@ -260,7 +260,6 @@ def dump(self, path: str | Path | None = None) -> None: self._logger.error(msg) raise RuntimeError(msg) - # TODO add regex module handling scoring_module: BaseScorer = self.nodes[NodeType.scoring].module # type: ignore[assignment,union-attr] decision_module: BaseDecision = self.nodes[NodeType.decision].module # type: ignore[assignment,union-attr] @@ -272,6 +271,12 @@ def dump(self, path: str | Path | None = None) -> None: self._nodes_configs[NodeType.scoring].load_path = scoring_dump_dir self._nodes_configs[NodeType.decision].load_path = decision_dump_dir + if NodeType.regex in self.nodes: + regex_module: BaseRegex = self.nodes[NodeType.regex].module # type: ignore[assignment,union-attr] + regex_dump_dir = str(path / "regex_module") + regex_module.dump(regex_dump_dir) + self._nodes_configs[NodeType.regex].load_path = regex_dump_dir + inference_nodes_configs = [cfg.asdict() for cfg in self._nodes_configs.values()] with (path / "inference_config.yaml").open("w") as file: yaml.dump(inference_nodes_configs, file) diff --git a/autointent/modules/regex/_simple.py b/autointent/modules/regex/_simple.py index 09a80cf81..37a15474e 100644 --- a/autointent/modules/regex/_simple.py +++ b/autointent/modules/regex/_simple.py @@ -1,13 +1,16 @@ """Module for regular expressions based intent detection.""" +import json import re from collections.abc import Iterable +from pathlib import Path from typing import Any, TypedDict import numpy as np import numpy.typing as npt from autointent import Context +from autointent.configs import CrossEncoderConfig, EmbedderConfig from autointent.context.data_handler._data_handler import RegexPatterns from autointent.context.optimization_info import Artifact from autointent.custom_types import LabelType, ListOfGenericLabels, ListOfLabels @@ -63,7 +66,7 @@ def fit(self, intents: list[Intent]) -> None: Args: intents: List of intents to fit the model with """ - self.regex_patterns = [ + regex_patterns = [ RegexPatterns( id=intent.id, regex_full_match=intent.regex_full_match, @@ -71,7 +74,7 @@ def fit(self, intents: list[Intent]) -> None: ) for intent in intents ] - self._compile_regex_patterns() + self._compile_regex_patterns(regex_patterns) def predict(self, utterances: list[str]) -> list[LabelType]: """Predict intents for given utterances. @@ -214,7 +217,7 @@ def score_metrics_cv( def clear_cache(self) -> None: """Clear cached regex patterns.""" - del self.regex_patterns + del self.regex_patterns_compiled def get_assets(self) -> Artifact: """Get model assets. @@ -224,7 +227,7 @@ def get_assets(self) -> Artifact: """ return Artifact() - def _compile_regex_patterns(self) -> None: + def _compile_regex_patterns(self, regex_patterns: list[dict[str, Any]]) -> None: """Compile regex patterns with case-insensitive flag.""" self.regex_patterns_compiled = [ RegexPatternsCompiled( @@ -236,5 +239,29 @@ def _compile_regex_patterns(self) -> None: re.compile(ptn, flags=re.IGNORECASE) for ptn in regex_patterns["regex_partial_match"] ], ) - for regex_patterns in self.regex_patterns + for regex_patterns in regex_patterns ] + + def dump(self, path: str) -> None: + serialized = [ + { + "id": regex_patterns["id"], + "regex_full_match": [pattern.pattern for pattern in regex_patterns["regex_full_match"]], + "regex_partial_match": [pattern.pattern for pattern in regex_patterns["regex_partial_match"]], + } + for regex_patterns in self.regex_patterns_compiled + ] + + with (Path(path) / "regex_patterns").open("w") as file: + json.dump(serialized, file, indent=4, ensure_ascii=False) + + def load( + self, + path: str, + embedder_config: EmbedderConfig | None = None, + cross_encoder_config: CrossEncoderConfig | None = None, + ) -> None: + with (Path(path) / "regex_patterns").open() as file: + serialized: list[dict[str, Any]] = json.load(file) + + self._compile_regex_patterns(serialized) From ed59b4950f64ce4baa1779123507048cce429261 Mon Sep 17 00:00:00 2001 From: voorhs Date: Mon, 3 Mar 2025 13:13:41 +0300 Subject: [PATCH 11/15] fix typing and codestyle --- autointent/context/data_handler/_data_handler.py | 9 --------- autointent/modules/regex/_simple.py | 11 +++++------ 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/autointent/context/data_handler/_data_handler.py b/autointent/context/data_handler/_data_handler.py index f51721c32..a32ee67c7 100644 --- a/autointent/context/data_handler/_data_handler.py +++ b/autointent/context/data_handler/_data_handler.py @@ -60,15 +60,6 @@ def __init__( elif self.config.scheme == "cv": self._split_cv() - self.regex_patterns = [ - RegexPatterns( - id=intent.id, - regex_full_match=intent.regex_full_match, - regex_partial_match=intent.regex_partial_match, - ) - for intent in self.dataset.intents - ] - self.intent_descriptions = [intent.description for intent in self.dataset.intents] self.tags = self.dataset.get_tags() diff --git a/autointent/modules/regex/_simple.py b/autointent/modules/regex/_simple.py index 37a15474e..132934476 100644 --- a/autointent/modules/regex/_simple.py +++ b/autointent/modules/regex/_simple.py @@ -11,7 +11,6 @@ from autointent import Context from autointent.configs import CrossEncoderConfig, EmbedderConfig -from autointent.context.data_handler._data_handler import RegexPatterns from autointent.context.optimization_info import Artifact from autointent.custom_types import LabelType, ListOfGenericLabels, ListOfLabels from autointent.metrics import REGEX_METRICS @@ -67,11 +66,11 @@ def fit(self, intents: list[Intent]) -> None: intents: List of intents to fit the model with """ regex_patterns = [ - RegexPatterns( - id=intent.id, - regex_full_match=intent.regex_full_match, - regex_partial_match=intent.regex_partial_match, - ) + { + "id": intent.id, + "regex_full_match": intent.regex_full_match, + "regex_partial_match": intent.regex_partial_match, + } for intent in intents ] self._compile_regex_patterns(regex_patterns) From 4fd7e5d4a627a5c4d69e7db1bed1606e9678e730 Mon Sep 17 00:00:00 2001 From: voorhs Date: Mon, 3 Mar 2025 13:28:09 +0300 Subject: [PATCH 12/15] minor bug fix --- autointent/modules/regex/_simple.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/autointent/modules/regex/_simple.py b/autointent/modules/regex/_simple.py index 132934476..4945278ce 100644 --- a/autointent/modules/regex/_simple.py +++ b/autointent/modules/regex/_simple.py @@ -251,7 +251,9 @@ def dump(self, path: str) -> None: for regex_patterns in self.regex_patterns_compiled ] - with (Path(path) / "regex_patterns").open("w") as file: + dump_dir = Path(path) + dump_dir.mkdir(parents=True, exist_ok=True) + with (dump_dir / "regex_patterns.json").open("w") as file: json.dump(serialized, file, indent=4, ensure_ascii=False) def load( @@ -260,7 +262,7 @@ def load( embedder_config: EmbedderConfig | None = None, cross_encoder_config: CrossEncoderConfig | None = None, ) -> None: - with (Path(path) / "regex_patterns").open() as file: + with (Path(path) / "regex_patterns.json").open() as file: serialized: list[dict[str, Any]] = json.load(file) self._compile_regex_patterns(serialized) From 515f551314f952d0f1fc2bd0e74104fafcf04f07 Mon Sep 17 00:00:00 2001 From: voorhs Date: Mon, 3 Mar 2025 13:28:21 +0300 Subject: [PATCH 13/15] add tests for regex --- tests/pipeline/test_inference.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/pipeline/test_inference.py b/tests/pipeline/test_inference.py index 8cf92885f..542861f22 100644 --- a/tests/pipeline/test_inference.py +++ b/tests/pipeline/test_inference.py @@ -8,7 +8,7 @@ @pytest.mark.parametrize( "task_type", - ["multiclass", "multilabel", "description"], + ["regex", "multiclass", "multilabel", "description"], ) def test_inference_from_config(dataset, task_type): project_dir = setup_environment() @@ -35,6 +35,9 @@ def test_inference_from_config(dataset, task_type): rich_outputs = inference_pipeline.predict_with_metadata(utterances) assert len(rich_outputs.predictions) == len(utterances) + if task_type == "regex": + assert rich_outputs.regex_predictions is not None + # case 3: dump and then load pipeline dump_dir = project_dir / "dumped_pipeline" pipeline_optimizer.dump(dump_dir) @@ -47,7 +50,7 @@ def test_inference_from_config(dataset, task_type): @pytest.mark.parametrize( "task_type", - ["multiclass", "multilabel", "description"], + ["regex", "multiclass", "multilabel", "description"], ) def test_inference_on_the_fly(dataset, task_type): project_dir = setup_environment() @@ -71,6 +74,9 @@ def test_inference_on_the_fly(dataset, task_type): rich_outputs = pipeline.predict_with_metadata(utterances) assert len(rich_outputs.predictions) == len(utterances) + if task_type == "regex": + assert rich_outputs.regex_predictions is not None + # case 3: dump and then load pipeline pipeline.dump() del pipeline From 5eaec2b2ab1a0d317d3abdfac9cf404a294b5356 Mon Sep 17 00:00:00 2001 From: voorhs Date: Mon, 3 Mar 2025 14:12:30 +0300 Subject: [PATCH 14/15] minor changes --- autointent/context/data_handler/_data_handler.py | 16 +--------------- autointent/metrics/regex.py | 2 +- autointent/nodes/info/_regex.py | 4 ++-- 3 files changed, 4 insertions(+), 18 deletions(-) diff --git a/autointent/context/data_handler/_data_handler.py b/autointent/context/data_handler/_data_handler.py index a32ee67c7..7e3b1f15a 100644 --- a/autointent/context/data_handler/_data_handler.py +++ b/autointent/context/data_handler/_data_handler.py @@ -2,7 +2,7 @@ import logging from collections.abc import Generator -from typing import TypedDict, cast +from typing import cast from datasets import concatenate_datasets from transformers import set_seed @@ -16,20 +16,6 @@ logger = logging.getLogger(__name__) -class RegexPatterns(TypedDict): - """SimpleRegex patterns for each intent class. - - Attributes: - id: Intent class id. - regex_full_match: Full match regex patterns. - regex_partial_match: Partial match regex patterns. - """ - - id: int - regex_full_match: list[str] - regex_partial_match: list[str] - - class DataHandler: """Data handler class.""" diff --git a/autointent/metrics/regex.py b/autointent/metrics/regex.py index d5e29e2e9..b92c9bd86 100644 --- a/autointent/metrics/regex.py +++ b/autointent/metrics/regex.py @@ -1,4 +1,4 @@ -"""SimpleRegex metrics for intent recognition.""" +"""Metrics for regex modules.""" from typing import Protocol diff --git a/autointent/nodes/info/_regex.py b/autointent/nodes/info/_regex.py index dce222230..d01603100 100644 --- a/autointent/nodes/info/_regex.py +++ b/autointent/nodes/info/_regex.py @@ -1,4 +1,4 @@ -"""SimpleRegex node info.""" +"""Regex node info.""" from collections.abc import Mapping from typing import ClassVar @@ -13,7 +13,7 @@ class RegexNodeInfo(NodeInfo): - """SimpleRegex node info.""" + """Regex node info.""" metrics_available: ClassVar[Mapping[str, RegexMetricFn]] = REGEX_METRICS From d31ecb3783dab6ec247412075887463293c2169a Mon Sep 17 00:00:00 2001 From: voorhs Date: Mon, 3 Mar 2025 14:30:06 +0300 Subject: [PATCH 15/15] fix refitting logic error --- autointent/_pipeline/_pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/autointent/_pipeline/_pipeline.py b/autointent/_pipeline/_pipeline.py index 4e0a3343d..0e9712395 100644 --- a/autointent/_pipeline/_pipeline.py +++ b/autointent/_pipeline/_pipeline.py @@ -224,6 +224,8 @@ def fit( ) return context + self.nodes = {node.node_type: node for node in nodes_list if node.node_type != NodeType.embedding} + if refit_after: self._refit(context) @@ -232,7 +234,6 @@ def fit( for cfg in context.optimization_info.get_inference_nodes_config() if cfg.node_type != NodeType.embedding } - self.nodes = {node.node_type: node for node in nodes_list if node.node_type != NodeType.embedding} self._dump_dir = context.logging_config.dirpath if test_utterances is not None: