Skip to content

Commit e7ceb52

Browse files
Test/regex module (#156)
* forbid extra field in pydantic configs * bug fix for regex module * add integration test for regex module * fix codestyle * regex bug fix in validation logic * update test for regex * fix typing * Update optimizer_config.schema.json * update tests * change cv logic a little bit * fix `train_classifier` issue * fix docs building * fix docs building * try another --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 2122a33 commit e7ceb52

File tree

17 files changed

+113
-44
lines changed

17 files changed

+113
-44
lines changed

autointent/_pipeline/_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(
6060
self.sampler = sampler
6161

6262
if isinstance(nodes[0], NodeOptimizer):
63-
self.logging_config = LoggingConfig(dump_dir=None)
63+
self.logging_config = LoggingConfig()
6464
self.embedder_config = EmbedderConfig()
6565
self.cross_encoder_config = CrossEncoderConfig()
6666
self.data_config = DataConfig()

autointent/_ranker.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ class CrossEncoderMetadata(TypedDict):
3131
3232
Attributes:
3333
model_name: Name of the model
34-
train_classifier: Whether to train a classifier
34+
train_head: Whether to train a classifier
3535
device: Device to use for inference
3636
max_length: Maximum sequence length
3737
batch_size: Batch size for inference
3838
"""
3939

4040
model_name: str
41-
train_classifier: bool
41+
train_head: bool
4242
device: str | None
4343
max_length: int | None
4444
batch_size: int
@@ -119,11 +119,11 @@ def __init__(
119119
device=self.cross_encoder_config.device,
120120
max_length=self.cross_encoder_config.max_length, # type: ignore[arg-type]
121121
)
122-
self.train_classifier = False
122+
self.train_head = False
123123
self._clf = classifier_head
124124

125125
if classifier_head is not None or self.cross_encoder_config.train_head:
126-
self.train_classifier = True
126+
self.train_head = True
127127
self._activations_list: list[npt.NDArray[Any]] = []
128128
self._hook_handler = self.cross_encoder.model.classifier.register_forward_hook(self._classifier_hook)
129129

@@ -147,7 +147,7 @@ def _get_features_or_predictions(self, pairs: list[tuple[str, str]]) -> npt.NDAr
147147
Returns:
148148
Array of extracted features or predictions
149149
"""
150-
if not self.train_classifier:
150+
if not self.train_head:
151151
return np.array(
152152
self.cross_encoder.predict(
153153
pairs,
@@ -189,7 +189,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
189189
utterances: List of utterances (texts)
190190
labels: Intent class labels corresponding to the utterances
191191
"""
192-
if not self.train_classifier:
192+
if not self.train_head:
193193
return
194194

195195
pairs, labels_ = construct_samples(utterances, labels, balancing_factor=1)
@@ -207,7 +207,7 @@ def predict(self, pairs: list[tuple[str, str]]) -> npt.NDArray[Any]:
207207
Raises:
208208
ValueError: If classifier is not trained yet
209209
"""
210-
if self.train_classifier and self._clf is None:
210+
if self.train_head and self._clf is None:
211211
msg = "Classifier is not trained yet"
212212
raise ValueError(msg)
213213

@@ -254,7 +254,7 @@ def save(self, path: str) -> None:
254254

255255
metadata = CrossEncoderMetadata(
256256
model_name=self.cross_encoder_config.model_name,
257-
train_classifier=self.train_classifier,
257+
train_head=self.train_head,
258258
device=self.cross_encoder_config.device,
259259
max_length=self.cross_encoder_config.max_length,
260260
batch_size=self.cross_encoder_config.batch_size,

autointent/configs/_optimization.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from pathlib import Path
44

5-
from pydantic import BaseModel, Field, PositiveInt
5+
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
66

77
from autointent._callbacks import REPORTERS_NAMES
88
from autointent.custom_types import FloatFromZeroToOne, ValidationScheme
@@ -13,6 +13,7 @@
1313
class DataConfig(BaseModel):
1414
"""Configuration for the data used in the optimization process."""
1515

16+
model_config = ConfigDict(extra="forbid")
1617
scheme: ValidationScheme = Field("ho", description="Validation scheme to use.")
1718
"""Hold-out or cross-validation."""
1819
n_folds: PositiveInt = Field(3, description="Number of folds in cross-validation.")
@@ -33,6 +34,8 @@ class DataConfig(BaseModel):
3334
class LoggingConfig(BaseModel):
3435
"""Configuration for the logging."""
3536

37+
model_config = ConfigDict(extra="forbid")
38+
3639
_dirpath: Path | None = None
3740
_dump_dir: Path | None = None
3841

autointent/configs/_transformers.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
from enum import Enum
22
from typing import Any
33

4-
from pydantic import (
5-
BaseModel,
6-
Field,
7-
PositiveInt,
8-
)
4+
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
95
from typing_extensions import Self, assert_never
106

117

128
class ModelConfig(BaseModel):
9+
model_config = ConfigDict(extra="forbid")
1310
batch_size: PositiveInt = Field(32, description="Batch size for model inference.")
1411
max_length: PositiveInt | None = Field(None, description="Maximum length of input sequences.")
1512

autointent/modules/regex/_simple.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
"""Module for regular expressions based intent detection."""
22

33
import re
4+
from collections.abc import Iterable
45
from typing import Any, TypedDict
56

7+
import numpy as np
8+
import numpy.typing as npt
9+
610
from autointent import Context
711
from autointent.context.data_handler._data_handler import RegexPatterns
812
from autointent.context.optimization_info import Artifact
9-
from autointent.custom_types import LabelType
13+
from autointent.custom_types import LabelType, ListOfGenericLabels, ListOfLabels
1014
from autointent.metrics import REGEX_METRICS
1115
from autointent.modules.base import BaseRegex
1216
from autointent.schemas import Intent
@@ -36,7 +40,10 @@ class Regex(BaseRegex):
3640
name: Name of the module, defaults to "regex"
3741
"""
3842

39-
name = "regex"
43+
name = "simple"
44+
supports_multiclass = True
45+
supports_multilabel = True
46+
supports_oos = False
4047

4148
@classmethod
4249
def from_context(cls, context: Context) -> "Regex":
@@ -158,7 +165,7 @@ def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]:
158165
return self.score_metrics_ho((val_labels, pred_labels), chosen_metrics)
159166

160167
def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]:
161-
"""Score the model using cross-validation.
168+
"""Score the model in cross-validation mode.
162169
163170
Args:
164171
context: Context containing validation data
@@ -169,10 +176,42 @@ def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]:
169176
"""
170177
chosen_metrics = {name: fn for name, fn in REGEX_METRICS.items() if name in metrics}
171178

172-
metrics_calculated, _ = self.score_metrics_cv(chosen_metrics, context.data_handler.validation_iterator())
179+
metrics_calculated, _ = self.score_metrics_cv(
180+
chosen_metrics, context.data_handler.validation_iterator(), intents=context.data_handler.dataset.intents
181+
)
173182

174183
return metrics_calculated
175184

185+
def score_metrics_cv(
186+
self,
187+
metrics_dict: dict[str, Any],
188+
cv_iterator: Iterable[tuple[list[str], ListOfLabels, list[str], ListOfLabels]],
189+
intents: list[Intent],
190+
) -> tuple[dict[str, float], list[ListOfGenericLabels] | list[npt.NDArray[Any]]]:
191+
"""Score metrics using cross-validation.
192+
193+
Args:
194+
metrics_dict: Dictionary of metrics to compute
195+
cv_iterator: Cross-validation iterator
196+
intents: intents from the dataset
197+
198+
Returns:
199+
Tuple of metrics dictionary and predictions
200+
"""
201+
metrics_values: dict[str, list[float]] = {name: [] for name in metrics_dict}
202+
all_val_preds = []
203+
204+
self.fit(intents)
205+
206+
for _, _, val_utterances, val_labels in cv_iterator:
207+
val_preds = self.predict(val_utterances)
208+
for name, fn in metrics_dict.items():
209+
metrics_values[name].append(fn(val_labels, val_preds))
210+
all_val_preds.append(val_preds)
211+
212+
metrics = {name: float(np.mean(values_list)) for name, values_list in metrics_values.items()}
213+
return metrics, all_val_preds # type: ignore[return-value]
214+
176215
def clear_cache(self) -> None:
177216
"""Clear cached regex patterns."""
178217
del self.regex_patterns

autointent/modules/scoring/_dnnc/dnnc.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,6 @@ class DNNCScorer(BaseScorer):
5252
5353
test_utterances = ["Hello!", "What's up?"]
5454
scores = scorer.predict(test_utterances)
55-
print(scores) # Outputs similarity scores for the utterances
56-
57-
58-
.. testoutput::
59-
60-
[[0.00013581 0. ]
61-
[0.00030066 0. ]]
6255
6356
"""
6457

autointent/modules/scoring/_knn/knn.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,6 @@ class KNNScorer(BaseScorer):
4242
scorer.fit(utterances, labels)
4343
test_utterances = ["hi", "what's up?"]
4444
probabilities = scorer.predict(test_utterances)
45-
print(probabilities) # Outputs predicted class probabilities for the utterances
46-
47-
.. testoutput::
48-
49-
[[0.67297815 0.32702185]
50-
[0.44031667 0.55968333]]
5145
5246
"""
5347

autointent/nodes/info/_regex.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from autointent.custom_types import NodeType
77
from autointent.metrics import REGEX_METRICS
88
from autointent.metrics.regex import RegexMetricFn
9+
from autointent.modules import REGEX_MODULES
910
from autointent.modules.base import BaseRegex
10-
from autointent.modules.regex import Regex
1111

1212
from ._base import NodeInfo
1313

@@ -17,6 +17,6 @@ class RegexNodeInfo(NodeInfo):
1717

1818
metrics_available: ClassVar[Mapping[str, RegexMetricFn]] = REGEX_METRICS
1919

20-
modules_available: ClassVar[Mapping[str, type[BaseRegex]]] = {NodeType.regex: Regex}
20+
modules_available: ClassVar[Mapping[str, type[BaseRegex]]] = REGEX_MODULES
2121

2222
node_type = NodeType.regex

docs/optimizer_config.schema.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
{
22
"$defs": {
33
"CrossEncoderConfig": {
4+
"additionalProperties": false,
45
"properties": {
56
"batch_size": {
67
"default": 32,
@@ -53,6 +54,7 @@
5354
"type": "object"
5455
},
5556
"DataConfig": {
57+
"additionalProperties": false,
5658
"description": "Configuration for the data used in the optimization process.",
5759
"properties": {
5860
"scheme": {
@@ -100,6 +102,7 @@
100102
"type": "object"
101103
},
102104
"EmbedderConfig": {
105+
"additionalProperties": false,
103106
"properties": {
104107
"batch_size": {
105108
"default": 32,
@@ -230,6 +233,7 @@
230233
"type": "object"
231234
},
232235
"LoggingConfig": {
236+
"additionalProperties": false,
233237
"description": "Configuration for the logging.",
234238
"properties": {
235239
"project_dir": {

tests/assets/configs/regex.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
- node_type: regex
2+
target_metric: regex_partial_accuracy
3+
search_space:
4+
- module_name: simple
5+
- node_type: scoring
6+
target_metric: scoring_roc_auc
7+
search_space:
8+
- module_name: linear
9+
embedder_config:
10+
- model_name: sentence-transformers/all-MiniLM-L6-v2
11+
- node_type: decision
12+
target_metric: decision_accuracy
13+
search_space:
14+
- module_name: argmax

0 commit comments

Comments
 (0)