Skip to content

Commit d402dba

Browse files
committed
fixed: dump & load modules and added tests
1 parent f8e7de7 commit d402dba

File tree

4 files changed

+151
-45
lines changed

4 files changed

+151
-45
lines changed

autointent/modules/embedding/_retrieval.py

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@ class VectorDBMetadata(BaseMetadataDict):
2323
max_length: int | None
2424

2525

26+
class ClassifierMetadata(BaseMetadataDict):
27+
"""Metadata class for LogisticRegressionCV and LabelEncoder."""
28+
29+
coef_: list[list[float]]
30+
intercept_: list[float]
31+
params: dict[str, any]
32+
classes: list[str]
33+
34+
2635
class LogRegEmbedding(EmbeddingModule):
2736
r"""
2837
Module for managing classification operations using logistic regression.
@@ -63,6 +72,7 @@ class LogRegEmbedding(EmbeddingModule):
6372
6473
"""
6574

75+
vector_index: VectorIndex
6676
classifier: LogisticRegressionCV
6777
label_encoder: LabelEncoder
6878
name = "logreg"
@@ -201,33 +211,29 @@ def clear_cache(self) -> None:
201211

202212
def dump(self, path: str) -> None:
203213
"""
204-
Save the module's metadata and model parameters to a specified directory.
214+
Save the module's metadata, classifier parameters, and label encoder to a specified directory.
205215
206216
:param path: Path to the directory where assets will be dumped.
207217
"""
208-
metadata = VectorDBMetadata(
218+
self.metadata = VectorDBMetadata(
209219
batch_size=self.batch_size,
210220
max_length=self.max_length,
211-
db_dir=self.db_dir,
221+
db_dir=str(self.db_dir),
212222
)
213223

214224
dump_dir = Path(path)
215-
with (dump_dir / "metadata.json").open("w") as file:
216-
json.dump(metadata.__dict__, file, indent=4)
217-
218-
model_path = dump_dir / "logreg_model.json"
219-
with model_path.open("w") as file:
220-
json.dump(
221-
{
222-
"coef": self.classifier.coef_.tolist(),
223-
"intercept": self.classifier.intercept_.tolist(),
224-
"classes": self.label_encoder.classes_.tolist(),
225-
},
226-
file,
227-
indent=4,
228-
)
229-
230-
super().dump(path)
225+
with (dump_dir / self.metadata_dict_name).open("w") as file:
226+
json.dump(self.metadata, file, indent=4)
227+
self.vector_index.dump(dump_dir)
228+
229+
self.classifier_metadata = ClassifierMetadata(
230+
coef_=self.classifier.coef_.tolist(),
231+
intercept_=self.classifier.intercept_.tolist(),
232+
classes=self.label_encoder.classes_.tolist(),
233+
params=self.classifier.get_params(),
234+
)
235+
with (dump_dir / "classifier.json").open("w") as file:
236+
json.dump(self.classifier_metadata, file, indent=4)
231237

232238
def load(self, path: str) -> None:
233239
"""
@@ -236,24 +242,28 @@ def load(self, path: str) -> None:
236242
:param path: Path to the directory containing the dumped assets.
237243
"""
238244
dump_dir = Path(path)
245+
with (dump_dir / self.metadata_dict_name).open() as file:
246+
self.metadata: VectorDBMetadata = json.load(file)
239247

240-
with (dump_dir / "metadata.json").open() as file:
241-
metadata_dict = json.load(file)
242-
self.batch_size = metadata_dict.get("batch_size", self.batch_size)
243-
self.max_length = metadata_dict.get("max_length", self.max_length)
244-
self._db_dir = metadata_dict.get("db_dir", self._db_dir)
245-
246-
model_path = dump_dir / "logreg_model.json"
247-
with model_path.open() as file:
248-
model_data = json.load(file)
249-
self.classifier = LogisticRegressionCV()
250-
self.k = model_data["k"]
251-
self.classifier.coef_ = [model_data["coef"]]
252-
self.classifier.intercept_ = model_data["intercept"]
253-
self.label_encoder = LabelEncoder()
254-
self.label_encoder.classes_ = model_data["classes"]
255-
256-
super().load(path)
248+
vector_index_client = VectorIndexClient(
249+
embedder_device=self.embedder_device,
250+
db_dir=self.metadata["db_dir"],
251+
embedder_batch_size=self.metadata["batch_size"],
252+
embedder_max_length=self.metadata["max_length"],
253+
embedder_use_cache=self.embedder_use_cache,
254+
)
255+
self.vector_index = vector_index_client.get_index(self.embedder_name)
256+
257+
with (dump_dir / "classifier.json").open() as file:
258+
self.classifier_metadata: ClassifierMetadata = json.load(file)
259+
260+
self.classifier = LogisticRegressionCV()
261+
self.classifier.set_params(**self.classifier_metadata["params"])
262+
self.classifier.coef_ = self.classifier_metadata["coef_"]
263+
self.classifier.intercept_ = self.classifier_metadata["intercept_"]
264+
265+
self.label_encoder = LabelEncoder()
266+
self.label_encoder.classes_ = self.classifier_metadata["classes"]
257267

258268
def predict(self, utterances: list[str]) -> list[int | list[int]]:
259269
"""
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import shutil
2+
from pathlib import Path
3+
from unittest.mock import MagicMock
4+
5+
import numpy as np
6+
7+
from autointent.modules.embedding import LogRegEmbedding
8+
from tests.conftest import setup_environment
9+
10+
11+
def test_get_assets_returns_correct_artifact_for_logreg():
12+
db_dir, dump_dir, logs_dir = setup_environment()
13+
module = LogRegEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir)
14+
artifact = module.get_assets()
15+
assert artifact.embedder_name == "sergeyzh/rubert-tiny-turbo"
16+
17+
18+
def test_fit_trains_model():
19+
db_dir, dump_dir, logs_dir = setup_environment()
20+
module = LogRegEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir)
21+
22+
utterances = ["hello", "goodbye", "hi", "bye", "bye", "hello", "welcome", "hi123", "hiii", "bye-bye", "bye!"]
23+
labels = [0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1]
24+
module.fit(utterances, labels)
25+
26+
assert module.classifier.coef_ is not None
27+
assert len(module.classifier.coef_) > 0
28+
assert module.label_encoder.classes_.tolist() == [0, 1]
29+
30+
31+
def test_score_evaluates_model():
32+
db_dir, dump_dir, logs_dir = setup_environment()
33+
module = LogRegEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir)
34+
35+
utterances = ["hello", "goodbye", "hi", "bye", "bye", "hello", "welcome", "hi123", "hiii", "bye-bye", "bye!"]
36+
labels = [0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1]
37+
module.fit(utterances, labels)
38+
39+
mock_context = MagicMock()
40+
mock_context.data_handler.test_utterances.return_value = ["hello", "goodbye"]
41+
mock_context.data_handler.test_labels.return_value = [0, 1]
42+
43+
def mock_metric_fn(true_labels, predicted_labels):
44+
return sum(1 for t, p in zip(true_labels, predicted_labels[0], strict=False) if t == p) / len(true_labels)
45+
46+
score = module.score(mock_context, split="test", metric_fn=mock_metric_fn)
47+
48+
assert 0 <= score <= 1
49+
assert score > 0
50+
51+
52+
def test_dump_and_load_preserves_model_state():
53+
db_dir, dump_dir, logs_dir = setup_environment()
54+
module = LogRegEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir)
55+
56+
utterances = ["hello", "goodbye", "hi", "bye", "bye", "hello", "welcome", "hi123", "hiii", "bye-bye", "bye!"]
57+
labels = [0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1]
58+
module.fit(utterances, labels)
59+
60+
dump_path = Path(dump_dir)
61+
dump_path.mkdir(parents=True, exist_ok=True)
62+
module.dump(str(dump_path))
63+
64+
loaded_module = LogRegEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir)
65+
loaded_module.load(str(dump_path))
66+
epsilon = 1e-6
67+
68+
assert np.allclose(loaded_module.classifier.coef_, module.classifier.coef_, atol=epsilon)
69+
assert np.allclose(loaded_module.classifier.intercept_, module.classifier.intercept_, atol=epsilon)
70+
assert np.array_equal(np.array(loaded_module.label_encoder.classes_), np.array(module.label_encoder.classes_))
71+
assert loaded_module.embedder_name == module.embedder_name
72+
73+
shutil.rmtree(dump_path)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import shutil
2+
from pathlib import Path
3+
4+
from autointent.modules.embedding import RetrievalEmbedding
5+
from tests.conftest import setup_environment
6+
7+
8+
def test_get_assets_returns_correct_artifact():
9+
db_dir, dump_dir, logs_dir = setup_environment()
10+
module = RetrievalEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir)
11+
artifact = module.get_assets()
12+
assert artifact.embedder_name == "sergeyzh/rubert-tiny-turbo"
13+
14+
15+
def test_dump_and_load_preserves_model_state():
16+
db_dir, dump_dir, logs_dir = setup_environment()
17+
module = RetrievalEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir)
18+
19+
utterances = ["hello", "goodbye", "hi", "bye", "bye", "hello", "welcome", "hi123", "hiii", "bye-bye", "bye!"]
20+
labels = [0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1]
21+
module.fit(utterances, labels)
22+
23+
dump_path = Path(dump_dir)
24+
dump_path.mkdir(parents=True, exist_ok=True)
25+
module.dump(str(dump_path))
26+
27+
loaded_module = RetrievalEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir)
28+
loaded_module.load(str(dump_path))
29+
30+
assert loaded_module.embedder_name == module.embedder_name
31+
32+
shutil.rmtree(dump_path)

tests/modules/retrieval/test_vectordb.py

Lines changed: 0 additions & 9 deletions
This file was deleted.

0 commit comments

Comments
 (0)