@@ -23,6 +23,15 @@ class VectorDBMetadata(BaseMetadataDict):
2323 max_length : int | None
2424
2525
26+ class ClassifierMetadata (BaseMetadataDict ):
27+ """Metadata class for LogisticRegressionCV and LabelEncoder."""
28+
29+ coef_ : list [list [float ]]
30+ intercept_ : list [float ]
31+ params : dict [str , any ]
32+ classes : list [str ]
33+
34+
2635class LogRegEmbedding (EmbeddingModule ):
2736 r"""
2837 Module for managing classification operations using logistic regression.
@@ -63,6 +72,7 @@ class LogRegEmbedding(EmbeddingModule):
6372
6473 """
6574
75+ vector_index : VectorIndex
6676 classifier : LogisticRegressionCV
6777 label_encoder : LabelEncoder
6878 name = "logreg"
@@ -201,33 +211,29 @@ def clear_cache(self) -> None:
201211
202212 def dump (self , path : str ) -> None :
203213 """
204- Save the module's metadata and model parameters to a specified directory.
214+ Save the module's metadata, classifier parameters, and label encoder to a specified directory.
205215
206216 :param path: Path to the directory where assets will be dumped.
207217 """
208- metadata = VectorDBMetadata (
218+ self . metadata = VectorDBMetadata (
209219 batch_size = self .batch_size ,
210220 max_length = self .max_length ,
211- db_dir = self .db_dir ,
221+ db_dir = str ( self .db_dir ) ,
212222 )
213223
214224 dump_dir = Path (path )
215- with (dump_dir / "metadata.json" ).open ("w" ) as file :
216- json .dump (metadata .__dict__ , file , indent = 4 )
217-
218- model_path = dump_dir / "logreg_model.json"
219- with model_path .open ("w" ) as file :
220- json .dump (
221- {
222- "coef" : self .classifier .coef_ .tolist (),
223- "intercept" : self .classifier .intercept_ .tolist (),
224- "classes" : self .label_encoder .classes_ .tolist (),
225- },
226- file ,
227- indent = 4 ,
228- )
229-
230- super ().dump (path )
225+ with (dump_dir / self .metadata_dict_name ).open ("w" ) as file :
226+ json .dump (self .metadata , file , indent = 4 )
227+ self .vector_index .dump (dump_dir )
228+
229+ self .classifier_metadata = ClassifierMetadata (
230+ coef_ = self .classifier .coef_ .tolist (),
231+ intercept_ = self .classifier .intercept_ .tolist (),
232+ classes = self .label_encoder .classes_ .tolist (),
233+ params = self .classifier .get_params (),
234+ )
235+ with (dump_dir / "classifier.json" ).open ("w" ) as file :
236+ json .dump (self .classifier_metadata , file , indent = 4 )
231237
232238 def load (self , path : str ) -> None :
233239 """
@@ -236,24 +242,28 @@ def load(self, path: str) -> None:
236242 :param path: Path to the directory containing the dumped assets.
237243 """
238244 dump_dir = Path (path )
245+ with (dump_dir / self .metadata_dict_name ).open () as file :
246+ self .metadata : VectorDBMetadata = json .load (file )
239247
240- with (dump_dir / "metadata.json" ).open () as file :
241- metadata_dict = json .load (file )
242- self .batch_size = metadata_dict .get ("batch_size" , self .batch_size )
243- self .max_length = metadata_dict .get ("max_length" , self .max_length )
244- self ._db_dir = metadata_dict .get ("db_dir" , self ._db_dir )
245-
246- model_path = dump_dir / "logreg_model.json"
247- with model_path .open () as file :
248- model_data = json .load (file )
249- self .classifier = LogisticRegressionCV ()
250- self .k = model_data ["k" ]
251- self .classifier .coef_ = [model_data ["coef" ]]
252- self .classifier .intercept_ = model_data ["intercept" ]
253- self .label_encoder = LabelEncoder ()
254- self .label_encoder .classes_ = model_data ["classes" ]
255-
256- super ().load (path )
248+ vector_index_client = VectorIndexClient (
249+ embedder_device = self .embedder_device ,
250+ db_dir = self .metadata ["db_dir" ],
251+ embedder_batch_size = self .metadata ["batch_size" ],
252+ embedder_max_length = self .metadata ["max_length" ],
253+ embedder_use_cache = self .embedder_use_cache ,
254+ )
255+ self .vector_index = vector_index_client .get_index (self .embedder_name )
256+
257+ with (dump_dir / "classifier.json" ).open () as file :
258+ self .classifier_metadata : ClassifierMetadata = json .load (file )
259+
260+ self .classifier = LogisticRegressionCV ()
261+ self .classifier .set_params (** self .classifier_metadata ["params" ])
262+ self .classifier .coef_ = self .classifier_metadata ["coef_" ]
263+ self .classifier .intercept_ = self .classifier_metadata ["intercept_" ]
264+
265+ self .label_encoder = LabelEncoder ()
266+ self .label_encoder .classes_ = self .classifier_metadata ["classes" ]
257267
258268 def predict (self , utterances : list [str ]) -> list [int | list [int ]]:
259269 """
0 commit comments