1717from autointent .modules .abc import EmbeddingModule
1818
1919
20- class RetrievalMetadata (BaseMetadataDict ):
21- """Metadata class for RetrievalEmbedding."""
22-
23- db_dir : str
24- batch_size : int
25- max_length : int | None
26-
27-
2820class LogRegMetadata (BaseMetadataDict ):
2921 """Metadata class for LogisticRegressionCV and LabelEncoder."""
3022
31- db_dir : str
32- batch_size : int
33- max_length : int | None
3423 classes : list [str ]
3524
3625
@@ -79,27 +68,26 @@ def __init__(
7968 k : int ,
8069 embedder_name : str ,
8170 cv : int = 3 ,
82- db_dir : str | None = None ,
8371 embedder_device : str = "cpu" ,
84- batch_size : int = 32 ,
85- max_length : int | None = None ,
86- embedder_use_cache : bool = False ,
72+ embedder_batch_size : int = 32 ,
73+ embedder_max_length : int | None = None ,
74+ embedder_use_cache : bool = True ,
8775 ) -> None :
8876 """
89- Initialize the RetrievalEmbedding .
77+ Initialize the LogRegEmbedding .
9078
79+ :param cv:
80+ :param k: Number of nearest neighbors to retrieve.
9181 :param embedder_name: Name of the embedder used for creating embeddings.
92- :param db_dir: Path to the database directory. If None, defaults will be used.
9382 :param embedder_device: Device to run operations on, e.g., "cpu" or "cuda".
9483 :param batch_size: Batch size for embedding generation.
9584 :param max_length: Maximum sequence length for embeddings. None if not set.
9685 :param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
9786 """
9887 self .embedder_name = embedder_name
9988 self .embedder_device = embedder_device
100- self ._db_dir = db_dir
101- self .batch_size = batch_size
102- self .max_length = max_length
89+ self .embedder_batch_size = embedder_batch_size
90+ self .embedder_max_length = embedder_max_length
10391 self .embedder_use_cache = embedder_use_cache
10492 self .cv = cv
10593
@@ -116,21 +104,25 @@ def from_context(
116104 """
117105 Create a LogRegEmbedding instance using a Context object.
118106
107+ :param cv:
119108 :param context: The context containing configurations and utilities.
109+ :param k: Number of nearest neighbors to retrieve.
120110 :param embedder_name: Name of the embedder to use.
121111 :return: Initialized LogRegEmbedding instance.
122112 """
123113 return cls (
124114 k = k ,
125115 cv = cv ,
126116 embedder_name = embedder_name ,
127- db_dir = str (context .get_db_dir ()),
128117 embedder_device = context .get_device (),
129- batch_size = context .get_batch_size (),
130- max_length = context .get_max_length (),
118+ embedder_batch_size = context .get_batch_size (),
119+ embedder_max_length = context .get_max_length (),
131120 embedder_use_cache = context .get_use_cache (),
132121 )
133122
123+ def clear_cache (self ) -> None :
124+ """Clear cached data in memory used by the vector index."""
125+
134126 def fit (self , utterances : list [str ], labels : list [LabelType ]) -> None :
135127 """
136128 Train the logistic regression model using the provided utterances and labels.
@@ -140,23 +132,15 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
140132 """
141133 self ._multilabel = isinstance (labels [0 ], list )
142134
143- self ._vector_index = VectorIndex (
144- self .embedder_name ,
145- self .embedder_device ,
146- self .embedder_batch_size ,
147- self .embedder_max_length ,
148- self .embedder_use_cache ,
149- )
150- self ._vector_index .add (utterances , labels )
151-
152135 self .embedder = Embedder (
153136 device = self .embedder_device ,
154- model_name = self .embedder_name ,
155- batch_size = self .batch_size ,
156- max_length = self .max_length ,
137+ model_name_or_path = self .embedder_name ,
138+ batch_size = self .embedder_batch_size ,
139+ max_length = self .embedder_max_length ,
157140 use_cache = self .embedder_use_cache ,
158141 )
159142 embeddings = self .embedder .embed (utterances )
143+
160144 if self ._multilabel :
161145 self .label_encoder = MultiLabelBinarizer ()
162146 encoded_labels = self .label_encoder .fit_transform (labels )
@@ -209,42 +193,33 @@ def get_assets(self) -> RetrieverArtifact:
209193 """
210194 return RetrieverArtifact (embedder_name = self .embedder_name )
211195
212- def clear_cache (self ) -> None :
213- """Clear cached data in memory used by the vector index."""
214- self .vector_index .clear_ram ()
215-
216- def dump (self , path : str ) -> None :
196+ def dump (self , path : Path ) -> None :
217197 """
218198 Save the module's metadata, classifier parameters, and label encoder to a specified directory.
219199
220200 :param path: Path to the directory where assets will be dumped.
221201 """
222- self .metadata = LogRegMetadata (
223- batch_size = self .batch_size ,
224- max_length = self .max_length ,
225- db_dir = str (self .db_dir ),
202+ metadata = LogRegMetadata (
226203 classes = self .label_encoder .classes_ .tolist (),
227204 )
228205
229- self ._vector_index .dump (Path (path ))
206+ path .mkdir (parents = True , exist_ok = True )
207+ with (path / self .metadata_dict_name ).open ("w" ) as file :
208+ json .dump (metadata , file , indent = 4 )
230209
231210 classifier_path = "classifier.joblib"
232- joblib .dump (self .classifier , classifier_path )
211+ joblib .dump (self .classifier , path / classifier_path )
233212
234- def load (self , path : str ) -> None :
213+ def load (self , path : Path ) -> None :
235214 """
236215 Load the module's metadata and model parameters from a specified directory.
237216
238217 :param path: Path to the directory containing the dumped assets.
239218 """
240- dump_dir = Path (path )
241-
242- with (dump_dir / self .metadata_dict_name ).open () as file :
219+ with (path / self .metadata_dict_name ).open () as file :
243220 self .metadata : LogRegMetadata = json .load (file )
244221
245- self ._vector_index = VectorIndex .load (Path (path ))
246-
247- classifier_path = dump_dir / "classifier.joblib"
222+ classifier_path = path / "classifier.joblib"
248223 self .classifier = joblib_load (classifier_path )
249224 self .label_encoder = LabelEncoder ()
250225 self .label_encoder .classes_ = self .metadata ["classes" ]
0 commit comments