Skip to content

Commit 5d584e8

Browse files
committed
stage progress on type fixing
1 parent 48d518f commit 5d584e8

File tree

25 files changed

+112
-96
lines changed

25 files changed

+112
-96
lines changed

autointent/_dataset/_dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,20 @@
99
from datasets import Dataset as HFDataset
1010
from datasets import Sequence, get_dataset_config_names, load_dataset
1111

12-
from autointent.custom_types import LabelType, Split
12+
from autointent.custom_types import LabelWithOOS, Split
1313
from autointent.schemas import Intent, Tag
1414

1515

1616
class Sample(TypedDict):
1717
"""
1818
Typed dictionary representing a dataset sample.
1919
20-
:param str utterance: The text of the utterance.
21-
:param LabelType | None label: The label associated with the utterance, or None if out-of-scope.
20+
:param utterance: The text of the utterance.
21+
:param label: The label associated with the utterance, or None if out-of-scope.
2222
"""
2323

2424
utterance: str
25-
label: LabelType | None
25+
label: LabelWithOOS
2626

2727

2828
class Dataset(dict[str, HFDataset]):

autointent/_pipeline/_pipeline.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,23 @@
33
import json
44
import logging
55
from pathlib import Path
6-
from typing import Any
6+
from typing import TYPE_CHECKING, Any
77

88
import numpy as np
99
import yaml
1010

1111
from autointent import Context, Dataset
1212
from autointent.configs import CrossEncoderConfig, EmbedderConfig, InferenceNodeConfig, LoggingConfig, VectorIndexConfig
13-
from autointent.custom_types import LabelType, NodeType
13+
from autointent.custom_types import ListOfGenericLabels, NodeType
1414
from autointent.metrics import PREDICTION_METRICS_MULTILABEL
1515
from autointent.nodes import InferenceNode, NodeOptimizer
1616
from autointent.utils import load_default_search_space, load_search_space
1717

1818
from ._schemas import InferencePipelineOutput, InferencePipelineUtteranceOutput
1919

20+
if TYPE_CHECKING:
21+
from autointent.modules.abc import DecisionModule, ScoringModule
22+
2023

2124
class Pipeline:
2225
"""Pipeline optimizer class."""
@@ -184,7 +187,7 @@ def load(cls, path: str | Path) -> "Pipeline":
184187
inference_dict_config = yaml.safe_load(file)
185188
return cls.from_dict_config(inference_dict_config["nodes_configs"])
186189

187-
def predict(self, utterances: list[str]) -> list[LabelType | None]:
190+
def predict(self, utterances: list[str]) -> ListOfGenericLabels:
188191
"""
189192
Predict the labels for the utterances.
190193
@@ -195,8 +198,11 @@ def predict(self, utterances: list[str]) -> list[LabelType | None]:
195198
msg = "Pipeline in optimization mode cannot perform inference"
196199
raise RuntimeError(msg)
197200

198-
scores = self.nodes[NodeType.scoring].module.predict(utterances) # type: ignore[union-attr]
199-
return self.nodes[NodeType.decision].module.predict(scores) # type: ignore[union-attr]
201+
scoring_module: ScoringModule = self.nodes[NodeType.scoring].module # type: ignore[assignment,union-attr]
202+
decision_module: DecisionModule = self.nodes[NodeType.decision].module # type: ignore[assignment,union-attr]
203+
204+
scores = scoring_module.predict(utterances)
205+
return decision_module.predict(scores)
200206

201207
def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutput:
202208
"""
@@ -210,7 +216,7 @@ def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutpu
210216
raise RuntimeError(msg)
211217

212218
scores, scores_metadata = self.nodes[NodeType.scoring].module.predict_with_metadata(utterances) # type: ignore[union-attr]
213-
predictions = self.nodes[NodeType.decision].module.predict(scores) # type: ignore[union-attr]
219+
predictions = self.nodes[NodeType.decision].module.predict(scores) # type: ignore[union-attr,arg-type]
214220
regexp_predictions, regexp_predictions_metadata = None, None
215221
if NodeType.regexp in self.nodes:
216222
regexp_predictions, regexp_predictions_metadata = self.nodes[NodeType.regexp].module.predict_with_metadata( # type: ignore[union-attr]

autointent/_pipeline/_schemas.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22

33
from pydantic import BaseModel
44

5-
from autointent.custom_types import LabelType
5+
from autointent.custom_types import LabelWithOOS, ListOfLabels, ListOfLabelsWithOOS
66

77

88
class InferencePipelineUtteranceOutput(BaseModel):
99
"""Output of the inference pipeline for a single utterance."""
1010

1111
utterance: str
12-
prediction: LabelType | None
13-
regexp_prediction: LabelType | None
12+
prediction: LabelWithOOS
13+
regexp_prediction: LabelWithOOS
1414
regexp_prediction_metadata: Any
1515
score: list[float]
1616
score_metadata: Any
@@ -19,6 +19,6 @@ class InferencePipelineUtteranceOutput(BaseModel):
1919
class InferencePipelineOutput(BaseModel):
2020
"""Output of the inference pipeline."""
2121

22-
predictions: list[LabelType | None]
23-
regexp_predictions: list[LabelType] | None = None
22+
predictions: ListOfLabelsWithOOS
23+
regexp_predictions: ListOfLabels | None = None
2424
utterances: list[InferencePipelineUtteranceOutput] | None = None

autointent/_ranker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from sklearn.linear_model import LogisticRegressionCV
1919
from torch import nn
2020

21-
from autointent.custom_types import LabelType
21+
from autointent.custom_types import ListOfLabels
2222

2323
logger = logging.getLogger(__name__)
2424

@@ -158,7 +158,7 @@ def _get_features_or_predictions(self, pairs: list[tuple[str, str]]) -> npt.NDAr
158158
self._activations_list.clear()
159159
return res # type: ignore[no-any-return]
160160

161-
def _fit(self, pairs: list[tuple[str, str]], labels: list[LabelType]) -> None:
161+
def _fit(self, pairs: list[tuple[str, str]], labels: ListOfLabels) -> None:
162162
"""
163163
Train the logistic regression model on cross-encoder features.
164164
@@ -181,7 +181,7 @@ def _fit(self, pairs: list[tuple[str, str]], labels: list[LabelType]) -> None:
181181

182182
self._clf = clf
183183

184-
def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
184+
def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
185185
"""
186186
Construct training samples and train the logistic regression classifier.
187187

autointent/_vector_index.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import numpy.typing as npt
1616

1717
from autointent import Embedder
18-
from autointent.custom_types import LabelType
18+
from autointent.custom_types import ListOfLabels
1919

2020

2121
class VectorIndexMetadata(TypedDict):
@@ -28,7 +28,7 @@ class VectorIndexMetadata(TypedDict):
2828

2929
class VectorIndexData(TypedDict):
3030
texts: list[str]
31-
labels: list[LabelType]
31+
labels: ListOfLabels
3232

3333

3434
class VectorIndex:
@@ -68,12 +68,12 @@ def __init__(
6868
)
6969
self.embedder_device = embedder_device
7070

71-
self.labels: list[LabelType] = [] # (n_samples,) or (n_samples, n_classes)
71+
self.labels: ListOfLabels = [] # (n_samples,) or (n_samples, n_classes)
7272
self.texts: list[str] = []
7373

7474
self.logger = logging.getLogger(__name__)
7575

76-
def add(self, texts: list[str], labels: list[LabelType]) -> None:
76+
def add(self, texts: list[str], labels: ListOfLabels) -> None:
7777
"""
7878
Add texts and their corresponding labels to the index.
7979
@@ -160,7 +160,7 @@ def get_all_embeddings(self) -> npt.NDArray[Any]:
160160
raise ValueError(msg)
161161
return self.index.reconstruct_n(0, self.index.ntotal) # type: ignore[no-any-return]
162162

163-
def get_all_labels(self) -> list[LabelType]:
163+
def get_all_labels(self) -> ListOfLabels:
164164
"""
165165
Retrieve all labels stored in the index.
166166
@@ -172,7 +172,7 @@ def query(
172172
self,
173173
queries: list[str] | npt.NDArray[np.float32],
174174
k: int,
175-
) -> tuple[list[list[LabelType]], list[list[float]], list[list[str]]]:
175+
) -> tuple[list[ListOfLabels], list[list[float]], list[list[str]]]:
176176
"""
177177
Query the index to retrieve nearest neighbors.
178178

autointent/context/data_handler/_data_handler.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from transformers import set_seed
99

1010
from autointent import Dataset
11-
from autointent.custom_types import LabelType, Split
11+
from autointent.custom_types import ListOfGenericLabels, Split
1212

1313
from ._stratification import split_dataset
1414

@@ -83,7 +83,7 @@ def train_utterances(self, idx: int | None = None) -> list[str]:
8383
split = f"{Split.TRAIN}_{idx}" if idx is not None else Split.TRAIN
8484
return cast(list[str], self.dataset[split][self.dataset.utterance_feature])
8585

86-
def train_labels(self, idx: int | None = None) -> list[LabelType]:
86+
def train_labels(self, idx: int | None = None) -> ListOfGenericLabels:
8787
"""
8888
Retrieve training labels from the dataset.
8989
@@ -95,7 +95,7 @@ def train_labels(self, idx: int | None = None) -> list[LabelType]:
9595
:return: List of training labels.
9696
"""
9797
split = f"{Split.TRAIN}_{idx}" if idx is not None else Split.TRAIN
98-
return cast(list[LabelType], self.dataset[split][self.dataset.label_feature])
98+
return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature])
9999

100100
def validation_utterances(self, idx: int | None = None) -> list[str]:
101101
"""
@@ -111,7 +111,7 @@ def validation_utterances(self, idx: int | None = None) -> list[str]:
111111
split = f"{Split.VALIDATION}_{idx}" if idx is not None else Split.VALIDATION
112112
return cast(list[str], self.dataset[split][self.dataset.utterance_feature])
113113

114-
def validation_labels(self, idx: int | None = None) -> list[LabelType]:
114+
def validation_labels(self, idx: int | None = None) -> ListOfGenericLabels:
115115
"""
116116
Retrieve validation labels from the dataset.
117117
@@ -123,7 +123,7 @@ def validation_labels(self, idx: int | None = None) -> list[LabelType]:
123123
:return: List of validation labels.
124124
"""
125125
split = f"{Split.VALIDATION}_{idx}" if idx is not None else Split.VALIDATION
126-
return cast(list[LabelType], self.dataset[split][self.dataset.label_feature])
126+
return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature])
127127

128128
def test_utterances(self, idx: int | None = None) -> list[str]:
129129
"""
@@ -139,7 +139,7 @@ def test_utterances(self, idx: int | None = None) -> list[str]:
139139
split = f"{Split.TEST}_{idx}" if idx is not None else Split.TEST
140140
return cast(list[str], self.dataset[split][self.dataset.utterance_feature])
141141

142-
def test_labels(self, idx: int | None = None) -> list[LabelType]:
142+
def test_labels(self, idx: int | None = None) -> ListOfGenericLabels:
143143
"""
144144
Retrieve test labels from the dataset.
145145
@@ -151,7 +151,7 @@ def test_labels(self, idx: int | None = None) -> list[LabelType]:
151151
:return: List of test labels.
152152
"""
153153
split = f"{Split.TEST}_{idx}" if idx is not None else Split.TEST
154-
return cast(list[LabelType], self.dataset[split][self.dataset.label_feature])
154+
return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature])
155155

156156
def dump(self, filepath: str | Path) -> None:
157157
"""

autointent/context/optimization_info/_data_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from numpy.typing import NDArray
1111
from pydantic import BaseModel, ConfigDict, Field
1212

13-
from autointent.custom_types import LabelType, NodeType
13+
from autointent.custom_types import ListOfLabelsWithOOS, NodeType
1414

1515

1616
class Artifact(BaseModel):
@@ -53,7 +53,7 @@ class DecisionArtifact(Artifact):
5353
"""
5454

5555
model_config = ConfigDict(arbitrary_types_allowed=True)
56-
labels: list[LabelType | None]
56+
labels: ListOfLabelsWithOOS
5757

5858

5959
def validate_node_name(value: str) -> str:

autointent/custom_types.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
from enum import Enum
8-
from typing import Literal, TypedDict
8+
from typing import Literal, TypeAlias, TypedDict
99

1010

1111
class LogLevel(Enum):
@@ -29,7 +29,15 @@ class LogLevel(Enum):
2929
"""
3030

3131
# Type alias for label representation
32-
LabelType = int | list[int]
32+
SimpleLabel = int
33+
MultiLabel = list[int]
34+
SimpleLabelWithOOS = SimpleLabel | None
35+
MultiLabelWithOOS = MultiLabel | None
36+
ListOfLabels = list[SimpleLabel] | list[MultiLabel]
37+
ListOfLabelsWithOOS = list[SimpleLabelWithOOS] | list[MultiLabelWithOOS]
38+
LabelType: TypeAlias = SimpleLabel | MultiLabel
39+
LabelWithOOS = LabelType | None
40+
ListOfGenericLabels = ListOfLabels | ListOfLabelsWithOOS
3341
"""
3442
Type alias for label representation
3543

autointent/metrics/decision.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,17 @@
88
import numpy.typing as npt
99
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score
1010

11-
from autointent.custom_types import LabelType
11+
from autointent.custom_types import ListOfGenericLabels, ListOfLabels
1212

1313
from ._converter import transform
14-
from .custom_types import LABELS_VALUE_TYPE
1514

1615
logger = logging.getLogger(__name__)
1716

1817

1918
class DecisionMetricFn(Protocol):
2019
"""Protocol for decision metrics."""
2120

22-
def __call__(self, y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
21+
def __call__(self, y_true: ListOfGenericLabels, y_pred: ListOfGenericLabels) -> float:
2322
"""
2423
Calculate decision metric.
2524
@@ -32,17 +31,14 @@ def __call__(self, y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> floa
3231
...
3332

3433

35-
def handle_oos(
36-
y_true: list[LabelType | None], y_pred: list[LabelType | None]
37-
) -> tuple[list[LabelType], list[LabelType]]:
34+
def handle_oos(y_true: ListOfGenericLabels, y_pred: ListOfGenericLabels) -> tuple[ListOfLabels, ListOfLabels]:
3835
"""Convert labels of OOS samples to make them usable in decision metrics."""
3936
in_domain_labels = list(filter(lambda lab: lab is not None, y_true))
40-
multilabel = isinstance(in_domain_labels[0], list)
41-
if multilabel:
37+
if isinstance(in_domain_labels[0], list):
4238
func = _add_oos_multilabel
4339
n_classes = len(in_domain_labels[0])
4440
else:
45-
func = _add_oos_multiclass
41+
func = _add_oos_multiclass # type: ignore[assignment]
4642
n_classes = len(set(in_domain_labels))
4743
func = partial(func, n_classes=n_classes)
4844
return list(map(func, y_true)), list(map(func, y_pred))
@@ -60,7 +56,7 @@ def _add_oos_multilabel(label: list[int] | None, n_classes: int) -> list[int]:
6056
return [*label, 1]
6157

6258

63-
def decision_accuracy(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
59+
def decision_accuracy(y_true: ListOfGenericLabels, y_pred: ListOfGenericLabels) -> float:
6460
r"""
6561
Calculate decision accuracy. Supports both multiclass and multilabel.
6662
@@ -131,7 +127,7 @@ def _decision_roc_auc_multilabel(y_true: npt.NDArray[Any], y_pred: npt.NDArray[A
131127
return float(roc_auc_score(y_true, y_pred, average="macro"))
132128

133129

134-
def decision_roc_auc(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
130+
def decision_roc_auc(y_true: ListOfGenericLabels, y_pred: ListOfGenericLabels) -> float:
135131
r"""
136132
Calculate ROC AUC for multiclass and multilabel classification.
137133
@@ -153,7 +149,7 @@ def decision_roc_auc(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> fl
153149
raise ValueError(msg)
154150

155151

156-
def decision_precision(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
152+
def decision_precision(y_true: ListOfGenericLabels, y_pred: ListOfGenericLabels) -> float:
157153
r"""
158154
Calculate decision precision. Supports both multiclass and multilabel.
159155
@@ -168,7 +164,7 @@ def decision_precision(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) ->
168164
return float(precision_score(*handle_oos(y_true, y_pred), average="macro"))
169165

170166

171-
def decision_recall(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
167+
def decision_recall(y_true: ListOfGenericLabels, y_pred: ListOfGenericLabels) -> float:
172168
r"""
173169
Calculate decision recall. Supports both multiclass and multilabel.
174170
@@ -183,7 +179,7 @@ def decision_recall(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> flo
183179
return float(recall_score(*handle_oos(y_true, y_pred), average="macro"))
184180

185181

186-
def decision_f1(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
182+
def decision_f1(y_true: ListOfGenericLabels, y_pred: ListOfGenericLabels) -> float:
187183
r"""
188184
Calculate decision f1 score. Supports both multiclass and multilabel.
189185

autointent/metrics/scoring.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ def scoring_roc_auc(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> flo
100100
return float(roc_auc_score(labels_, scores_, average="macro"))
101101

102102

103-
def _calculate_decision_metric(func: DecisionMetricFn, labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
103+
def _calculate_decision_metric(
104+
func: DecisionMetricFn, labels: list[int] | list[list[int]], scores: SCORES_VALUE_TYPE
105+
) -> float:
104106
r"""
105107
Calculate decision metric.
106108

0 commit comments

Comments
 (0)