Skip to content

Commit dcee174

Browse files
committed
fix: typing
1 parent 31147a9 commit dcee174

File tree

6 files changed

+9
-22
lines changed

6 files changed

+9
-22
lines changed

autointent/modules/abc/_base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@ def fit(self, *args: tuple[Any], **kwargs: dict[str, Any]) -> None:
3434
"""
3535

3636
@abstractmethod
37-
def score(
38-
self, context: Context, split: Literal["validation", "test"], metrics: list[str]
39-
) -> dict[str, float | str]:
37+
def score(self, context: Context, split: Literal["validation", "test"], metrics: list[str]) -> dict[str, float]:
4038
"""
4139
Calculate metric on test set and return metric value.
4240
@@ -108,7 +106,7 @@ def get_embedder_name(self) -> str | None:
108106
return None
109107

110108
@staticmethod
111-
def score_metrics(params: tuple[Any, Any], metrics_dict: dict[str, Any]) -> dict[str, float | str]:
109+
def score_metrics(params: tuple[Any, Any], metrics_dict: dict[str, Any]) -> dict[str, float]:
112110
"""
113111
Score metrics on the test set.
114112

autointent/modules/abc/_decision.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,7 @@ def predict(self, scores: npt.NDArray[Any]) -> ListOfGenericLabels:
4040
:param scores: Scores to predict
4141
"""
4242

43-
def score(
44-
self,
45-
context: Context,
46-
split: Literal["validation", "test"],
47-
) -> dict[str, float | str]:
43+
def score(self, context: Context, split: Literal["validation", "test"], metrics: list[str]) -> dict[str, float]:
4844
"""
4945
Calculate metric on test set and return metric value.
5046
@@ -54,7 +50,8 @@ def score(
5450
"""
5551
labels, scores = get_decision_evaluation_data(context, split)
5652
self._decisions = self.predict(scores)
57-
return self.score_metrics((labels, self._decisions), PREDICTION_METRICS_MULTICLASS)
53+
chosen_metrics = {name: fn for name, fn in PREDICTION_METRICS_MULTICLASS.items() if name in metrics}
54+
return self.score_metrics((labels, self._decisions), chosen_metrics)
5855

5956
def get_assets(self) -> DecisionArtifact:
6057
"""Return useful assets that represent intermediate data into context."""

autointent/modules/abc/_scoring.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ class ScoringModule(Module, ABC):
2121

2222
supports_oos = False
2323

24-
def score(
25-
self, context: Context, split: Literal["validation", "test"], metrics: list[str]
26-
) -> dict[str, float | str]:
24+
def score(self, context: Context, split: Literal["validation", "test"], metrics: list[str]) -> dict[str, float]:
2725
"""
2826
Evaluate the scorer on a test set and compute the specified metric.
2927

autointent/modules/embedding/_logreg.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
129129

130130
self._classifier.fit(embeddings, labels)
131131

132-
def score(
133-
self, context: Context, split: Literal["validation", "test"], metrics: list[str]
134-
) -> dict[str, float | str]:
132+
def score(self, context: Context, split: Literal["validation", "test"], metrics: list[str]) -> dict[str, float]:
135133
"""
136134
Evaluate the embedding model using a specified metric function.
137135

autointent/modules/embedding/_retrieval.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
109109
)
110110
self._vector_index.add(utterances, labels)
111111

112-
def score(
113-
self, context: Context, split: Literal["validation", "test"], metrics: list[str]
114-
) -> dict[str, float | str]:
112+
def score(self, context: Context, split: Literal["validation", "test"], metrics: list[str]) -> dict[str, float]:
115113
"""
116114
Evaluate the embedding model using a specified metric function.
117115

autointent/modules/regexp/_regexp.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,7 @@ def _predict_single(self, utterance: str) -> tuple[LabelType, dict[str, list[str
108108
matches["partial_matches"].extend(intent_matches["partial_matches"])
109109
return list(prediction), matches
110110

111-
def score(
112-
self, context: Context, split: Literal["validation", "test"], metrics: list[str]
113-
) -> dict[str, float | str]:
111+
def score(self, context: Context, split: Literal["validation", "test"], metrics: list[str]) -> dict[str, float]:
114112
"""
115113
Calculate metric on test set and return metric value.
116114

0 commit comments

Comments
 (0)