Skip to content

Commit 599794b

Browse files
authored
Add config validation (#104)
* init validation * init validation * add validation * add validation to pipeline * simplify fields * make module name literal * fix fields * fix task types * update name and metrics * update * add config tests * update metrics * fix * fix naming * fix docs
1 parent 23efe32 commit 599794b

25 files changed

+561
-61
lines changed

autointent/_pipeline/_pipeline.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from autointent.custom_types import ListOfGenericLabels, NodeType
1414
from autointent.metrics import PREDICTION_METRICS_MULTILABEL
1515
from autointent.nodes import InferenceNode, NodeOptimizer
16+
from autointent.nodes.schemes import OptimizationConfig
1617
from autointent.utils import load_default_search_space, load_search_space
1718

1819
from ._schemas import InferencePipelineOutput, InferencePipelineUtteranceOutput
@@ -72,10 +73,12 @@ def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed
7273
Create pipeline optimizer from dictionary search space.
7374
7475
:param search_space: Dictionary config
76+
:param seed: random seed
7577
"""
7678
if isinstance(search_space, Path | str):
7779
search_space = load_search_space(search_space)
78-
nodes = [NodeOptimizer(**node) for node in search_space]
80+
validated_search_space = OptimizationConfig(search_space).model_dump() # type: ignore[arg-type]
81+
nodes = [NodeOptimizer(**node) for node in validated_search_space]
7982
return cls(nodes=nodes, seed=seed)
8083

8184
@classmethod
@@ -84,6 +87,9 @@ def default_optimizer(cls, multilabel: bool, seed: int = 42) -> "Pipeline":
8487
Create pipeline optimizer with default search space for given classification task.
8588
8689
:param multilabel: Whether the task multi-label, or single-label.
90+
:param seed: random seed
91+
92+
:return: Pipeline
8793
"""
8894
return cls.from_search_space(search_space=load_default_search_space(multilabel), seed=seed)
8995

autointent/custom_types.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
from enum import Enum
8-
from typing import Literal, TypeAlias, TypedDict
8+
from typing import Literal, TypeAlias
99

1010

1111
class LogLevel(Enum):
@@ -46,10 +46,6 @@ class LogLevel(Enum):
4646
"""
4747

4848

49-
class BaseMetadataDict(TypedDict):
50-
"""Base metadata dictionary for storing additional information."""
51-
52-
5349
class NodeType(str, Enum):
5450
"""Enumeration of node types in the AutoIntent pipeline."""
5551

autointent/modules/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
2424
[RetrievalAimedEmbedding, LogregAimedEmbedding]
2525
)
2626

27-
RETRIEVAL_MODULES_MULTILABEL = RETRIEVAL_MODULES_MULTICLASS
27+
RETRIEVAL_MODULES_MULTILABEL: dict[str, type[EmbeddingModule]] = RETRIEVAL_MODULES_MULTICLASS
2828

2929
SCORING_MODULES_MULTICLASS: dict[str, type[ScoringModule]] = _create_modules_dict(
3030
[

autointent/modules/decision/_threshold.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@ class ThresholdDecision(DecisionModule):
2424
ThresholdDecision uses a predefined threshold (or array of thresholds) to predict
2525
labels for single-label or multi-label classification tasks.
2626
27-
:ivar metadata_dict_name: Filename for saving metadata to disk.
28-
:ivar multilabel: If True, the model supports multi-label classification.
29-
:ivar n_classes: Number of classes in the dataset.
3027
:ivar tags: Tags for predictions (if any).
3128
:ivar name: Name of the predictor, defaults to "adaptive".
3229
@@ -78,17 +75,17 @@ class ThresholdDecision(DecisionModule):
7875

7976
def __init__(
8077
self,
81-
thresh: float | npt.NDArray[Any],
78+
thresh: float | list[float],
8279
) -> None:
8380
"""
8481
Initialize threshold predictor.
8582
8683
:param thresh: Threshold for the scores, shape (n_classes,) or float
8784
"""
88-
self.thresh = thresh
85+
self.thresh = thresh if isinstance(thresh, float) else np.array(thresh)
8986

9087
@classmethod
91-
def from_context(cls, context: Context, thresh: float | npt.NDArray[Any] = 0.5) -> "ThresholdDecision":
88+
def from_context(cls, context: Context, thresh: float | list[float] = 0.5) -> "ThresholdDecision":
9289
"""
9390
Initialize from context.
9491

autointent/modules/embedding/_logreg.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ class LogregAimedEmbedding(EmbeddingModule):
2222
The main purpose of this module is to be used at embedding node for optimizing
2323
embedding configuration using its logreg classification quality as a sort of proxy metric.
2424
25-
:ivar classifier: The trained logistic regression model.
26-
:ivar label_encoder: Label encoder for converting labels to numerical format.
25+
:ivar _classifier: The trained logistic regression model.
26+
:ivar _label_encoder: Label encoder for converting labels to numerical format.
2727
:ivar name: Name of the module, defaults to "logreg".
2828
2929
Examples
@@ -42,7 +42,7 @@ class LogregAimedEmbedding(EmbeddingModule):
4242

4343
_classifier: LogisticRegressionCV | MultiOutputClassifier
4444
_label_encoder: LabelEncoder | None
45-
name = "logreg"
45+
name = "logreg_embedding"
4646
supports_multiclass = True
4747
supports_multilabel = True
4848
supports_oos = False
@@ -62,8 +62,8 @@ def __init__(
6262
:param cv: the number of folds used in LogisticRegressionCV
6363
:param embedder_name: Name of the embedder used for creating embeddings.
6464
:param embedder_device: Device to run operations on, e.g., "cpu" or "cuda".
65-
:param batch_size: Batch size for embedding generation.
66-
:param max_length: Maximum sequence length for embeddings. None if not set.
65+
:param embedder_batch_size: Batch size for embedding generation.
66+
:param embedder_max_length: Maximum sequence length for embeddings. None if not set.
6767
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
6868
"""
6969
self.embedder_name = embedder_name

autointent/modules/regexp/_regexp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ class RegexPatternsCompiled(TypedDict):
2626
class RegExp(Module):
2727
"""Regular expressions based intent detection module."""
2828

29+
name = "regexp"
30+
2931
@classmethod
3032
def from_context(cls, context: Context) -> "RegExp":
3133
"""Initialize from context."""

autointent/modules/scoring/_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class LinearScorer(ScoringModule):
3939
.. testoutput::
4040
4141
[[0.50000032 0.49999968]
42-
[0.50000032 0.49999968]]
42+
[0.44031667 0.55968333]]
4343
4444
"""
4545

autointent/nodes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
from ._inference_node import InferenceNode
44
from ._nodes_info import DecisionNodeInfo, EmbeddingNodeInfo, NodeInfo, RegExpNodeInfo, ScoringNodeInfo
55
from ._optimization import NodeOptimizer
6+
from .schemes import OptimizationConfig
67

78
__all__ = [
89
"DecisionNodeInfo",
910
"EmbeddingNodeInfo",
1011
"InferenceNode",
1112
"NodeInfo",
1213
"NodeOptimizer",
14+
"OptimizationConfig",
1315
"RegExpNodeInfo",
1416
"ScoringNodeInfo",
1517
]

autointent/nodes/_nodes_info/_embedding.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
from autointent.metrics import (
88
RETRIEVAL_METRICS_MULTICLASS,
99
RETRIEVAL_METRICS_MULTILABEL,
10+
SCORING_METRICS_MULTICLASS,
11+
SCORING_METRICS_MULTILABEL,
1012
RetrievalMetricFn,
13+
ScoringMetricFn,
1114
)
1215
from autointent.modules import RETRIEVAL_MODULES_MULTICLASS, RETRIEVAL_MODULES_MULTILABEL
1316
from autointent.modules.abc import Module
@@ -18,12 +21,15 @@
1821
class EmbeddingNodeInfo(NodeInfo):
1922
"""Retrieval node info."""
2023

21-
metrics_available: ClassVar[Mapping[str, RetrievalMetricFn]] = (
22-
RETRIEVAL_METRICS_MULTICLASS | RETRIEVAL_METRICS_MULTILABEL
24+
metrics_available: ClassVar[Mapping[str, RetrievalMetricFn | ScoringMetricFn]] = (
25+
RETRIEVAL_METRICS_MULTICLASS
26+
| RETRIEVAL_METRICS_MULTILABEL
27+
| SCORING_METRICS_MULTILABEL
28+
| SCORING_METRICS_MULTICLASS
2329
)
2430

2531
modules_available: ClassVar[Mapping[str, type[Module]]] = (
26-
RETRIEVAL_MODULES_MULTICLASS | RETRIEVAL_MODULES_MULTILABEL # type: ignore[has-type]
32+
RETRIEVAL_MODULES_MULTICLASS | RETRIEVAL_MODULES_MULTILABEL
2733
)
2834

2935
node_type = NodeType.embedding

autointent/nodes/_optimization/_node_optimizer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ def __init__(
3535
"""
3636
self.node_type = node_type
3737
self.node_info = NODES_INFO[node_type]
38-
self.decision_metric_name = target_metric
38+
self.target_metric = target_metric
3939

4040
self.metrics = metrics if metrics is not None else []
41-
if self.decision_metric_name not in self.metrics:
42-
self.metrics.append(self.decision_metric_name)
41+
if self.target_metric not in self.metrics:
42+
self.metrics.append(self.target_metric)
4343

4444
self.modules_search_spaces = search_space # TODO search space validation
4545
self._logger = logging.getLogger(__name__) # TODO solve duplicate logging messages problem
@@ -73,7 +73,7 @@ def fit(self, context: Context) -> None:
7373

7474
self._logger.debug("scoring %s module...", module_name)
7575
metrics_score = module.score(context, "validation", self.metrics)
76-
metric_value = metrics_score[self.decision_metric_name]
76+
metric_value = metrics_score[self.target_metric]
7777

7878
context.callback_handler.log_metrics(metrics_score)
7979
context.callback_handler.end_module()
@@ -91,7 +91,7 @@ def fit(self, context: Context) -> None:
9191
module_name,
9292
module_kwargs,
9393
metric_value,
94-
self.decision_metric_name,
94+
self.target_metric,
9595
module.get_assets(), # retriever name / scores / predictions
9696
module_dump_dir,
9797
module=module if not context.is_ram_to_clear() else None,

0 commit comments

Comments
 (0)