Skip to content

Commit 34fb1ca

Browse files
committed
fix: fixed retrieval and logreg test
1 parent a25afe2 commit 34fb1ca

File tree

2 files changed

+12
-21
lines changed

2 files changed

+12
-21
lines changed

tests/modules/retrieval/test_logreg.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import shutil
2-
from pathlib import Path
32
from unittest.mock import MagicMock
43

54
import numpy as np
@@ -9,15 +8,13 @@
98

109

1110
def test_get_assets_returns_correct_artifact_for_logreg():
12-
db_dir, dump_dir, logs_dir = setup_environment()
13-
module = LogRegEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir)
11+
module = LogRegEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo")
1412
artifact = module.get_assets()
1513
assert artifact.embedder_name == "sergeyzh/rubert-tiny-turbo"
1614

1715

1816
def test_fit_trains_model():
19-
db_dir, dump_dir, logs_dir = setup_environment()
20-
module = LogRegEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir)
17+
module = LogRegEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo")
2118

2219
utterances = ["hello", "goodbye", "hi", "bye", "bye", "hello", "welcome", "hi123", "hiii", "bye-bye", "bye!"]
2320
labels = [0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1]
@@ -29,8 +26,7 @@ def test_fit_trains_model():
2926

3027

3128
def test_score_evaluates_model():
32-
db_dir, dump_dir, logs_dir = setup_environment()
33-
module = LogRegEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir)
29+
module = LogRegEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo")
3430

3531
utterances = ["hello", "goodbye", "hi", "bye", "bye", "hello", "welcome", "hi123", "hiii", "bye-bye", "bye!"]
3632
labels = [0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1]
@@ -50,24 +46,22 @@ def mock_metric_fn(true_labels, predicted_labels):
5046

5147

5248
def test_dump_and_load_preserves_model_state():
53-
db_dir, dump_dir, logs_dir = setup_environment()
54-
module = LogRegEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir)
49+
dump_dir, _ = setup_environment()
50+
module = LogRegEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo")
5551

5652
utterances = ["hello", "goodbye", "hi", "bye", "bye", "hello", "welcome", "hi123", "hiii", "bye-bye", "bye!"]
5753
labels = [0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1]
5854
module.fit(utterances, labels)
5955

60-
dump_path = Path(dump_dir)
61-
dump_path.mkdir(parents=True, exist_ok=True)
62-
module.dump(str(dump_path))
56+
module.dump(dump_dir)
6357

64-
loaded_module = LogRegEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir)
65-
loaded_module.load(str(dump_path))
58+
loaded_module = LogRegEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo")
59+
loaded_module.load(dump_dir)
6660
epsilon = 1e-6
6761

6862
assert np.allclose(loaded_module.classifier.coef_, module.classifier.coef_, atol=epsilon)
6963
assert np.allclose(loaded_module.classifier.intercept_, module.classifier.intercept_, atol=epsilon)
7064
assert np.array_equal(np.array(loaded_module.label_encoder.classes_), np.array(module.label_encoder.classes_))
7165
assert loaded_module.embedder_name == module.embedder_name
7266

73-
shutil.rmtree(dump_path)
67+
shutil.rmtree(dump_dir)
Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import shutil
2-
from pathlib import Path
32

43
from autointent.modules.embedding import RetrievalEmbedding
54
from tests.conftest import setup_environment
65

76

87
def test_get_assets_returns_correct_artifact():
9-
dump_dir, _ = setup_environment()
108
module = RetrievalEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo")
119
artifact = module.get_assets()
1210
assert artifact.embedder_name == "sergeyzh/rubert-tiny-turbo"
@@ -20,12 +18,11 @@ def test_dump_and_load_preserves_model_state():
2018
labels = [0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1]
2119
module.fit(utterances, labels)
2220

23-
dump_path = Path(dump_dir)
24-
module.dump(dump_path)
21+
module.dump(dump_dir)
2522

2623
loaded_module = RetrievalEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo")
27-
loaded_module.load(dump_path)
24+
loaded_module.load(dump_dir)
2825

2926
assert loaded_module.embedder_name == module.embedder_name
3027

31-
shutil.rmtree(dump_path)
28+
shutil.rmtree(dump_dir)

0 commit comments

Comments
 (0)