33from typing import Literal
44
55import numpy as np
6+ from numpy .typing import NDArray
67from sklearn .linear_model import LogisticRegression , LogisticRegressionCV
78from sklearn .multioutput import MultiOutputClassifier
89from sklearn .preprocessing import LabelEncoder
@@ -155,10 +156,7 @@ def score(
155156 message = f"Invalid split '{ split } ' provided. Expected one of 'validation', or 'test'."
156157 raise ValueError (message )
157158
158- embeddings = self ._embedder .embed (utterances )
159- probas = self ._classifier .predict_proba (embeddings )
160- if self ._multilabel :
161- probas = np .stack (probas , axis = 1 )[..., 1 ]
159+ probas = self .predict (utterances )
162160 metrics_dict = SCORING_METRICS_MULTILABEL if context .is_multilabel () else SCORING_METRICS_MULTICLASS
163161 return self .score_metrics ((labels , probas ), metrics_dict )
164162
@@ -170,5 +168,11 @@ def get_assets(self) -> RetrieverArtifact:
170168 """
171169 return RetrieverArtifact (embedder_name = self .embedder_name )
172170
173- def predict (self , utterances : list [str ]) -> None :
174- pass
171+ def predict (self , utterances : list [str ]) -> NDArray [np .float64 ] | list [NDArray [np .float64 ]]:
172+ embeddings = self ._embedder .embed (utterances )
173+ probas = self ._classifier .predict_proba (embeddings )
174+
175+ if self ._multilabel :
176+ probas = np .stack (probas , axis = 1 )[..., 1 ]
177+
178+ return probas
0 commit comments