Skip to content

Commit 8c092fd

Browse files
committed
finish fixing typing
1 parent 5d584e8 commit 8c092fd

File tree

12 files changed

+28
-56
lines changed

12 files changed

+28
-56
lines changed

autointent/_vector_index.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def add(self, texts: list[str], labels: ListOfLabels) -> None:
8686
if not hasattr(self, "index"):
8787
self.index = faiss.IndexFlatIP(embeddings.shape[1])
8888
self.index.add(embeddings)
89-
self.labels.extend(labels)
89+
self.labels.extend(labels) # type: ignore[arg-type]
9090
self.texts.extend(texts)
9191

9292
def is_empty(self) -> bool:
@@ -186,9 +186,9 @@ def query(
186186
func = self._search_by_text if isinstance(queries[0], str) else self._search_by_embedding
187187
all_results = func(queries, k) # type: ignore[arg-type]
188188

189-
all_labels = [[self.labels[result["id"]] for result in results] for results in all_results]
189+
all_labels: list[ListOfLabels] = [[self.labels[result["id"]] for result in results] for results in all_results]
190190
all_distances = [[float(result["distance"]) for result in results] for results in all_results]
191-
all_texts = [[self.texts[result["id"]] for result in results] for results in all_results]
191+
all_texts: list[list[str]] = [[self.texts[result["id"]] for result in results] for results in all_results]
192192

193193
return all_labels, all_distances, all_texts
194194

autointent/metrics/custom_types.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
import numpy.typing as npt
66

7-
from autointent.custom_types import LabelType
7+
from autointent.custom_types import ListOfLabels
88

9-
LABELS_VALUE_TYPE = list[LabelType] | npt.NDArray[Any]
9+
LABELS_VALUE_TYPE = ListOfLabels
1010

11-
CANDIDATE_TYPE = list[list[LabelType]] | npt.NDArray[Any]
11+
CANDIDATE_TYPE = list[ListOfLabels] | npt.NDArray[Any]
1212

1313
SCORES_VALUE_TYPE = list[list[float]] | npt.NDArray[Any]

autointent/metrics/retrieval.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Retrieval metrics."""
22

3-
from collections.abc import Callable
43
from typing import Any, Protocol
54

65
import numpy as np
@@ -36,7 +35,7 @@ def __call__(
3635

3736

3837
def _macrofy(
39-
metric_fn: Callable[[npt.NDArray[Any], npt.NDArray[Any], int | None], float],
38+
metric_fn: RetrievalMetricFn,
4039
query_labels: LABELS_VALUE_TYPE,
4140
candidates_labels: CANDIDATE_TYPE,
4241
k: int | None = None,
@@ -72,7 +71,7 @@ def _macrofy(
7271
for i in range(n_classes):
7372
binarized_query_labels = query_labels_[..., i]
7473
binarized_candidates_labels = candidates_labels_[..., i]
75-
classwise_values.append(metric_fn(binarized_query_labels, binarized_candidates_labels, k))
74+
classwise_values.append(metric_fn(binarized_query_labels, binarized_candidates_labels, k)) # type: ignore[arg-type]
7675

7776
return np.mean(classwise_values) # type: ignore[return-value]
7877

@@ -136,12 +135,12 @@ def retrieval_map(query_labels: LABELS_VALUE_TYPE, candidates_labels: CANDIDATE_
136135
:param k: Number of top items to consider for each query
137136
:return: Score of the retrieval metric
138137
"""
139-
ap_list = [_average_precision(q, c, k) for q, c in zip(query_labels, candidates_labels, strict=True)]
138+
ap_list = [_average_precision(q, c, k) for q, c in zip(query_labels, candidates_labels, strict=True)] # type: ignore[arg-type]
140139
return sum(ap_list) / len(ap_list)
141140

142141

143142
def _average_precision_intersecting(
144-
query_label: LABELS_VALUE_TYPE, candidate_labels: CANDIDATE_TYPE, k: int | None = None
143+
query_label: list[int], candidate_labels: CANDIDATE_TYPE, k: int | None = None
145144
) -> float:
146145
r"""
147146
Calculate the average precision at position k for the intersecting labels.
@@ -212,7 +211,7 @@ def retrieval_map_intersecting(
212211
:param k: Number of top items to consider for each query
213212
:return: Score of the retrieval metric
214213
"""
215-
ap_list = [_average_precision_intersecting(q, c, k) for q, c in zip(query_labels, candidates_labels, strict=True)]
214+
ap_list = [_average_precision_intersecting(q, c, k) for q, c in zip(query_labels, candidates_labels, strict=True)] # type: ignore[arg-type]
216215
return sum(ap_list) / len(ap_list)
217216

218217

autointent/modules/abc/_decision.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,7 @@ def get_assets(self) -> DecisionArtifact:
6363
def clear_cache(self) -> None:
6464
"""Clear cache."""
6565

66-
def _validate_inputs(
67-
self, scores: npt.NDArray[Any], labels: ListOfGenericLabels
68-
) -> tuple[int, bool]:
66+
def _validate_inputs(self, scores: npt.NDArray[Any], labels: ListOfGenericLabels) -> tuple[int, bool, bool]:
6967
"""
7068
Sanity check if labels and scores are valid to be a training data for decision module.
7169

autointent/modules/decision/_adaptive.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy.typing as npt
88

99
from autointent import Context
10-
from autointent.custom_types import ListOfGenericLabels, ListOfLabelsWithOOS
10+
from autointent.custom_types import ListOfGenericLabels, ListOfLabelsWithOOS, MultiLabel
1111
from autointent.metrics import decision_f1
1212
from autointent.modules.abc import DecisionModule
1313
from autointent.schemas import Tag
@@ -146,10 +146,11 @@ def multilabel_predict(scores: npt.NDArray[Any], r: float, tags: list[Tag] | Non
146146
res = (scores >= thresh[:, None]).astype(int)
147147
if tags:
148148
res = apply_tags(res, scores, tags)
149-
return [lab if sum(lab) > 0 else None for lab in res.tolist()]
149+
y_pred: list[MultiLabel] = res.tolist() # type: ignore[assignment]
150+
return [lab if sum(lab) > 0 else None for lab in y_pred]
150151

151152

152-
def multilabel_score(y_true: ListOfLabelsWithOOS, y_pred: ListOfLabelsWithOOS) -> float:
153+
def multilabel_score(y_true: ListOfGenericLabels, y_pred: ListOfGenericLabels) -> float:
153154
"""
154155
Calculate the weighted F1 score for multi-label classification.
155156

autointent/modules/decision/_jinoos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def predict(self, scores: npt.NDArray[Any]) -> list[int | None]:
114114
if scores.shape[1] != self._n_classes:
115115
raise InvalidNumClassesError
116116
pred_classes, best_scores = _predict(scores)
117-
y_pred = _detect_oos(pred_classes, best_scores, self._thresh).tolist()
117+
y_pred: list[int] = _detect_oos(pred_classes, best_scores, self._thresh).tolist() # type: ignore[assignment]
118118
return [lab if lab != -1 else None for lab in y_pred]
119119

120120
@staticmethod

autointent/modules/decision/_threshold.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy.typing as npt
88

99
from autointent import Context
10-
from autointent.custom_types import ListOfGenericLabels
10+
from autointent.custom_types import ListOfGenericLabels, MultiLabel
1111
from autointent.modules.abc import DecisionModule
1212
from autointent.schemas import Tag
1313

@@ -141,7 +141,7 @@ def predict(self, scores: npt.NDArray[Any]) -> ListOfGenericLabels:
141141
return multiclass_predict(scores, self.thresh)
142142

143143

144-
def multiclass_predict(scores: npt.NDArray[Any], thresh: float | npt.NDArray[Any]) -> npt.NDArray[Any]:
144+
def multiclass_predict(scores: npt.NDArray[Any], thresh: float | npt.NDArray[Any]) -> ListOfGenericLabels:
145145
"""
146146
Make predictions for multiclass classification task.
147147
@@ -158,14 +158,15 @@ def multiclass_predict(scores: npt.NDArray[Any], thresh: float | npt.NDArray[Any
158158
thresh_selected = thresh[pred_classes]
159159
pred_classes[best_scores < thresh_selected] = -1 # out of scope
160160

161-
return [lab if lab != -1 else None for lab in pred_classes.tolist()]
161+
y_pred: list[int] = pred_classes.tolist() # type: ignore[assignment]
162+
return [lab if lab != -1 else None for lab in y_pred]
162163

163164

164165
def multilabel_predict(
165166
scores: npt.NDArray[Any],
166167
thresh: float | npt.NDArray[Any],
167168
tags: list[Tag] | None,
168-
) -> npt.NDArray[Any]:
169+
) -> ListOfGenericLabels:
169170
"""
170171
Make predictions for multilabel classification task.
171172
@@ -177,4 +178,5 @@ def multilabel_predict(
177178
res = (scores >= thresh).astype(int) if isinstance(thresh, float) else (scores >= thresh[None, :]).astype(int)
178179
if tags:
179180
res = apply_tags(res, scores, tags)
180-
return [lab if sum(lab) > 0 else None for lab in res.tolist()]
181+
y_pred: list[MultiLabel] = res.tolist() # type: ignore[assignment]
182+
return [lab if sum(lab) > 0 else None for lab in y_pred]

autointent/modules/decision/_tunable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def objective(self, trial: Trial) -> float:
176176
y_pred = multilabel_predict(self.probas, thresholds, self.tags)
177177
else:
178178
y_pred = multiclass_predict(self.probas, thresholds)
179-
return decision_f1(self.labels, y_pred) # type: ignore[no-any-return]
179+
return decision_f1(self.labels, y_pred)
180180

181181
def fit(
182182
self,

autointent/modules/embedding/_retrieval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def load(self, path: str) -> None:
167167
"""
168168
self._vector_index = VectorIndex.load(Path(path))
169169

170-
def predict(self, utterances: list[str]) -> tuple[list[list[int | list[int]]], list[list[float]], list[list[str]]]:
170+
def predict(self, utterances: list[str]) -> tuple[list[ListOfLabels], list[list[float]], list[list[str]]]:
171171
"""
172172
Predict the nearest neighbors for a list of utterances.
173173

autointent/modules/regexp/_regexp.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
"""Module for regular expressions based intent detection."""
22

3-
import json
43
import re
5-
from pathlib import Path
64
from typing import Any, Literal, TypedDict
75

86
from autointent import Context
@@ -140,30 +138,6 @@ def get_assets(self) -> Artifact:
140138
"""Get assets."""
141139
return Artifact()
142140

143-
def load(self, path: str) -> None:
144-
"""
145-
Load data from dump.
146-
147-
:param path: Path to load
148-
"""
149-
dump_dir = Path(path)
150-
151-
with (dump_dir / self.metadata_dict_name).open() as file:
152-
self.regexp_patterns = json.load(file)
153-
154-
self._compile_regex_patterns()
155-
156-
def dump(self, path: str) -> None:
157-
"""
158-
Dump all data needed for inference.
159-
160-
:param path: Path to dump
161-
"""
162-
dump_dir = Path(path)
163-
164-
with (dump_dir / self.metadata_dict_name).open("w") as file:
165-
json.dump(self.regexp_patterns, file, indent=4)
166-
167141
def _compile_regex_patterns(self) -> None:
168142
"""Compile regex patterns."""
169143
self.regexp_patterns_compiled = [

0 commit comments

Comments
 (0)