Skip to content

Commit b2c8986

Browse files
committed
fix typing problems (except DataHandler._split_cv)
1 parent 8f30ec9 commit b2c8986

File tree

10 files changed

+38
-23
lines changed

10 files changed

+38
-23
lines changed

autointent/_pipeline/_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
import logging
55
from pathlib import Path
6-
from typing import TYPE_CHECKING, Any
6+
from typing import TYPE_CHECKING, Any, Literal
77

88
import numpy as np
99
import yaml
@@ -122,7 +122,7 @@ def _is_inference(self) -> bool:
122122
"""
123123
return isinstance(self.nodes[NodeType.scoring], InferenceNode)
124124

125-
def fit(self, dataset: Dataset) -> Context:
125+
def fit(self, dataset: Dataset, scheme: Literal["ho", "cv"] = "ho") -> Context:
126126
"""
127127
Optimize the pipeline from dataset.
128128
@@ -134,7 +134,7 @@ def fit(self, dataset: Dataset) -> Context:
134134
raise RuntimeError(msg)
135135

136136
context = Context()
137-
context.set_dataset(dataset)
137+
context.set_dataset(dataset, scheme)
138138
context.configure_logging(self.logging_config)
139139
context.configure_vector_index(self.vector_index_config, self.embedder_config)
140140
context.configure_cross_encoder(self.cross_encoder_config)

autointent/configs/_optimization.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Configuration for the optimization process."""
22

33
from pathlib import Path
4+
from typing import Literal
45

56
from pydantic import BaseModel, Field
67

@@ -12,6 +13,8 @@ class DataConfig(BaseModel):
1213

1314
train_path: str | Path
1415
"""Path to the training data. Can be local path or HF repo."""
16+
scheme: Literal["ho", "cv"]
17+
"""Hold-out or cross-validation."""
1518

1619

1720
class TaskConfig(BaseModel):

autointent/context/_context.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
import logging
55
from pathlib import Path
6-
from typing import Any
6+
from typing import Any, Literal
77

88
import yaml
99

@@ -83,9 +83,10 @@ def configure_data(self, config: DataConfig) -> None:
8383
self.data_handler = DataHandler(
8484
dataset=load_dataset(config.train_path),
8585
random_seed=self.seed,
86+
scheme=config.scheme
8687
)
8788

88-
def set_dataset(self, dataset: Dataset) -> None:
89+
def set_dataset(self, dataset: Dataset, scheme: Literal["ho", "cv"]) -> None:
8990
"""
9091
Set the datasets for training, validation and testing.
9192
@@ -94,6 +95,7 @@ def set_dataset(self, dataset: Dataset) -> None:
9495
self.data_handler = DataHandler(
9596
dataset=dataset,
9697
random_seed=self.seed,
98+
scheme=scheme,
9799
)
98100

99101
def get_inference_config(self) -> dict[str, Any]:

autointent/context/data_handler/_data_handler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from transformers import set_seed
1010

1111
from autointent import Dataset
12-
from autointent.custom_types import ListOfGenericLabels, Split
12+
from autointent.custom_types import ListOfGenericLabels, ListOfLabels, Split
1313

1414
from ._stratification import split_dataset
1515

@@ -169,7 +169,7 @@ def test_labels(self, idx: int | None = None) -> ListOfGenericLabels:
169169
split = f"{Split.TEST}_{idx}" if idx is not None else Split.TEST
170170
return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature])
171171

172-
def validation_iterator(self) -> Generator[tuple[list, list, list, list]]:
172+
def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[str], ListOfLabels]]:
173173
if self.scheme == "ho":
174174
msg = "Cannot call cross-validation on hold-out DataHandler"
175175
raise RuntimeError(msg)
@@ -180,7 +180,7 @@ def validation_iterator(self) -> Generator[tuple[list, list, list, list]]:
180180
train_folds = [i for i in range(self.n_folds) if i != j]
181181
train_utterances = [ut for i_fold in train_folds for ut in self.train_utterances(i_fold)]
182182
train_labels = [ut for i_fold in train_folds for ut in self.train_labels(i_fold)]
183-
yield train_utterances, train_labels, val_utterances, val_labels
183+
yield train_utterances, train_labels, val_utterances, val_labels # type: ignore[misc]
184184

185185
msg = "something's wrong"
186186
raise RuntimeError(msg)

autointent/modules/abc/_base.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from autointent._dump_tools import Dumper
1313
from autointent.context import Context
1414
from autointent.context.optimization_info import Artifact
15-
from autointent.custom_types import ListOfGenericLabels
15+
from autointent.custom_types import ListOfGenericLabels, ListOfLabels
1616
from autointent.exceptions import WrongClassificationError
1717

1818
logger = logging.getLogger(__name__)
@@ -133,20 +133,22 @@ def score_metrics_ho(params: tuple[Any, Any], metrics_dict: dict[str, Any]) -> d
133133
return metrics
134134

135135
def score_metrics_cv(
136-
self, metrics_dict: dict[str, Any], cv_iterator: Iterable[tuple[list, list, list, list]]
136+
self,
137+
metrics_dict: dict[str, Any],
138+
cv_iterator: Iterable[tuple[list[str], ListOfLabels, list[str], ListOfLabels]],
137139
) -> tuple[dict[str, float], list[ListOfGenericLabels] | list[npt.NDArray[Any]]]:
138-
metrics_values = {name: [] for name in metrics_dict}
140+
metrics_values: dict[str, list[float]] = {name: [] for name in metrics_dict}
139141
all_val_preds = []
140142

141143
for train_utterances, train_labels, val_utterances, val_labels in cv_iterator:
142-
self.fit(train_utterances, train_labels)
144+
self.fit(train_utterances, train_labels) # type: ignore[arg-type]
143145
val_preds = self.predict(val_utterances)
144146
for name, fn in metrics_dict.items():
145147
metrics_values[name].append(fn(val_labels, val_preds))
146148
all_val_preds.append(val_preds)
147149

148-
metrics = {name: np.mean(values_list) for name, values_list in metrics_values.items()}
149-
return metrics, all_val_preds
150+
metrics = {name: float(np.mean(values_list)) for name, values_list in metrics_values.items()}
151+
return metrics, all_val_preds # type: ignore[return-value]
150152

151153
def _validate_multilabel(self, data_is_multilabel: bool) -> None:
152154
if data_is_multilabel and not self.supports_multilabel:

autointent/modules/abc/_decision.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]:
4848
:param split: Target split
4949
:return: Computed metrics value for the test set or error code of metrics
5050
"""
51-
train_scores, train_labels = self.get_train_data(context)
52-
self.fit(train_scores, train_labels, context.data_handler.tags)
51+
train_scores, train_labels, tags = self.get_train_data(context)
52+
self.fit(train_scores, train_labels, tags)
5353

5454
val_labels, val_scores = get_decision_evaluation_data(context, "validation")
5555
decisions = self.predict(val_scores)
@@ -73,22 +73,22 @@ def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]:
7373
raise RuntimeError(msg)
7474

7575
chosen_metrics = {name: fn for name, fn in PREDICTION_METRICS_MULTICLASS.items() if name in metrics}
76-
metrics_values = {name: [] for name in chosen_metrics}
76+
metrics_values: dict[str, list[float]] = {name: [] for name in chosen_metrics}
7777
all_val_decisions = []
7878
for j in range(context.data_handler.n_folds):
7979
val_labels = labels[j]
8080
val_scores = scores[j]
8181
train_folds = [i for i in range(context.data_handler.n_folds) if i != j]
8282
train_labels = [ut for i_fold in train_folds for ut in labels[i_fold]]
8383
train_scores = [ut for i_fold in train_folds for ut in scores[i_fold]]
84-
self.fit(train_scores, train_labels, context.data_handler.tags)
84+
self.fit(train_scores, train_labels, context.data_handler.tags) # type: ignore[arg-type]
8585
val_decisions = self.predict(val_scores)
8686
for name, fn in chosen_metrics.items():
8787
metrics_values[name].append(fn(val_labels, val_decisions))
8888
all_val_decisions.append(val_decisions)
8989

9090
self._artifact = DecisionArtifact(labels=[pred for pred_list in all_val_decisions for pred in pred_list])
91-
return {name: np.mean(values_list) for name, values_list in metrics_values.items()}
91+
return {name: float(np.mean(values_list)) for name, values_list in metrics_values.items()}
9292

9393
def get_assets(self) -> DecisionArtifact:
9494
"""Return useful assets that represent intermediate data into context."""

autointent/modules/abc/_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ class EmbeddingModule(Module, ABC):
1111
"""Base class for embedding modules."""
1212

1313
def get_train_data(self, context: Context) -> tuple[list[str], ListOfLabels]:
14-
return (context.data_handler.train_utterances(0), context.data_handler.train_labels(0))
14+
return (context.data_handler.train_utterances(0), context.data_handler.train_labels(0)) # type: ignore[return-value]

autointent/modules/abc/_scoring.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ class ScoringModule(Module, ABC):
2222

2323
supports_oos = False
2424

25+
@abstractmethod
26+
def fit(
27+
self,
28+
utterances: list[str],
29+
labels: ListOfLabels,
30+
) -> None:
31+
...
32+
2533
def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]:
2634
train_utterances, train_labels = self.get_train_data(context)
2735
self.fit(train_utterances, train_labels)
@@ -68,7 +76,7 @@ def get_assets(self) -> ScorerArtifact:
6876
return self._artifact
6977

7078
def get_train_data(self, context: Context) -> tuple[list[str], ListOfLabels]:
71-
return (context.data_handler.train_utterances(0), context.data_handler.train_labels(0))
79+
return (context.data_handler.train_utterances(0), context.data_handler.train_labels(0)) # type: ignore[return-value]
7280

7381
@abstractmethod
7482
def predict(self, utterances: list[str]) -> npt.NDArray[Any]:

autointent/modules/scoring/_description/description.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def clear_cache(self) -> None:
148148
self._embedder.clear_ram()
149149

150150
def get_train_data(self, context: Context) -> tuple[list[str], ListOfLabels, list[str]]:
151-
return (
151+
return ( # type: ignore[return-value]
152152
context.data_handler.train_utterances(0),
153153
context.data_handler.train_labels(0),
154154
context.data_handler.intent_descriptions,

autointent/nodes/_optimization/_node_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def fit(self, context: Context) -> None:
6767
module_kwargs["embedder_name"] = embedder_name
6868

6969
self._logger.debug("scoring %s module...", module_name)
70-
metrics_score = module.score(context, test=False, metrics=self.metrics)
70+
metrics_score = module.score(context, metrics=self.metrics)
7171
metric_value = metrics_score[self.target_metric]
7272

7373
context.callback_handler.log_metrics(metrics_score)

0 commit comments

Comments
 (0)