Skip to content

Commit 2e4b5e2

Browse files
committed
resolve part conversations
1 parent 921cb66 commit 2e4b5e2

File tree

19 files changed

+69
-81
lines changed

19 files changed

+69
-81
lines changed

autointent/_ranker.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from torch import nn
2121

2222
from autointent.custom_types import ListOfLabels
23-
from autointent.schemas._schemas import CrossEncoderConfig
23+
from autointent.schemas import CrossEncoderConfig
2424

2525
logger = logging.getLogger(__name__)
2626

@@ -106,7 +106,7 @@ class Ranker:
106106

107107
def __init__(
108108
self,
109-
cross_encoder_config: CrossEncoderConfig,
109+
cross_encoder_config: CrossEncoderConfig | str | dict[str, Any],
110110
classifier_head: LogisticRegressionCV | None = None,
111111
) -> None:
112112
"""
@@ -116,10 +116,7 @@ def __init__(
116116
:param max_length (int, optional): Max length for input sequences for the cross encoder.
117117
:param classifier_head (LogisticRegressionCV, optional): Classifier (to be used in restore procedure mainly).
118118
"""
119-
if isinstance(cross_encoder_config, dict):
120-
cross_encoder_config = CrossEncoderConfig(**cross_encoder_config)
121-
if isinstance(cross_encoder_config, str):
122-
cross_encoder_config = CrossEncoderConfig(model_name=cross_encoder_config)
119+
cross_encoder_config = CrossEncoderConfig.from_search_config(cross_encoder_config)
123120
self.cross_encoder = st.CrossEncoder(
124121
cross_encoder_config.model_name,
125122
trust_remote_code=True,

autointent/_vector_index.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from autointent import Embedder
1818
from autointent.custom_types import ListOfLabels
19-
from autointent.schemas._schemas import EmbedderConfig
19+
from autointent.schemas import EmbedderConfig
2020

2121

2222
class VectorIndexMetadata(TypedDict):

autointent/context/optimization_info/_data_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pydantic import BaseModel, ConfigDict, Field
1212

1313
from autointent.custom_types import ListOfLabelsWithOOS, NodeType
14-
from autointent.schemas._schemas import EmbedderConfig
14+
from autointent.schemas import EmbedderConfig
1515

1616

1717
class Artifact(BaseModel):

autointent/context/optimization_info/_optimization_info.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from autointent.configs import InferenceNodeConfig
1515
from autointent.custom_types import NodeType
16-
from autointent.schemas._schemas import EmbedderConfig
16+
from autointent.schemas import EmbedderConfig
1717

1818
from ._data_models import Artifact, Artifacts, RetrieverArtifact, ScorerArtifact, Trial, Trials, TrialsIds
1919

autointent/modules/abc/_base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from autointent.context.optimization_info import Artifact
1515
from autointent.custom_types import ListOfGenericLabels, ListOfLabels
1616
from autointent.exceptions import WrongClassificationError
17+
from autointent.schemas import EmbedderConfig
1718

1819
logger = logging.getLogger(__name__)
1920

@@ -114,8 +115,12 @@ def from_context(cls, context: Context, **kwargs: dict[str, Any]) -> "Module":
114115
:param kwargs: Additional kwargs.
115116
"""
116117

117-
def get_embedder_name(self) -> str | None:
118-
"""Experimental method."""
118+
def get_embedder_config(self) -> EmbedderConfig | None:
119+
"""
120+
Get the config of the embedder.
121+
122+
:return: Embedder config.
123+
"""
119124
return None
120125

121126
@staticmethod

autointent/modules/embedding/_logreg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from autointent.custom_types import ListOfLabels
1414
from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL
1515
from autointent.modules.abc import EmbeddingModule
16-
from autointent.schemas._schemas import EmbedderConfig
16+
from autointent.schemas import EmbedderConfig
1717

1818

1919
class LogregAimedEmbedding(EmbeddingModule):

autointent/modules/embedding/_retrieval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from autointent.custom_types import ListOfLabels
88
from autointent.metrics import RETRIEVAL_METRICS_MULTICLASS, RETRIEVAL_METRICS_MULTILABEL
99
from autointent.modules.abc import EmbeddingModule
10-
from autointent.schemas._schemas import EmbedderConfig
10+
from autointent.schemas import EmbedderConfig
1111

1212

1313
class RetrievalAimedEmbedding(EmbeddingModule):

autointent/modules/scoring/_description/description.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from autointent.custom_types import ListOfLabels
1313
from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL
1414
from autointent.modules.abc import ScoringModule
15-
from autointent.schemas._schemas import EmbedderConfig
15+
from autointent.schemas import EmbedderConfig
1616

1717

1818
class DescriptionScorer(ScoringModule):
@@ -47,12 +47,7 @@ def __init__(
4747
:param temperature: Temperature parameter for scaling logits, defaults to 1.0.
4848
"""
4949
self.temperature = temperature
50-
if isinstance(embedder_config, dict):
51-
embedder_config = EmbedderConfig(**embedder_config)
52-
if isinstance(embedder_config, str):
53-
embedder_config = EmbedderConfig(model_name=embedder_config)
54-
55-
self.embedder_config = embedder_config
50+
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)
5651

5752
@classmethod
5853
def from_context(
@@ -77,7 +72,7 @@ def from_context(
7772
embedder_config=embedder_config,
7873
)
7974

80-
def get_embedder_name(self) -> EmbedderConfig:
75+
def get_embedder_config(self) -> EmbedderConfig:
8176
"""
8277
Get the name of the embedder.
8378
@@ -129,10 +124,10 @@ def predict(self, utterances: list[str]) -> NDArray[np.float64]:
129124
similarities: NDArray[np.float64] = cosine_similarity(utterance_vectors, self._description_vectors)
130125

131126
if self._multilabel:
132-
probabilites = scipy.special.expit(similarities / self.temperature)
127+
probabilities = scipy.special.expit(similarities / self.temperature)
133128
else:
134-
probabilites = scipy.special.softmax(similarities / self.temperature, axis=1)
135-
return probabilites # type: ignore[no-any-return]
129+
probabilities = scipy.special.softmax(similarities / self.temperature, axis=1)
130+
return probabilities # type: ignore[no-any-return]
136131

137132
def clear_cache(self) -> None:
138133
"""Clear cached data in memory used by the embedder."""
@@ -150,7 +145,7 @@ def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]:
150145
Evaluate the scorer on a test set and compute the specified metric.
151146
152147
:param context: Context containing test set and other data.
153-
:param split: Target split
148+
:param metrics: List of metric names to compute.
154149
:return: Computed metrics value for the test set or error code of metrics
155150
"""
156151
metrics_dict = SCORING_METRICS_MULTILABEL if context.is_multilabel() else SCORING_METRICS_MULTICLASS

autointent/modules/scoring/_dnnc/dnnc.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from autointent import Context, Ranker, VectorIndex
1111
from autointent.custom_types import ListOfLabels
1212
from autointent.modules.abc import ScoringModule
13-
from autointent.schemas._schemas import CrossEncoderConfig, EmbedderConfig
13+
from autointent.schemas import CrossEncoderConfig, EmbedderConfig
1414

1515
logger = logging.getLogger(__name__)
1616

@@ -87,17 +87,8 @@ def __init__(
8787
:param embedder_config: Config of the embedder model.
8888
:param k: Number of nearest neighbors to retrieve.
8989
"""
90-
if isinstance(cross_encoder_config, dict):
91-
cross_encoder_config = CrossEncoderConfig(**cross_encoder_config)
92-
if isinstance(cross_encoder_config, str):
93-
cross_encoder_config = CrossEncoderConfig(model_name=cross_encoder_config)
94-
self.cross_encoder_config = cross_encoder_config
95-
96-
if isinstance(embedder_config, dict):
97-
embedder_config = EmbedderConfig(**embedder_config)
98-
if isinstance(embedder_config, str):
99-
embedder_config = EmbedderConfig(model_name=embedder_config)
100-
self.embedder_config = embedder_config
90+
self.cross_encoder_config = CrossEncoderConfig.from_search_config(cross_encoder_config)
91+
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)
10192
self.k = k
10293

10394
@classmethod

autointent/modules/scoring/_knn/knn.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from autointent import Context, VectorIndex
99
from autointent.custom_types import WEIGHT_TYPES, ListOfLabels
1010
from autointent.modules.abc import ScoringModule
11-
from autointent.schemas._schemas import EmbedderConfig
11+
from autointent.schemas import EmbedderConfig
1212

1313
from .weighting import apply_weights
1414

@@ -71,12 +71,7 @@ def __init__(
7171
- "distance": Weight inversely proportional to distance.
7272
- "closest": Only the closest neighbor of each class is weighted.
7373
"""
74-
if isinstance(embedder_config, dict):
75-
embedder_config = EmbedderConfig(**embedder_config)
76-
if isinstance(embedder_config, str):
77-
embedder_config = EmbedderConfig(model_name=embedder_config)
78-
79-
self.embedder_config = embedder_config
74+
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)
8075
self.k = k
8176
self.weights = weights
8277

@@ -106,13 +101,13 @@ def from_context(
106101
weights=weights,
107102
)
108103

109-
def get_embedder_name(self) -> str:
104+
def get_embedder_config(self) -> EmbedderConfig:
110105
"""
111106
Get the name of the embedder.
112107
113108
:return: Embedder name.
114109
"""
115-
return self.embedder_config.model_name
110+
return self.embedder_config
116111

117112
def fit(self, utterances: list[str], labels: ListOfLabels, clear_cache: bool = False) -> None:
118113
"""

0 commit comments

Comments
 (0)