Skip to content

Commit f3fe5cc

Browse files
committed
fix: load and dump
1 parent fe9a587 commit f3fe5cc

File tree

1 file changed

+27
-26
lines changed

1 file changed

+27
-26
lines changed

autointent/modules/embedding/_retrieval.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from pathlib import Path
55
from typing import Literal
66

7+
import joblib
8+
from joblib import load as joblib_load
79
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
810
from sklearn.multioutput import MultiOutputClassifier
911
from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer
@@ -16,20 +18,20 @@
1618
from autointent.modules.abc import EmbeddingModule
1719

1820

19-
class VectorDBMetadata(BaseMetadataDict):
21+
class RetrievalMetadata(BaseMetadataDict):
2022
"""Metadata class for RetrievalEmbedding."""
2123

2224
db_dir: str
2325
batch_size: int
2426
max_length: int | None
2527

2628

27-
class ClassifierMetadata(BaseMetadataDict):
29+
class LogRegMetadata(BaseMetadataDict):
2830
"""Metadata class for LogisticRegressionCV and LabelEncoder."""
2931

30-
coef_: list[list[float]]
31-
intercept_: list[float]
32-
params: dict[str, any]
32+
db_dir: str
33+
batch_size: int
34+
max_length: int | None
3335
classes: list[str]
3436

3537

@@ -154,6 +156,15 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
154156
"""
155157
self._multilabel = isinstance(labels[0], list)
156158

159+
vector_index_client = VectorIndexClient(
160+
self.embedder_device,
161+
self.db_dir,
162+
embedder_batch_size=self.batch_size,
163+
embedder_max_length=self.max_length,
164+
embedder_use_cache=self.embedder_use_cache,
165+
)
166+
self.vector_index = vector_index_client.create_index(self.embedder_name, utterances, labels)
167+
157168
self.embedder = Embedder(
158169
device=self.embedder_device,
159170
model_name=self.embedder_name,
@@ -224,25 +235,20 @@ def dump(self, path: str) -> None:
224235
225236
:param path: Path to the directory where assets will be dumped.
226237
"""
227-
self.metadata = VectorDBMetadata(
238+
self.metadata = LogRegMetadata(
228239
batch_size=self.batch_size,
229240
max_length=self.max_length,
230241
db_dir=str(self.db_dir),
242+
classes=self.label_encoder.classes_.tolist(),
231243
)
232244

233245
dump_dir = Path(path)
234246
with (dump_dir / self.metadata_dict_name).open("w") as file:
235247
json.dump(self.metadata, file, indent=4)
236248
self.vector_index.dump(dump_dir)
237249

238-
self.classifier_metadata = ClassifierMetadata(
239-
coef_=self.classifier.coef_.tolist(),
240-
intercept_=self.classifier.intercept_.tolist(),
241-
classes=self.label_encoder.classes_.tolist(),
242-
params=self.classifier.get_params(),
243-
)
244-
with (dump_dir / "classifier.json").open("w") as file:
245-
json.dump(self.classifier_metadata, file, indent=4)
250+
classifier_path = dump_dir / "classifier.joblib"
251+
joblib.dump(self.classifier, classifier_path)
246252

247253
def load(self, path: str) -> None:
248254
"""
@@ -251,8 +257,9 @@ def load(self, path: str) -> None:
251257
:param path: Path to the directory containing the dumped assets.
252258
"""
253259
dump_dir = Path(path)
260+
254261
with (dump_dir / self.metadata_dict_name).open() as file:
255-
self.metadata: VectorDBMetadata = json.load(file)
262+
self.metadata: LogRegMetadata = json.load(file)
256263

257264
vector_index_client = VectorIndexClient(
258265
embedder_device=self.embedder_device,
@@ -263,16 +270,10 @@ def load(self, path: str) -> None:
263270
)
264271
self.vector_index = vector_index_client.get_index(self.embedder_name)
265272

266-
with (dump_dir / "classifier.json").open() as file:
267-
self.classifier_metadata: ClassifierMetadata = json.load(file)
268-
269-
self.classifier = LogisticRegressionCV()
270-
self.classifier.set_params(**self.classifier_metadata["params"])
271-
self.classifier.coef_ = self.classifier_metadata["coef_"]
272-
self.classifier.intercept_ = self.classifier_metadata["intercept_"]
273-
273+
classifier_path = dump_dir / "classifier.joblib"
274+
self.classifier = joblib_load(classifier_path)
274275
self.label_encoder = LabelEncoder()
275-
self.label_encoder.classes_ = self.classifier_metadata["classes"]
276+
self.label_encoder.classes_ = self.metadata["classes"]
276277

277278
def predict(self, utterances: list[str]) -> tuple[list[list[int | list[int]]], list[list[float]], list[list[str]]]:
278279
pass
@@ -448,7 +449,7 @@ def dump(self, path: str) -> None:
448449
449450
:param path: Path to the directory where assets will be dumped.
450451
"""
451-
self.metadata = VectorDBMetadata(
452+
self.metadata = RetrievalMetadata(
452453
batch_size=self.batch_size,
453454
max_length=self.max_length,
454455
db_dir=str(self.db_dir),
@@ -467,7 +468,7 @@ def load(self, path: str) -> None:
467468
"""
468469
dump_dir = Path(path)
469470
with (dump_dir / self.metadata_dict_name).open() as file:
470-
self.metadata: VectorDBMetadata = json.load(file)
471+
self.metadata: RetrievalMetadata = json.load(file)
471472

472473
vector_index_client = VectorIndexClient(
473474
embedder_device=self.embedder_device,

0 commit comments

Comments
 (0)