44from pathlib import Path
55from typing import Literal
66
7+ import joblib
8+ from joblib import load as joblib_load
79from sklearn .linear_model import LogisticRegression , LogisticRegressionCV
810from sklearn .multioutput import MultiOutputClassifier
911from sklearn .preprocessing import LabelEncoder , MultiLabelBinarizer
1618from autointent .modules .abc import EmbeddingModule
1719
1820
19- class VectorDBMetadata (BaseMetadataDict ):
21+ class RetrievalMetadata (BaseMetadataDict ):
2022 """Metadata class for RetrievalEmbedding."""
2123
2224 db_dir : str
2325 batch_size : int
2426 max_length : int | None
2527
2628
27- class ClassifierMetadata (BaseMetadataDict ):
29+ class LogRegMetadata (BaseMetadataDict ):
2830 """Metadata class for LogisticRegressionCV and LabelEncoder."""
2931
30- coef_ : list [ list [ float ]]
31- intercept_ : list [ float ]
32- params : dict [ str , any ]
32+ db_dir : str
33+ batch_size : int
34+ max_length : int | None
3335 classes : list [str ]
3436
3537
@@ -154,6 +156,15 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
154156 """
155157 self ._multilabel = isinstance (labels [0 ], list )
156158
159+ vector_index_client = VectorIndexClient (
160+ self .embedder_device ,
161+ self .db_dir ,
162+ embedder_batch_size = self .batch_size ,
163+ embedder_max_length = self .max_length ,
164+ embedder_use_cache = self .embedder_use_cache ,
165+ )
166+ self .vector_index = vector_index_client .create_index (self .embedder_name , utterances , labels )
167+
157168 self .embedder = Embedder (
158169 device = self .embedder_device ,
159170 model_name = self .embedder_name ,
@@ -224,25 +235,20 @@ def dump(self, path: str) -> None:
224235
225236 :param path: Path to the directory where assets will be dumped.
226237 """
227- self .metadata = VectorDBMetadata (
238+ self .metadata = LogRegMetadata (
228239 batch_size = self .batch_size ,
229240 max_length = self .max_length ,
230241 db_dir = str (self .db_dir ),
242+ classes = self .label_encoder .classes_ .tolist (),
231243 )
232244
233245 dump_dir = Path (path )
234246 with (dump_dir / self .metadata_dict_name ).open ("w" ) as file :
235247 json .dump (self .metadata , file , indent = 4 )
236248 self .vector_index .dump (dump_dir )
237249
238- self .classifier_metadata = ClassifierMetadata (
239- coef_ = self .classifier .coef_ .tolist (),
240- intercept_ = self .classifier .intercept_ .tolist (),
241- classes = self .label_encoder .classes_ .tolist (),
242- params = self .classifier .get_params (),
243- )
244- with (dump_dir / "classifier.json" ).open ("w" ) as file :
245- json .dump (self .classifier_metadata , file , indent = 4 )
250+ classifier_path = dump_dir / "classifier.joblib"
251+ joblib .dump (self .classifier , classifier_path )
246252
247253 def load (self , path : str ) -> None :
248254 """
@@ -251,8 +257,9 @@ def load(self, path: str) -> None:
251257 :param path: Path to the directory containing the dumped assets.
252258 """
253259 dump_dir = Path (path )
260+
254261 with (dump_dir / self .metadata_dict_name ).open () as file :
255- self .metadata : VectorDBMetadata = json .load (file )
262+ self .metadata : LogRegMetadata = json .load (file )
256263
257264 vector_index_client = VectorIndexClient (
258265 embedder_device = self .embedder_device ,
@@ -263,16 +270,10 @@ def load(self, path: str) -> None:
263270 )
264271 self .vector_index = vector_index_client .get_index (self .embedder_name )
265272
266- with (dump_dir / "classifier.json" ).open () as file :
267- self .classifier_metadata : ClassifierMetadata = json .load (file )
268-
269- self .classifier = LogisticRegressionCV ()
270- self .classifier .set_params (** self .classifier_metadata ["params" ])
271- self .classifier .coef_ = self .classifier_metadata ["coef_" ]
272- self .classifier .intercept_ = self .classifier_metadata ["intercept_" ]
273-
273+ classifier_path = dump_dir / "classifier.joblib"
274+ self .classifier = joblib_load (classifier_path )
274275 self .label_encoder = LabelEncoder ()
275- self .label_encoder .classes_ = self .classifier_metadata ["classes" ]
276+ self .label_encoder .classes_ = self .metadata ["classes" ]
276277
277278 def predict (self , utterances : list [str ]) -> tuple [list [list [int | list [int ]]], list [list [float ]], list [list [str ]]]:
278279 pass
@@ -448,7 +449,7 @@ def dump(self, path: str) -> None:
448449
449450 :param path: Path to the directory where assets will be dumped.
450451 """
451- self .metadata = VectorDBMetadata (
452+ self .metadata = RetrievalMetadata (
452453 batch_size = self .batch_size ,
453454 max_length = self .max_length ,
454455 db_dir = str (self .db_dir ),
@@ -467,7 +468,7 @@ def load(self, path: str) -> None:
467468 """
468469 dump_dir = Path (path )
469470 with (dump_dir / self .metadata_dict_name ).open () as file :
470- self .metadata : VectorDBMetadata = json .load (file )
471+ self .metadata : RetrievalMetadata = json .load (file )
471472
472473 vector_index_client = VectorIndexClient (
473474 embedder_device = self .embedder_device ,
0 commit comments