Skip to content

Commit 18d035d

Browse files
committed
update tests
1 parent 0893800 commit 18d035d

File tree

11 files changed

+49
-39
lines changed

11 files changed

+49
-39
lines changed

tests/context/test_vector_index.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
import pytest
66

77
from autointent import VectorIndex
8-
from autointent.configs import EmbedderConfig, FaissConfig, OpenSearchConfig, get_default_embedder_config
8+
from autointent.configs import EmbedderConfig, FaissConfig, OpenSearchConfig
99
from autointent.custom_types import Document
10+
from tests.conftest import get_test_embedder_config
1011

1112
# Check if opensearch-py is available
1213
opensearch_available = True
@@ -56,7 +57,7 @@ class TestVectorIndex:
5657
@pytest.fixture
5758
def embedder_config(self) -> EmbedderConfig:
5859
"""Create a lightweight embedder config for testing."""
59-
return get_default_embedder_config(model_name="sentence-transformers/all-MiniLM-L6-v2")
60+
return get_test_embedder_config()
6061

6162
@pytest.fixture
6263
def vector_index(self, embedder_config: EmbedderConfig, vector_config) -> VectorIndex:
@@ -242,16 +243,13 @@ def test_load_with_embedder_override(
242243
vector_index.dump(dump_path)
243244

244245
# Create override config
245-
override_config = get_default_embedder_config(model_name="sentence-transformers/all-MiniLM-L6-v2")
246-
override_config.device = "cpu"
247-
override_config.batch_size = 1
246+
override_config = get_test_embedder_config()
248247

249248
# Load with override
250249
loaded_index = VectorIndex.load(dump_path, embedder_override_config=override_config)
251250

252-
# Check that override was applied
253-
assert loaded_index.embedder.config.device == "cpu"
254-
assert loaded_index.embedder.config.batch_size == 1
251+
# Check that loaded index works with overridden config
252+
assert loaded_index.embedder.config.n_features == 512
255253

256254
def test_error_handling_mismatched_lengths(self, vector_index: VectorIndex):
257255
"""Test error handling when texts and labels have different lengths."""
@@ -287,7 +285,7 @@ def test_abstract_config_raises_error(self):
287285
"""Test that using abstract VectorIndexConfig raises an error."""
288286
from autointent.configs import VectorIndexConfig
289287

290-
embedder_config = get_default_embedder_config(model_name="sentence-transformers/all-MiniLM-L6-v2")
288+
embedder_config = get_test_embedder_config()
291289

292290
vector_index = VectorIndex(embedder_config=embedder_config, config=VectorIndexConfig())
293291
with pytest.raises(TypeError, match="Passed abstract vector index config"):

tests/embedder/conftest.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33

44
import pytest
55

6-
from autointent.configs import OpenaiEmbeddingConfig, SentenceTransformerEmbeddingConfig
6+
from autointent.configs import (
7+
HashingVectorizerEmbeddingConfig,
8+
OpenaiEmbeddingConfig,
9+
SentenceTransformerEmbeddingConfig,
10+
)
711

812
# Check if OpenAI API key is available for testing
913
openai_available = os.getenv("OPENAI_API_KEY") is not None
@@ -18,6 +22,13 @@ def on_windows() -> bool:
1822

1923
# Backend configurations for parametrization
2024
backend_configs = [
25+
pytest.param(
26+
HashingVectorizerEmbeddingConfig(
27+
n_features=512,
28+
use_cache=False,
29+
),
30+
id="hashing_vectorizer",
31+
),
2132
pytest.param(
2233
SentenceTransformerEmbeddingConfig(
2334
model_name="sergeyzh/rubert-tiny-turbo",

tests/modules/embedding/test_logreg.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
import numpy as np
22

33
from autointent.modules.embedding import LogregAimedEmbedding
4-
from tests.conftest import setup_environment
4+
from tests.conftest import get_test_embedder_config, setup_environment
55

66

77
def test_get_assets_returns_correct_artifact_for_logreg():
8-
module = LogregAimedEmbedding(embedder_config="sergeyzh/rubert-tiny-turbo")
8+
module = LogregAimedEmbedding(embedder_config=get_test_embedder_config())
99
artifact = module.get_assets()
10-
assert artifact.config.model_name == "sergeyzh/rubert-tiny-turbo"
10+
assert artifact.config.n_features == 512
1111

1212

1313
def test_fit_trains_model():
14-
module = LogregAimedEmbedding(embedder_config="sergeyzh/rubert-tiny-turbo")
14+
module = LogregAimedEmbedding(embedder_config=get_test_embedder_config())
1515

1616
utterances = ["hello", "goodbye", "hi", "bye", "bye", "hello", "welcome", "hi123", "hiii", "bye-bye", "bye!"]
1717
labels = [0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1]
@@ -23,7 +23,7 @@ def test_fit_trains_model():
2323

2424

2525
def test_predict_evaluates_model():
26-
module = LogregAimedEmbedding(embedder_config="sergeyzh/rubert-tiny-turbo")
26+
module = LogregAimedEmbedding(embedder_config=get_test_embedder_config())
2727

2828
utterances = ["hello", "goodbye", "hi", "bye", "bye", "hello", "welcome", "hi123", "hiii", "bye-bye", "bye!"]
2929
labels = [0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1]
@@ -37,7 +37,7 @@ def test_predict_evaluates_model():
3737

3838

3939
def test_dump_load():
40-
module = LogregAimedEmbedding(embedder_config="sergeyzh/rubert-tiny-turbo")
40+
module = LogregAimedEmbedding(embedder_config=get_test_embedder_config())
4141
utterances = ["hello", "goodbye", "hi", "bye", "bye", "hello", "welcome", "hi123", "hiii", "bye-bye", "bye!"]
4242
labels = [0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1]
4343
module.fit(utterances, labels)

tests/modules/embedding/test_retrieval.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
from pathlib import Path
22

33
from autointent.modules.embedding import RetrievalAimedEmbedding
4+
from tests.conftest import get_test_embedder_config
45

56

67
def test_get_assets_returns_correct_artifact():
7-
module = RetrievalAimedEmbedding(k=5, embedder_config="sergeyzh/rubert-tiny-turbo")
8+
module = RetrievalAimedEmbedding(k=5, embedder_config=get_test_embedder_config())
89
artifact = module.get_assets()
9-
assert artifact.config.model_name == "sergeyzh/rubert-tiny-turbo"
10+
assert artifact.config.n_features == 512
1011

1112

1213
def test_dump_and_load_preserves_model_state(tmp_path: Path):
13-
module = RetrievalAimedEmbedding(k=5, embedder_config="sergeyzh/rubert-tiny-turbo")
14+
module = RetrievalAimedEmbedding(k=5, embedder_config=get_test_embedder_config())
1415

1516
utterances = ["hello", "goodbye", "hi", "bye", "bye", "hello", "welcome", "hi123", "hiii", "bye-bye", "bye!"]
1617
labels = [0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1]

tests/modules/scoring/test_catboost.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from autointent.context.data_handler import DataHandler
99
from autointent.modules import CatBoostScorer
10+
from tests.conftest import get_test_embedder_config
1011

1112
pytest.importorskip("catboost")
1213

@@ -55,7 +56,7 @@ def test_catboost_prediction_multilabel(dataset):
5556
data_handler = DataHandler(dataset.to_multilabel())
5657

5758
scorer = CatBoostScorer(
58-
embedder_config="prajjwal1/bert-tiny",
59+
embedder_config=get_test_embedder_config(),
5960
iterations=50,
6061
learning_rate=0.05,
6162
depth=6,
@@ -99,7 +100,7 @@ def test_catboost_features_types(dataset, features_type, use_embedding_features)
99100
data_handler = DataHandler(dataset)
100101

101102
scorer = CatBoostScorer(
102-
embedder_config="prajjwal1/bert-tiny",
103+
embedder_config=get_test_embedder_config(),
103104
iterations=50,
104105
learning_rate=0.05,
105106
depth=6,

tests/modules/scoring/test_gcn_scorer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from autointent import Dataset
66
from autointent.modules.scoring import GCNScorer
7+
from tests.conftest import get_test_embedder_config
78

89

910
@pytest.fixture
@@ -44,7 +45,7 @@ def multiclass_dataset():
4445

4546
def test_gcn_scorer_multilabel(multilabel_dataset):
4647
torch.manual_seed(42)
47-
scorer = GCNScorer(embedder_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=2, seed=42)
48+
scorer = GCNScorer(embedder_config=get_test_embedder_config(), num_train_epochs=1, batch_size=2, seed=42)
4849
train_utterances = multilabel_dataset["train"]["utterance"]
4950
train_labels = multilabel_dataset["train"]["label"]
5051
descriptions = [intent.name for intent in multilabel_dataset.intents]
@@ -59,7 +60,7 @@ def test_gcn_scorer_multilabel(multilabel_dataset):
5960

6061
def test_gcn_scorer_multiclass(multiclass_dataset):
6162
torch.manual_seed(42)
62-
scorer = GCNScorer(embedder_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=2, seed=42)
63+
scorer = GCNScorer(embedder_config=get_test_embedder_config(), num_train_epochs=1, batch_size=2, seed=42)
6364
train_utterances = multiclass_dataset["train"]["utterance"]
6465
train_labels = multiclass_dataset["train"]["label"]
6566
descriptions = [intent.name for intent in multiclass_dataset.intents]
@@ -75,7 +76,7 @@ def test_gcn_scorer_multiclass(multiclass_dataset):
7576

7677
def test_gcn_scorer_dump_load(tmp_path, multilabel_dataset):
7778
torch.manual_seed(42)
78-
scorer = GCNScorer(embedder_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=2, seed=42)
79+
scorer = GCNScorer(embedder_config=get_test_embedder_config(), num_train_epochs=1, batch_size=2, seed=42)
7980
train_utterances = multilabel_dataset["train"]["utterance"]
8081
train_labels = multilabel_dataset["train"]["label"]
8182
descriptions = [intent.name for intent in multilabel_dataset.intents]

tests/modules/scoring/test_knn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44

55
from autointent.context.data_handler import DataHandler
66
from autointent.modules import KNNScorer
7+
from tests.conftest import get_test_embedder_config
78

89

910
def test_base_knn(dataset):
1011
data_handler = DataHandler(dataset)
1112

12-
scorer = KNNScorer(k=3, weights="distance", embedder_config="sergeyzh/rubert-tiny-turbo")
13+
scorer = KNNScorer(k=3, weights="distance", embedder_config=get_test_embedder_config())
1314

1415
test_data = [
1516
"why is there a hold on my american saving bank account",

tests/modules/scoring/test_linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44

55
from autointent.context.data_handler import DataHandler
66
from autointent.modules import LinearScorer
7+
from tests.conftest import get_test_embedder_config
78

89

910
def test_base_linear(dataset):
1011
data_handler = DataHandler(dataset)
1112

12-
scorer = LinearScorer(embedder_config="sergeyzh/rubert-tiny-turbo")
13+
scorer = LinearScorer(embedder_config=get_test_embedder_config())
1314

1415
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
1516
test_data = [

tests/modules/scoring/test_mlknn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44

55
from autointent.context.data_handler import DataHandler
66
from autointent.modules.scoring import MLKnnScorer
7+
from tests.conftest import get_test_embedder_config
78

89

910
def test_base_mlknn(dataset):
1011
data_handler = DataHandler(dataset.to_multilabel())
1112

12-
scorer = MLKnnScorer(embedder_config="sergeyzh/rubert-tiny-turbo", k=3)
13+
scorer = MLKnnScorer(embedder_config=get_test_embedder_config(), k=3)
1314
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
1415

1516
test_data = [

tests/modules/scoring/test_sklearn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44

55
from autointent.context.data_handler import DataHandler
66
from autointent.modules import SklearnScorer
7+
from tests.conftest import get_test_embedder_config
78

89

910
def test_base_sklearn(dataset):
1011
data_handler = DataHandler(dataset)
1112

1213
scorer = SklearnScorer(
13-
embedder_config="sergeyzh/rubert-tiny-turbo",
14+
embedder_config=get_test_embedder_config(),
1415
clf_name="LogisticRegression",
1516
penalty="elasticnet",
1617
solver="saga",

0 commit comments

Comments
 (0)