44from pathlib import Path
55from typing import Literal
66
7- from sklearn .linear_model import LogisticRegressionCV
8- from sklearn .preprocessing import LabelEncoder
7+ from sklearn .linear_model import LogisticRegression , LogisticRegressionCV
8+ from sklearn .multioutput import MultiOutputClassifier
9+ from sklearn .preprocessing import LabelEncoder , MultiLabelBinarizer
910
1011from autointent import Context , Embedder
1112from autointent .context .optimization_info import RetrieverArtifact
@@ -103,8 +104,7 @@ def __init__(
103104 self .batch_size = batch_size
104105 self .max_length = max_length
105106 self .embedder_use_cache = embedder_use_cache
106- self .classifier = LogisticRegressionCV (cv = cv )
107- self .label_encoder = LabelEncoder ()
107+ self .cv = cv
108108
109109 super ().__init__ (k = k )
110110
@@ -152,6 +152,8 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
152152 :param utterances: List of text data to index.
153153 :param labels: List of corresponding labels for the utterances.
154154 """
155+ self ._multilabel = isinstance (labels [0 ], list )
156+
155157 self .embedder = Embedder (
156158 device = self .embedder_device ,
157159 model_name = self .embedder_name ,
@@ -160,6 +162,16 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
160162 use_cache = self .embedder_use_cache ,
161163 )
162164 embeddings = self .embedder .embed (utterances )
165+ if self ._multilabel :
166+ self .label_encoder = MultiLabelBinarizer ()
167+ encoded_labels = self .label_encoder .fit_transform (labels )
168+ base_clf = LogisticRegression ()
169+ self .classifier = MultiOutputClassifier (base_clf )
170+ else :
171+ self .label_encoder = LabelEncoder ()
172+ encoded_labels = self .label_encoder .fit_transform (labels )
173+ self .classifier = LogisticRegressionCV (cv = self .cv )
174+
163175 self .label_encoder .fit (labels )
164176 encoded_labels = self .label_encoder .transform (labels )
165177 self .classifier .fit (embeddings , encoded_labels )
@@ -192,7 +204,7 @@ def score(
192204 predicted_encoded = self .classifier .predict (embeddings )
193205 predicted_labels = self .label_encoder .inverse_transform (predicted_encoded )
194206
195- return metric_fn (labels , [ predicted_labels ] )
207+ return metric_fn (labels , predicted_labels . reshape ( - 1 , 1 ) )
196208
197209 def get_assets (self ) -> RetrieverArtifact :
198210 """
@@ -262,6 +274,9 @@ def load(self, path: str) -> None:
262274 self .label_encoder = LabelEncoder ()
263275 self .label_encoder .classes_ = self .classifier_metadata ["classes" ]
264276
277+ def predict (self , utterances : list [str ]) -> tuple [list [list [int | list [int ]]], list [list [float ]], list [list [str ]]]:
278+ pass
279+
265280
266281class RetrievalEmbedding (EmbeddingModule ):
267282 r"""
0 commit comments