Skip to content

Commit de4d37d

Browse files
authored
Refactor/pipeline load and dump (#157)
* implement logic * refactor tests a little bit * minor bug fix * add `None` option for random seed * fix typing and codestyle * minor bug fix * add sklearn to tests * fix test * `Regex` -> `SimpleRegex` * implement regex module loading and dumping * fix typing and codestyle * minor bug fix * add tests for regex * minor changes * fix refitting logic error
1 parent e7ceb52 commit de4d37d

File tree

14 files changed

+198
-103
lines changed

14 files changed

+198
-103
lines changed

autointent/_pipeline/_pipeline.py

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from ._schemas import InferencePipelineOutput, InferencePipelineUtteranceOutput
3232

3333
if TYPE_CHECKING:
34-
from autointent.modules.base import BaseDecision, BaseScorer
34+
from autointent.modules.base import BaseDecision, BaseRegex, BaseScorer
3535

3636

3737
class Pipeline:
@@ -41,7 +41,7 @@ def __init__(
4141
self,
4242
nodes: list[NodeOptimizer] | list[InferenceNode],
4343
sampler: SamplerType = "brute",
44-
seed: int = 42,
44+
seed: int | None = 42,
4545
) -> None:
4646
"""Initialize the pipeline optimizer.
4747
@@ -85,7 +85,7 @@ def set_config(self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig
8585
assert_never(config)
8686

8787
@classmethod
88-
def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed: int = 42) -> "Pipeline":
88+
def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed: int | None = 42) -> "Pipeline":
8989
"""Search space to pipeline optimizer.
9090
9191
Args:
@@ -101,7 +101,7 @@ def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed
101101
return cls(nodes=nodes, seed=seed)
102102

103103
@classmethod
104-
def from_preset(cls, name: SearchSpacePresets, seed: int = 42) -> "Pipeline":
104+
def from_preset(cls, name: SearchSpacePresets, seed: int | None = 42) -> "Pipeline":
105105
optimization_config = load_preset(name)
106106
config = OptimizationConfig(seed=seed, **optimization_config)
107107
return cls.from_optimization_config(config=config)
@@ -186,7 +186,7 @@ def fit(
186186
msg = "Pipeline in inference mode cannot be fitted"
187187
raise RuntimeError(msg)
188188

189-
context = Context()
189+
context = Context(self.seed)
190190
context.set_dataset(dataset, self.data_config)
191191
context.configure_logging(self.logging_config)
192192
context.configure_transformer(self.embedder_config)
@@ -199,25 +199,43 @@ def fit(
199199
self._logger.warning(
200200
"Test data is not provided. Final test metrics won't be calculated after pipeline optimization."
201201
)
202+
elif context.logging_config.clear_ram and not context.logging_config.dump_modules:
203+
self._logger.warning(
204+
"Test data is provided, but final metrics won't be calculated "
205+
"because fitted modules won't be saved neither in RAM nor in file system."
206+
"Change settings in LoggerConfig to obtain different behavior."
207+
)
202208

203209
if sampler is None:
204210
sampler = self.sampler
205211

206212
self._fit(context, sampler)
207213

208-
if context.is_ram_to_clear():
214+
if context.logging_config.clear_ram and context.logging_config.dump_modules:
209215
nodes_configs = context.optimization_info.get_inference_nodes_config()
210216
nodes_list = [InferenceNode.from_config(cfg) for cfg in nodes_configs]
211-
else:
217+
elif not context.logging_config.clear_ram:
212218
modules_dict = context.optimization_info.get_best_modules()
213219
nodes_list = [InferenceNode(module, node_type) for node_type, module in modules_dict.items()]
220+
else:
221+
self._logger.info(
222+
"Skipping calculating final metrics because fitted modules weren't saved."
223+
"Change settings in LoggerConfig to obtain different behavior."
224+
)
225+
return context
214226

215-
self.nodes = {node.node_type: node for node in nodes_list}
227+
self.nodes = {node.node_type: node for node in nodes_list if node.node_type != NodeType.embedding}
216228

217229
if refit_after:
218-
# TODO reflect this refitting in dumped version of pipeline
219230
self._refit(context)
220231

232+
self._nodes_configs: dict[str, InferenceNodeConfig] = {
233+
NodeType(cfg.node_type): cfg
234+
for cfg in context.optimization_info.get_inference_nodes_config()
235+
if cfg.node_type != NodeType.embedding
236+
}
237+
self._dump_dir = context.logging_config.dirpath
238+
221239
if test_utterances is not None:
222240
predictions = self.predict(test_utterances)
223241
for metric_name, metric in DECISION_METRICS.items():
@@ -229,6 +247,41 @@ def fit(
229247

230248
return context
231249

250+
def dump(self, path: str | Path | None = None) -> None:
251+
if isinstance(path, str):
252+
path = Path(path)
253+
elif path is None:
254+
if hasattr(self, "_dump_dir"):
255+
path = self._dump_dir
256+
else:
257+
msg = (
258+
"Either you didn't trained the pipeline yet or fitted modules weren't saved during optimization. "
259+
"Change settings in LoggerConfig and retrain the pipeline to obtain different behavior."
260+
)
261+
self._logger.error(msg)
262+
raise RuntimeError(msg)
263+
264+
scoring_module: BaseScorer = self.nodes[NodeType.scoring].module # type: ignore[assignment,union-attr]
265+
decision_module: BaseDecision = self.nodes[NodeType.decision].module # type: ignore[assignment,union-attr]
266+
267+
scoring_dump_dir = str(path / "scoring_module")
268+
decision_dump_dir = str(path / "decision_module")
269+
scoring_module.dump(scoring_dump_dir)
270+
decision_module.dump(decision_dump_dir)
271+
272+
self._nodes_configs[NodeType.scoring].load_path = scoring_dump_dir
273+
self._nodes_configs[NodeType.decision].load_path = decision_dump_dir
274+
275+
if NodeType.regex in self.nodes:
276+
regex_module: BaseRegex = self.nodes[NodeType.regex].module # type: ignore[assignment,union-attr]
277+
regex_dump_dir = str(path / "regex_module")
278+
regex_module.dump(regex_dump_dir)
279+
self._nodes_configs[NodeType.regex].load_path = regex_dump_dir
280+
281+
inference_nodes_configs = [cfg.asdict() for cfg in self._nodes_configs.values()]
282+
with (path / "inference_config.yaml").open("w") as file:
283+
yaml.dump(inference_nodes_configs, file)
284+
232285
def validate_modules(self, dataset: Dataset, mode: SearchSpaceValidationMode) -> None:
233286
"""Validate modules with dataset.
234287
@@ -240,18 +293,6 @@ def validate_modules(self, dataset: Dataset, mode: SearchSpaceValidationMode) ->
240293
if isinstance(node, NodeOptimizer):
241294
node.validate_nodes_with_dataset(dataset, mode)
242295

243-
@classmethod
244-
def from_dict_config(cls, nodes_configs: list[dict[str, Any]]) -> "Pipeline":
245-
"""Create inference pipeline from dictionary config.
246-
247-
Args:
248-
nodes_configs: list of config for nodes
249-
250-
Returns:
251-
Inference pipeline
252-
"""
253-
return cls.from_config([InferenceNodeConfig(**cfg) for cfg in nodes_configs])
254-
255296
@classmethod
256297
def from_config(cls, nodes_configs: list[InferenceNodeConfig]) -> "Pipeline":
257298
"""Create inference pipeline from config.
@@ -283,13 +324,13 @@ def load(
283324
Inference pipeline
284325
"""
285326
with (Path(path) / "inference_config.yaml").open() as file:
286-
inference_dict_config: dict[str, Any] = yaml.safe_load(file)
327+
inference_nodes_configs: list[dict[str, Any]] = yaml.safe_load(file)
287328

288329
inference_config = [
289330
InferenceNodeConfig(
290331
**node_config, embedder_config=embedder_config, cross_encoder_config=cross_encoder_config
291332
)
292-
for node_config in inference_dict_config["nodes_configs"]
333+
for node_config in inference_nodes_configs
293334
]
294335

295336
return cls.from_config(inference_config)

autointent/configs/_inference_node.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Configuration for the nodes."""
22

3-
from dataclasses import dataclass
3+
from dataclasses import asdict, dataclass
44
from typing import Any
55

66
from autointent.custom_types import NodeType
@@ -24,3 +24,15 @@ class InferenceNodeConfig:
2424
"""One can override presaved embedder config while loading from file system."""
2525
cross_encoder_config: CrossEncoderConfig | None = None
2626
"""One can override presaved cross encoder config while loading from file system."""
27+
28+
def asdict(self) -> dict[str, Any]:
29+
res = asdict(self)
30+
if self.embedder_config is not None:
31+
res["embedder_config"] = self.embedder_config.model_dump()
32+
else:
33+
res.pop("embedder_config")
34+
if self.cross_encoder_config is not None:
35+
res["cross_encoder_config"] = self.cross_encoder_config.model_dump()
36+
else:
37+
res.pop("cross_encoder_config")
38+
return res

autointent/context/_context.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import json
44
import logging
55
from pathlib import Path
6-
from typing import Any
76

87
import yaml
98

@@ -32,7 +31,7 @@ class Context:
3231
optimization_info: OptimizationInfo
3332
callback_handler = CallbackHandler()
3433

35-
def __init__(self, seed: int = 42) -> None:
34+
def __init__(self, seed: int | None = 42) -> None:
3635
"""Initialize the Context object.
3736
3837
Args:
@@ -71,22 +70,6 @@ def set_dataset(self, dataset: Dataset, config: DataConfig) -> None:
7170
"""
7271
self.data_handler = DataHandler(dataset=dataset, random_seed=self.seed, config=config)
7372

74-
def get_inference_config(self) -> dict[str, Any]:
75-
"""Generate configuration settings for inference.
76-
77-
Returns:
78-
Dictionary containing inference configuration.
79-
"""
80-
nodes_configs = self.optimization_info.get_inference_nodes_config(asdict=True)
81-
return {
82-
"metadata": {
83-
"multilabel": self.is_multilabel(),
84-
"n_classes": self.get_n_classes(),
85-
"seed": self.seed,
86-
},
87-
"nodes_configs": nodes_configs,
88-
}
89-
9073
def dump(self) -> None:
9174
"""Save logs, configurations, and datasets to disk."""
9275
self._logger.debug("dumping logs...")
@@ -103,7 +86,7 @@ def dump(self) -> None:
10386

10487
self._logger.info("logs and other assets are saved to %s", logs_dir)
10588

106-
inference_config = self.get_inference_config()
89+
inference_config = self.optimization_info.get_inference_nodes_config(asdict=True)
10790
inference_config_path = logs_dir / "inference_config.yaml"
10891
with inference_config_path.open("w") as file:
10992
yaml.dump(inference_config, file)

autointent/context/data_handler/_data_handler.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
from collections.abc import Generator
5-
from typing import TypedDict, cast
5+
from typing import cast
66

77
from datasets import concatenate_datasets
88
from transformers import set_seed
@@ -16,28 +16,14 @@
1616
logger = logging.getLogger(__name__)
1717

1818

19-
class RegexPatterns(TypedDict):
20-
"""Regex patterns for each intent class.
21-
22-
Attributes:
23-
id: Intent class id.
24-
regex_full_match: Full match regex patterns.
25-
regex_partial_match: Partial match regex patterns.
26-
"""
27-
28-
id: int
29-
regex_full_match: list[str]
30-
regex_partial_match: list[str]
31-
32-
33-
class DataHandler: # TODO rename to Validator
19+
class DataHandler:
3420
"""Data handler class."""
3521

3622
def __init__(
3723
self,
3824
dataset: Dataset,
3925
config: DataConfig | None = None,
40-
random_seed: int = 0,
26+
random_seed: int | None = 0,
4127
) -> None:
4228
"""Initialize the data handler.
4329
@@ -46,7 +32,8 @@ def __init__(
4632
config: Configuration object
4733
random_seed: Seed for random number generation.
4834
"""
49-
set_seed(random_seed)
35+
if random_seed is not None:
36+
set_seed(random_seed)
5037
self.random_seed = random_seed
5138

5239
self.dataset = dataset
@@ -59,15 +46,6 @@ def __init__(
5946
elif self.config.scheme == "cv":
6047
self._split_cv()
6148

62-
self.regex_patterns = [
63-
RegexPatterns(
64-
id=intent.id,
65-
regex_full_match=intent.regex_full_match,
66-
regex_partial_match=intent.regex_partial_match,
67-
)
68-
for intent in self.dataset.intents
69-
]
70-
7149
self.intent_descriptions = [intent.description for intent in self.dataset.intents]
7250
self.tags = self.dataset.get_tags()
7351

autointent/context/data_handler/_stratification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(
2929
self,
3030
test_size: float,
3131
label_feature: str,
32-
random_seed: int,
32+
random_seed: int | None,
3333
shuffle: bool = True,
3434
) -> None:
3535
"""Initialize the StratifiedSplitter.
@@ -283,7 +283,7 @@ def split_dataset(
283283
dataset: Dataset,
284284
split: str,
285285
test_size: float,
286-
random_seed: int,
286+
random_seed: int | None,
287287
allow_oos_in_train: bool | None = None,
288288
) -> tuple[HFDataset, HFDataset]:
289289
"""Split a Dataset object into training and testing subsets.

autointent/metrics/regex.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Regex metrics for intent recognition."""
1+
"""Metrics for regex modules."""
22

33
from typing import Protocol
44

autointent/modules/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
TunableDecision,
1212
)
1313
from .embedding import LogregAimedEmbedding, RetrievalAimedEmbedding
14-
from .regex import Regex
14+
from .regex import SimpleRegex
1515
from .scoring import DescriptionScorer, DNNCScorer, KNNScorer, LinearScorer, MLKnnScorer, RerankScorer, SklearnScorer
1616

1717
T = TypeVar("T", bound=BaseModule)
@@ -21,7 +21,7 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
2121
return {module.name: module for module in modules}
2222

2323

24-
REGEX_MODULES: dict[str, type[BaseRegex]] = _create_modules_dict([Regex])
24+
REGEX_MODULES: dict[str, type[BaseRegex]] = _create_modules_dict([SimpleRegex])
2525

2626
EMBEDDING_MODULES: dict[str, type[BaseEmbedding]] = _create_modules_dict(
2727
[RetrievalAimedEmbedding, LogregAimedEmbedding]

autointent/modules/decision/_tunable.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(
8585
self,
8686
target_metric: MetricType = "decision_accuracy",
8787
n_optuna_trials: PositiveInt = 320,
88-
seed: int = 0,
88+
seed: int | None = 0,
8989
tags: list[Tag] | None = None,
9090
) -> None:
9191
"""Initialize tunable predictor.
@@ -222,7 +222,7 @@ def fit(
222222
self,
223223
probas: npt.NDArray[Any],
224224
labels: ListOfGenericLabels,
225-
seed: int,
225+
seed: int | None,
226226
tags: list[Tag] | None = None,
227227
) -> None:
228228
"""Fit the optimizer by finding optimal thresholds.
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from ._simple import Regex
1+
from ._simple import SimpleRegex
22

3-
__all__ = ["Regex"]
3+
__all__ = ["SimpleRegex"]

0 commit comments

Comments
 (0)