Skip to content

Commit 29950f6

Browse files
committed
make naming consistent: retriever -> embedding; prediction -> decision
1 parent 9e8a744 commit 29950f6

File tree

11 files changed

+32
-32
lines changed

11 files changed

+32
-32
lines changed

autointent/_pipeline/_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from autointent import Context, Dataset
1212
from autointent.configs import InferenceNodeConfig, LoggingConfig, VectorIndexConfig
1313
from autointent.custom_types import ListOfGenericLabels, NodeType, SamplerType, ValidationScheme
14-
from autointent.metrics import PREDICTION_METRICS_MULTILABEL
14+
from autointent.metrics import DECISION_METRICS_MULTILABEL
1515
from autointent.nodes import InferenceNode, NodeOptimizer
1616
from autointent.nodes.schemes import OptimizationConfig
1717
from autointent.utils import load_default_search_space, load_search_space
@@ -155,7 +155,7 @@ def fit(
155155
self._refit(context)
156156

157157
predictions = self.predict(context.data_handler.test_utterances())
158-
for metric_name, metric in PREDICTION_METRICS_MULTILABEL.items():
158+
for metric_name, metric in DECISION_METRICS_MULTILABEL.items():
159159
context.optimization_info.pipeline_metrics[metric_name] = metric(
160160
context.data_handler.test_labels(),
161161
predictions,
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ._data_models import Artifact, DecisionArtifact, RetrieverArtifact, ScorerArtifact
1+
from ._data_models import Artifact, DecisionArtifact, EmbeddingArtifact, ScorerArtifact
22
from ._optimization_info import OptimizationInfo
33

4-
__all__ = ["Artifact", "DecisionArtifact", "OptimizationInfo", "RetrieverArtifact", "ScorerArtifact"]
4+
__all__ = ["Artifact", "DecisionArtifact", "OptimizationInfo", "EmbeddingArtifact", "ScorerArtifact"]

autointent/context/optimization_info/_data_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class RegexArtifact(Artifact):
2222
"""Artifact containing results from the regex node."""
2323

2424

25-
class RetrieverArtifact(Artifact):
25+
class EmbeddingArtifact(Artifact):
2626
"""
2727
Artifact containing details from the embedding node.
2828
@@ -84,7 +84,7 @@ class Artifacts(BaseModel):
8484
model_config = ConfigDict(arbitrary_types_allowed=True)
8585

8686
regex: list[RegexArtifact] = []
87-
embedding: list[RetrieverArtifact] = []
87+
embedding: list[EmbeddingArtifact] = []
8888
scoring: list[ScorerArtifact] = []
8989
decision: list[DecisionArtifact] = []
9090

autointent/context/optimization_info/_optimization_info.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from autointent.custom_types import NodeType
1616
from autointent.schemas import EmbedderConfig
1717

18-
from ._data_models import Artifact, Artifacts, RetrieverArtifact, ScorerArtifact, Trial, Trials, TrialsIds
18+
from ._data_models import Artifact, Artifacts, EmbeddingArtifact, ScorerArtifact, Trial, Trials, TrialsIds
1919

2020
if TYPE_CHECKING:
2121
from autointent.modules.abc import BaseModule
@@ -126,7 +126,7 @@ def _get_best_trial_idx(self, node_type: str) -> int | None:
126126
self._trials_best_ids.set_best_trial_idx(node_type, best_idx)
127127
return best_idx
128128

129-
def _get_best_artifact(self, node_type: str) -> RetrieverArtifact | ScorerArtifact | Artifact:
129+
def _get_best_artifact(self, node_type: str) -> EmbeddingArtifact | ScorerArtifact | Artifact:
130130
"""
131131
Retrieve the best artifact for a specific node type.
132132
@@ -146,7 +146,7 @@ def get_best_embedder(self) -> EmbedderConfig:
146146
147147
:return: Name of the best embedder.
148148
"""
149-
best_retriever_artifact: RetrieverArtifact = self._get_best_artifact(node_type=NodeType.embedding) # type: ignore[assignment]
149+
best_retriever_artifact: EmbeddingArtifact = self._get_best_artifact(node_type=NodeType.embedding) # type: ignore[assignment]
150150
return best_retriever_artifact.config
151151

152152
def get_best_train_scores(self) -> NDArray[np.float64] | None:

autointent/metrics/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,15 @@
8080
scoring_neg_ranking_loss,
8181
)
8282

83-
PREDICTION_METRICS_MULTICLASS: dict[str, DecisionMetricFn] = _funcs_to_dict(
83+
DECISION_METRICS_MULTICLASS: dict[str, DecisionMetricFn] = _funcs_to_dict(
8484
decision_accuracy,
8585
decision_f1,
8686
decision_precision,
8787
decision_recall,
8888
decision_roc_auc,
8989
)
9090

91-
PREDICTION_METRICS_MULTILABEL = PREDICTION_METRICS_MULTICLASS
91+
DECISION_METRICS_MULTILABEL = DECISION_METRICS_MULTICLASS
9292

9393
REGEX_METRICS = _funcs_to_dict(regex_partial_accuracy, regex_partial_precision)
9494

autointent/modules/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
2323

2424
REGEX_MODULES: dict[str, type[BaseRegex]] = _create_modules_dict([Regex])
2525

26-
RETRIEVAL_MODULES_MULTICLASS: dict[str, type[BaseEmbedding]] = _create_modules_dict(
26+
EMBEDDING_MODULES_MULTICLASS: dict[str, type[BaseEmbedding]] = _create_modules_dict(
2727
[RetrievalAimedEmbedding, LogregAimedEmbedding]
2828
)
2929

30-
RETRIEVAL_MODULES_MULTILABEL: dict[str, type[BaseEmbedding]] = RETRIEVAL_MODULES_MULTICLASS
30+
EMBEDDING_MODULES_MULTILABEL: dict[str, type[BaseEmbedding]] = EMBEDDING_MODULES_MULTICLASS
3131

3232
SCORING_MODULES_MULTICLASS: dict[str, type[BaseScorer]] = _create_modules_dict(
3333
[
@@ -49,11 +49,11 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
4949
],
5050
)
5151

52-
PREDICTION_MODULES_MULTICLASS: dict[str, type[BaseDecision]] = _create_modules_dict(
52+
DECISION_MODULES_MULTICLASS: dict[str, type[BaseDecision]] = _create_modules_dict(
5353
[ArgmaxDecision, JinoosDecision, ThresholdDecision, TunableDecision],
5454
)
5555

56-
PREDICTION_MODULES_MULTILABEL: dict[str, type[BaseDecision]] = _create_modules_dict(
56+
DECISION_MODULES_MULTILABEL: dict[str, type[BaseDecision]] = _create_modules_dict(
5757
[AdaptiveDecision, ThresholdDecision, TunableDecision],
5858
)
5959

autointent/modules/abc/_decision.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from autointent import Context
1010
from autointent.context.optimization_info import DecisionArtifact
1111
from autointent.custom_types import ListOfGenericLabels
12-
from autointent.metrics import PREDICTION_METRICS_MULTICLASS
12+
from autointent.metrics import DECISION_METRICS_MULTICLASS
1313
from autointent.modules.abc import BaseModule
1414
from autointent.schemas import Tag
1515

@@ -53,7 +53,7 @@ def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]:
5353

5454
val_labels, val_scores = get_decision_evaluation_data(context, "validation")
5555
decisions = self.predict(val_scores)
56-
chosen_metrics = {name: fn for name, fn in PREDICTION_METRICS_MULTICLASS.items() if name in metrics}
56+
chosen_metrics = {name: fn for name, fn in DECISION_METRICS_MULTICLASS.items() if name in metrics}
5757
self._artifact = DecisionArtifact(labels=decisions)
5858
return self.score_metrics_ho((val_labels, decisions), chosen_metrics)
5959

@@ -72,7 +72,7 @@ def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]:
7272
msg = "No folded scores are found."
7373
raise RuntimeError(msg)
7474

75-
chosen_metrics = {name: fn for name, fn in PREDICTION_METRICS_MULTICLASS.items() if name in metrics}
75+
chosen_metrics = {name: fn for name, fn in DECISION_METRICS_MULTICLASS.items() if name in metrics}
7676
metrics_values: dict[str, list[float]] = {name: [] for name in chosen_metrics}
7777
all_val_decisions = []
7878
for j in range(context.data_handler.n_folds):

autointent/modules/embedding/_logreg.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from sklearn.preprocessing import LabelEncoder
1111

1212
from autointent import Context, Embedder
13-
from autointent.context.optimization_info import RetrieverArtifact
13+
from autointent.context.optimization_info import EmbeddingArtifact
1414
from autointent.custom_types import ListOfLabels
1515
from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL
1616
from autointent.modules.abc import BaseEmbedding
@@ -146,13 +146,13 @@ def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]:
146146
metrics_calculated, _ = self.score_metrics_cv(chosen_metrics, context.data_handler.validation_iterator())
147147
return metrics_calculated
148148

149-
def get_assets(self) -> RetrieverArtifact:
149+
def get_assets(self) -> EmbeddingArtifact:
150150
"""
151151
Get the classifier artifacts for this module.
152152
153-
:return: A RetrieverArtifact object containing embedder information.
153+
:return: A EmbeddingArtifact object containing embedder information.
154154
"""
155-
return RetrieverArtifact(config=self.embedder_config)
155+
return EmbeddingArtifact(config=self.embedder_config)
156156

157157
def predict(self, utterances: list[str]) -> NDArray[np.float64]:
158158
embeddings = self._embedder.embed(utterances, TaskTypeEnum.classification)

autointent/modules/embedding/_retrieval.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pydantic import PositiveInt
66

77
from autointent import Context, VectorIndex
8-
from autointent.context.optimization_info import RetrieverArtifact
8+
from autointent.context.optimization_info import EmbeddingArtifact
99
from autointent.custom_types import ListOfLabels
1010
from autointent.metrics import RETRIEVAL_METRICS_MULTICLASS, RETRIEVAL_METRICS_MULTILABEL
1111
from autointent.modules.abc import BaseEmbedding
@@ -124,13 +124,13 @@ def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]:
124124
metrics_calculated, _ = self.score_metrics_cv(chosen_metrics, context.data_handler.validation_iterator())
125125
return metrics_calculated
126126

127-
def get_assets(self) -> RetrieverArtifact:
127+
def get_assets(self) -> EmbeddingArtifact:
128128
"""
129129
Get the retriever artifacts for this module.
130130
131-
:return: A RetrieverArtifact object containing embedder information.
131+
:return: A EmbeddingArtifact object containing embedder information.
132132
"""
133-
return RetrieverArtifact(config=self.embedder_config)
133+
return EmbeddingArtifact(config=self.embedder_config)
134134

135135
def clear_cache(self) -> None:
136136
"""Clear cached data in memory used by the vector index."""

autointent/nodes/info/_decision.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from typing import ClassVar
55

66
from autointent.custom_types import NodeType
7-
from autointent.metrics import PREDICTION_METRICS_MULTICLASS, PREDICTION_METRICS_MULTILABEL, DecisionMetricFn
8-
from autointent.modules import PREDICTION_MODULES_MULTICLASS, PREDICTION_MODULES_MULTILABEL
7+
from autointent.metrics import DECISION_METRICS_MULTICLASS, DECISION_METRICS_MULTILABEL, DecisionMetricFn
8+
from autointent.modules import DECISION_MODULES_MULTICLASS, DECISION_MODULES_MULTILABEL
99
from autointent.modules.abc import BaseDecision
1010

1111
from ._base import NodeInfo
@@ -15,11 +15,11 @@ class DecisionNodeInfo(NodeInfo):
1515
"""Prediction node info."""
1616

1717
metrics_available: ClassVar[Mapping[str, DecisionMetricFn]] = (
18-
PREDICTION_METRICS_MULTICLASS | PREDICTION_METRICS_MULTILABEL
18+
DECISION_METRICS_MULTICLASS | DECISION_METRICS_MULTILABEL
1919
)
2020

2121
modules_available: ClassVar[dict[str, type[BaseDecision]]] = (
22-
PREDICTION_MODULES_MULTICLASS | PREDICTION_MODULES_MULTILABEL
22+
DECISION_MODULES_MULTICLASS | DECISION_MODULES_MULTILABEL
2323
)
2424

2525
node_type = NodeType.decision

0 commit comments

Comments
 (0)