Skip to content

Commit 0b0c1fa

Browse files
committed
add comprehensive tests
1 parent fcf1f31 commit 0b0c1fa

File tree

9 files changed

+615
-8
lines changed

9 files changed

+615
-8
lines changed

tests/embedder/test_basic.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import numpy as np
2+
import pytest
3+
4+
from autointent._wrappers.embedder import Embedder
5+
from autointent.configs._transformers import EmbedderConfig
6+
7+
8+
@pytest.fixture
9+
def simple_embedder_config():
10+
"""Create a simple embedder config for testing."""
11+
return EmbedderConfig(
12+
model_name="sergeyzh/rubert-tiny-turbo",
13+
batch_size=4,
14+
device="cpu",
15+
use_cache=False,
16+
)
17+
18+
19+
def test_embedding_calculation(simple_embedder_config):
20+
"""Test basic embedding calculation functionality."""
21+
embedder = Embedder(simple_embedder_config)
22+
test_utterances = ["Hello world", "Test sentence", "Another example"]
23+
24+
embeddings = embedder.embed(test_utterances)
25+
26+
assert embeddings.shape[0] == len(test_utterances)
27+
assert np.allclose(np.linalg.norm(embeddings, axis=1), 1.0, atol=1e-5) # normalized
28+
29+
30+
def test_embedding_reproducibility(simple_embedder_config):
31+
"""Test that embeddings are reproducible for same input."""
32+
embedder = Embedder(simple_embedder_config)
33+
test_utterances = ["Hello world", "Test sentence"]
34+
35+
embeddings1 = embedder.embed(test_utterances)
36+
embeddings2 = embedder.embed(test_utterances)
37+
38+
np.testing.assert_allclose(embeddings1, embeddings2, rtol=1e-5)
39+
40+
41+
def test_single_utterance(simple_embedder_config):
42+
"""Test embedding calculation for single utterance."""
43+
embedder = Embedder(simple_embedder_config)
44+
45+
embeddings = embedder.embed(["Single test sentence"])
46+
assert embeddings.shape[0] == 1
47+
assert np.allclose(np.linalg.norm(embeddings[0]), 1.0, atol=1e-5)
48+
49+
50+
def test_similarity_symmetry():
51+
"""Test that similarity is symmetric for cosine similarity."""
52+
config = EmbedderConfig(model_name="sergeyzh/rubert-tiny-turbo", similarity_fn_name="cosine", use_cache=False)
53+
embedder = Embedder(config)
54+
55+
utterances = ["Hello world", "Test sentence"]
56+
embeddings = embedder.embed(utterances)
57+
58+
sim1 = embedder.similarity(embeddings[:1], embeddings[1:])
59+
sim2 = embedder.similarity(embeddings[1:], embeddings[:1])
60+
61+
np.testing.assert_allclose(sim1, sim2.T, rtol=1e-5)

tests/embedder/test_caching.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import time
2+
from unittest.mock import patch
3+
4+
import numpy as np
5+
6+
from autointent._wrappers.embedder import Embedder
7+
from autointent.configs._transformers import EmbedderConfig
8+
9+
10+
def test_caching_enabled():
11+
"""Test that caching works when enabled."""
12+
config = EmbedderConfig(
13+
model_name="sergeyzh/rubert-tiny-turbo",
14+
use_cache=True,
15+
device="cpu",
16+
)
17+
embedder = Embedder(config)
18+
test_utterances = ["Cache test sentence"]
19+
20+
# Mock the actual embedding calculation to verify caching
21+
with patch.object(embedder, "_load_model") as mock_load:
22+
mock_model = mock_load.return_value
23+
mock_model.encode.return_value = np.array([[0.1, 0.2, 0.3]])
24+
25+
# First call should trigger model loading
26+
start_time = time.time()
27+
embeddings1 = embedder.embed(test_utterances)
28+
first_call_time = time.time() - start_time
29+
30+
# Second call should use cache (model.encode shouldn't be called again)
31+
start_time = time.time()
32+
embeddings2 = embedder.embed(test_utterances)
33+
second_call_time = time.time() - start_time
34+
35+
# Verify results are the same
36+
np.testing.assert_allclose(embeddings1, embeddings2, rtol=1e-5)
37+
38+
assert (
39+
second_call_time < first_call_time / 5
40+
), f"Second call ({second_call_time:.4f}s) should be much faster than first call ({first_call_time:.4f}s)"
41+
42+
43+
def test_caching_disabled():
44+
"""Test behavior when caching is disabled."""
45+
config = EmbedderConfig(
46+
model_name="sergeyzh/rubert-tiny-turbo",
47+
use_cache=False,
48+
device="cpu",
49+
)
50+
embedder = Embedder(config)
51+
test_utterances = ["No cache test"]
52+
53+
embeddings1 = embedder.embed(test_utterances)
54+
embeddings2 = embedder.embed(test_utterances)
55+
56+
# Should still be the same since same model/input
57+
np.testing.assert_allclose(embeddings1, embeddings2, rtol=1e-5)

tests/embedder/test_dump_load.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import tempfile
2+
from pathlib import Path
3+
4+
import numpy as np
5+
from sentence_transformers import SentenceTransformer
6+
7+
from autointent._wrappers.embedder import Embedder
8+
from autointent.configs._transformers import EmbedderConfig
9+
10+
11+
def test_load_from_disk():
12+
model = SentenceTransformer("sergeyzh/rubert-tiny-turbo")
13+
14+
with tempfile.TemporaryDirectory() as tmp_dir:
15+
model.save(str(Path(tmp_dir) / "weights"))
16+
embedder = Embedder(EmbedderConfig(model_name=str(Path(tmp_dir) / "weights")))
17+
predictions = embedder.embed(["hi!"])
18+
embedder.dump(Path(tmp_dir) / "embedder")
19+
embedder_loaded = Embedder.load(Path(tmp_dir) / "embedder")
20+
predictions_after = embedder_loaded.embed(["hi!"])
21+
22+
np.testing.assert_almost_equal(predictions_after, predictions, decimal=4)
23+
24+
25+
def test_dump_load_cycle():
26+
"""Test complete dump/load cycle preserves functionality."""
27+
original_config = EmbedderConfig(
28+
model_name="sergeyzh/rubert-tiny-turbo",
29+
default_prompt="Test prompt:",
30+
similarity_fn_name="cosine",
31+
batch_size=4,
32+
use_cache=False,
33+
)
34+
35+
with tempfile.TemporaryDirectory() as temp_dir:
36+
temp_path = Path(temp_dir)
37+
38+
# Create and test original embedder
39+
embedder_original = Embedder(original_config)
40+
test_utterances = ["Test sentence for persistence"]
41+
original_embeddings = embedder_original.embed(test_utterances)
42+
43+
# Dump embedder
44+
embedder_original.dump(temp_path)
45+
46+
# Load embedder
47+
embedder_loaded = Embedder.load(temp_path)
48+
49+
# Test that loaded embedder works the same
50+
loaded_embeddings = embedder_loaded.embed(test_utterances)
51+
np.testing.assert_allclose(original_embeddings, loaded_embeddings, rtol=1e-5)
52+
53+
# Test configuration preservation
54+
assert embedder_loaded.config.model_name == original_config.model_name
55+
assert embedder_loaded.config.default_prompt == original_config.default_prompt
56+
assert embedder_loaded.config.similarity_fn_name == original_config.similarity_fn_name
57+
58+
59+
def test_load_with_config_override():
60+
"""Test loading with configuration override."""
61+
original_config = EmbedderConfig(
62+
model_name="sergeyzh/rubert-tiny-turbo",
63+
batch_size=8,
64+
use_cache=False,
65+
)
66+
67+
with tempfile.TemporaryDirectory() as temp_dir:
68+
temp_path = Path(temp_dir)
69+
70+
# Create and dump original
71+
embedder_original = Embedder(original_config)
72+
embedder_original.dump(temp_path)
73+
74+
# Load with override
75+
override_config = EmbedderConfig(batch_size=16)
76+
embedder_loaded = Embedder.load(temp_path, override_config)
77+
78+
# Verify override took effect
79+
assert embedder_loaded.config.batch_size == 16
80+
# Verify original config preserved where not overridden
81+
assert embedder_loaded.config.model_name == original_config.model_name
Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,13 @@ def test_model_updates_after_training(dataset):
2222

2323
train_config = EmbedderFineTuningConfig(epoch_num=3, batch_size=8)
2424
embedder = Embedder(embedder_config)
25-
embedder._load_model()
25+
embedder._model = embedder._load_model()
2626

27-
for param in embedder.embedding_model.parameters():
27+
for param in embedder._model.parameters():
2828
assert param.requires_grad, "All trainable parameters should have requires_grad=True"
2929

3030
original_weights = [
31-
param.data.detach().cpu().numpy().copy()
32-
for param in embedder.embedding_model.parameters()
33-
if param.requires_grad
31+
param.data.detach().cpu().numpy().copy() for param in embedder._model.parameters() if param.requires_grad
3432
]
3533
embedder.train(
3634
utterances=data_handler.train_utterances(0)[:1000],
@@ -39,9 +37,7 @@ def test_model_updates_after_training(dataset):
3937
)
4038

4139
trained_weights = [
42-
param.data.detach().cpu().numpy().copy()
43-
for param in embedder.embedding_model.parameters()
44-
if param.requires_grad
40+
param.data.detach().cpu().numpy().copy() for param in embedder._model.parameters() if param.requires_grad
4541
]
4642

4743
weights_changed = any(

0 commit comments

Comments
 (0)