Skip to content

Commit 5d4a8c9

Browse files
committed
feat: updated test
1 parent cd75d72 commit 5d4a8c9

File tree

1 file changed

+5
-9
lines changed

1 file changed

+5
-9
lines changed

tests/modules/retrieval/test_logreg.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from unittest.mock import MagicMock
2-
31
from autointent.modules.embedding import LogRegEmbedding
42

53

@@ -21,17 +19,15 @@ def test_fit_trains_model():
2119
assert module._label_encoder.classes_.tolist() == [0, 1]
2220

2321

24-
def test_score_evaluates_model():
22+
def test_predict_evaluates_model():
2523
module = LogRegEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo")
2624

2725
utterances = ["hello", "goodbye", "hi", "bye", "bye", "hello", "welcome", "hi123", "hiii", "bye-bye", "bye!"]
2826
labels = [0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1]
2927
module.fit(utterances, labels)
3028

31-
mock_context = MagicMock()
32-
mock_context.data_handler.test_utterances.return_value = ["hello", "goodbye"]
33-
mock_context.data_handler.test_labels.return_value = [[1, 0], [0, 1]]
34-
35-
scores = module.score(mock_context, split="test")
29+
probas = module.predict(["hello", "bye"])
3630

37-
assert isinstance(scores, dict)
31+
assert len(probas) == 2
32+
assert probas[0][0] > probas[0][1]
33+
assert probas[1][1] > probas[1][0]

0 commit comments

Comments
 (0)