1212from autointent .custom_types import ListOfLabels
1313from autointent .metrics import SCORING_METRICS_MULTICLASS , SCORING_METRICS_MULTILABEL
1414from autointent .modules .abc import ScoringModule
15- from autointent .schemas . _schemas import EmbedderConfig
15+ from autointent .schemas import EmbedderConfig
1616
1717
1818class DescriptionScorer (ScoringModule ):
@@ -47,12 +47,7 @@ def __init__(
4747 :param temperature: Temperature parameter for scaling logits, defaults to 1.0.
4848 """
4949 self .temperature = temperature
50- if isinstance (embedder_config , dict ):
51- embedder_config = EmbedderConfig (** embedder_config )
52- if isinstance (embedder_config , str ):
53- embedder_config = EmbedderConfig (model_name = embedder_config )
54-
55- self .embedder_config = embedder_config
50+ self .embedder_config = EmbedderConfig .from_search_config (embedder_config )
5651
5752 @classmethod
5853 def from_context (
@@ -77,7 +72,7 @@ def from_context(
7772 embedder_config = embedder_config ,
7873 )
7974
80- def get_embedder_name (self ) -> EmbedderConfig :
75+ def get_embedder_config (self ) -> EmbedderConfig :
8176 """
8277 Get the name of the embedder.
8378
@@ -129,10 +124,10 @@ def predict(self, utterances: list[str]) -> NDArray[np.float64]:
129124 similarities : NDArray [np .float64 ] = cosine_similarity (utterance_vectors , self ._description_vectors )
130125
131126 if self ._multilabel :
132- probabilites = scipy .special .expit (similarities / self .temperature )
127+ probabilities = scipy .special .expit (similarities / self .temperature )
133128 else :
134- probabilites = scipy .special .softmax (similarities / self .temperature , axis = 1 )
135- return probabilites # type: ignore[no-any-return]
129+ probabilities = scipy .special .softmax (similarities / self .temperature , axis = 1 )
130+ return probabilities # type: ignore[no-any-return]
136131
137132 def clear_cache (self ) -> None :
138133 """Clear cached data in memory used by the embedder."""
@@ -150,7 +145,7 @@ def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]:
150145 Evaluate the scorer on a test set and compute the specified metric.
151146
152147 :param context: Context containing test set and other data.
153- :param split: Target split
148+ :param metrics: List of metric names to compute.
154149 :return: Computed metrics value for the test set or error code of metrics
155150 """
156151 metrics_dict = SCORING_METRICS_MULTILABEL if context .is_multilabel () else SCORING_METRICS_MULTICLASS
0 commit comments