Skip to content

Commit 075b8de

Browse files
authored
fix names (#37)
* fix names * lint
1 parent fcf61a3 commit 075b8de

File tree

25 files changed

+82
-50
lines changed

25 files changed

+82
-50
lines changed

autointent/context/optimization_info/data_models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from numpy.typing import NDArray
55
from pydantic import BaseModel, ConfigDict, Field
66

7+
from autointent.custom_types import NodeType
8+
79

810
class Artifact(BaseModel): ...
911

@@ -40,7 +42,7 @@ class PredictorArtifact(Artifact):
4042

4143

4244
def validate_node_name(value: str) -> str:
43-
if value in ["regexp", "retrieval", "scoring", "prediction"]:
45+
if value in [NodeType.retrieval, NodeType.scoring, NodeType.prediction, NodeType.regexp]:
4446
return value
4547
msg = f"Unknown node_type: {value}. Expected one of ['regexp', 'retrieval', 'scoring', 'prediction']"
4648
raise ValueError(msg)

autointent/context/optimization_info/optimization_info.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from numpy.typing import NDArray
55

66
from autointent.configs.node import InferenceNodeConfig
7+
from autointent.custom_types import NODE_TYPES, NodeType
78
from autointent.logger import get_logger
89

910
from .data_models import Artifact, Artifacts, RetrieverArtifact, ScorerArtifact, Trial, Trials, TrialsIds
@@ -71,32 +72,28 @@ def _get_best_artifact(self, node_type: str) -> RetrieverArtifact | ScorerArtifa
7172
return self.artifacts.get_best_artifact(node_type, i_best)
7273

7374
def get_best_embedder(self) -> str:
74-
best_retriever_artifact: RetrieverArtifact = self._get_best_artifact(node_type="retrieval") # type: ignore[assignment]
75+
best_retriever_artifact: RetrieverArtifact = self._get_best_artifact(node_type=NodeType.retrieval) # type: ignore[assignment]
7576
return best_retriever_artifact.embedder_name
7677

7778
def get_best_test_scores(self) -> NDArray[np.float64] | None:
78-
best_scorer_artifact: ScorerArtifact = self._get_best_artifact(node_type="scoring") # type: ignore[assignment]
79+
best_scorer_artifact: ScorerArtifact = self._get_best_artifact(node_type=NodeType.scoring) # type: ignore[assignment]
7980
return best_scorer_artifact.test_scores
8081

8182
def get_best_oos_scores(self) -> NDArray[np.float64] | None:
82-
best_scorer_artifact: ScorerArtifact = self._get_best_artifact(node_type="scoring") # type: ignore[assignment]
83+
best_scorer_artifact: ScorerArtifact = self._get_best_artifact(node_type=NodeType.scoring) # type: ignore[assignment]
8384
return best_scorer_artifact.oos_scores
8485

8586
def dump_evaluation_results(self) -> dict[str, dict[str, list[float]]]:
86-
node_wise_metrics = {
87-
node_type: self._get_metrics_values(node_type)
88-
for node_type in ["regexp", "retrieval", "scoring", "prediction"]
89-
}
87+
node_wise_metrics = {node_type.value: self._get_metrics_values(node_type) for node_type in NODE_TYPES}
9088
return {
9189
"metrics": node_wise_metrics,
9290
"configs": self.trials.model_dump(),
9391
}
9492

9593
def get_inference_nodes_config(self) -> list[InferenceNodeConfig]:
96-
node_types = ["regexp", "retrieval", "scoring", "prediction"]
97-
trial_ids = [self._get_best_trial_idx(node_type) for node_type in node_types]
94+
trial_ids = [self._get_best_trial_idx(node_type) for node_type in NODE_TYPES]
9895
res = []
99-
for idx, node_type in zip(trial_ids, node_types, strict=True):
96+
for idx, node_type in zip(trial_ids, NODE_TYPES, strict=True):
10097
if idx is None:
10198
continue
10299
trial = self.trials.get_trial(node_type, idx)

autointent/custom_types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,13 @@ class LogLevel(Enum):
1919

2020
class BaseMetadataDict(TypedDict):
2121
pass
22+
23+
24+
class NodeType(str, Enum):
25+
retrieval = "retrieval"
26+
prediction = "prediction"
27+
scoring = "scoring"
28+
regexp = "regexp"
29+
30+
31+
NODE_TYPES = [NodeType.retrieval, NodeType.prediction, NodeType.scoring, NodeType.regexp]

autointent/generation/prompt_scheme.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
class PromptDescription(BaseModel):
77
text: str = PROMPT_DESCRIPTION
88

9-
@field_validator("text")
109
@classmethod
10+
@field_validator("text")
1111
def check_valid_prompt(cls, value: str) -> str:
1212
if value.find("{intent_name}") == -1 or value.find("{user_utterances}") == -1:
1313
text_error = (

autointent/modules/__init__.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import TypeVar
2+
13
from .base import Module
24
from .prediction import (
35
ArgmaxPredictor,
@@ -10,36 +12,31 @@
1012
from .retrieval import RetrievalModule, VectorDBModule
1113
from .scoring import DescriptionScorer, DNNCScorer, KNNScorer, LinearScorer, MLKnnScorer, ScoringModule
1214

13-
RETRIEVAL_MODULES_MULTICLASS: dict[str, type[Module]] = {
14-
"vector_db": VectorDBModule,
15-
}
15+
T = TypeVar("T", bound=Module)
16+
17+
18+
def create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
19+
return {module.name: module for module in modules}
20+
21+
22+
RETRIEVAL_MODULES_MULTICLASS: dict[str, type[Module]] = create_modules_dict([VectorDBModule])
1623

1724
RETRIEVAL_MODULES_MULTILABEL = RETRIEVAL_MODULES_MULTICLASS
1825

19-
SCORING_MODULES_MULTICLASS: dict[str, type[ScoringModule]] = {
20-
"dnnc": DNNCScorer,
21-
"knn": KNNScorer,
22-
"linear": LinearScorer,
23-
"description": DescriptionScorer,
24-
}
26+
SCORING_MODULES_MULTICLASS: dict[str, type[ScoringModule]] = create_modules_dict(
27+
[DNNCScorer, KNNScorer, LinearScorer, DescriptionScorer]
28+
)
29+
30+
SCORING_MODULES_MULTILABEL: dict[str, type[ScoringModule]] = create_modules_dict(
31+
[MLKnnScorer, LinearScorer, DescriptionScorer]
32+
)
2533

26-
SCORING_MODULES_MULTILABEL: dict[str, type[ScoringModule]] = {
27-
"knn": KNNScorer,
28-
"linear": LinearScorer,
29-
"mlknn": MLKnnScorer,
30-
}
34+
PREDICTION_MODULES_MULTICLASS: dict[str, type[Module]] = create_modules_dict(
35+
[ArgmaxPredictor, JinoosPredictor, ThresholdPredictor, TunablePredictor]
36+
)
3137

32-
PREDICTION_MODULES_MULTICLASS: dict[str, type[Module]] = {
33-
"argmax": ArgmaxPredictor,
34-
"jinoos": JinoosPredictor,
35-
"threshold": ThresholdPredictor,
36-
"tunable": TunablePredictor,
37-
}
38+
PREDICTION_MODULES_MULTILABEL: dict[str, type[Module]] = create_modules_dict([ThresholdPredictor, TunablePredictor])
3839

39-
PREDICTION_MODULES_MULTILABEL: dict[str, type[Module]] = {
40-
"threshold": ThresholdPredictor,
41-
"tunable": TunablePredictor,
42-
}
4340
__all__ = [
4441
"Module",
4542
"ArgmaxPredictor",

autointent/modules/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212

1313
class Module(ABC):
14+
name: str
15+
1416
metadata_dict_name: str = "metadata.json"
1517
metadata: BaseMetadataDict
1618

autointent/modules/prediction/argmax.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313

1414
class ArgmaxPredictor(PredictionModule):
1515
metadata = {} # noqa: RUF012
16+
name = "argmax"
1617

1718
def __init__(self) -> None:
1819
pass
1920

21+
2022
@classmethod
2123
def from_context(cls, context: Context) -> Self:
2224
return cls()

autointent/modules/prediction/jinoos.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class JinoosPredictorDumpMetadata(BaseMetadataDict):
2222

2323
class JinoosPredictor(PredictionModule):
2424
thresh: float
25+
name = "jinoos"
2526

2627
def __init__(
2728
self,

autointent/modules/prediction/threshold.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class ThresholdPredictor(PredictionModule):
2727
metadata: ThresholdPredictorDumpMetadata
2828
multilabel: bool
2929
tags: list[Tag] | None
30+
name = "threshold"
3031

3132
def __init__(
3233
self,

autointent/modules/prediction/tunable.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class TunablePredictorDumpMetadata(BaseMetadataDict):
2424

2525

2626
class TunablePredictor(PredictionModule):
27+
name = "tunable"
28+
2729
def __init__(
2830
self,
2931
n_trials: int = 320,

0 commit comments

Comments
 (0)