Skip to content

Commit e95e11c

Browse files
committed
fix embedder tests
1 parent 6507f5e commit e95e11c

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

tests/embedder/test_dump_load.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def embedder(self, embedder_config: EmbedderConfig) -> Embedder:
3535
"""Create an Embedder instance for testing."""
3636
return Embedder(embedder_config)
3737

38-
def test_dump_load_cycle(self, embedder: Embedder, on_windows):
38+
def test_dump_load_cycle(self, embedder: Embedder, on_windows, embedder_config: EmbedderConfig): # noqa: ARG002
3939
"""Test complete dump/load cycle preserves functionality."""
4040
with tempfile.TemporaryDirectory(ignore_cleanup_errors=on_windows) as temp_dir:
4141
temp_path = Path(temp_dir)
@@ -54,13 +54,22 @@ def test_dump_load_cycle(self, embedder: Embedder, on_windows):
5454
loaded_embeddings = embedder_loaded.embed(test_utterances)
5555
np.testing.assert_allclose(original_embeddings, loaded_embeddings, rtol=1e-3)
5656

57-
# Test configuration preservation
58-
assert embedder_loaded.config.model_name == embedder.config.model_name
59-
assert embedder_loaded.config.default_prompt == embedder.config.default_prompt
60-
assert embedder_loaded.config.batch_size == embedder.config.batch_size
57+
# Test configuration preservation (only for configs that have these attributes)
58+
if hasattr(embedder.config, "model_name"):
59+
assert embedder_loaded.config.model_name == embedder.config.model_name
60+
if hasattr(embedder.config, "default_prompt"):
61+
assert embedder_loaded.config.default_prompt == embedder.config.default_prompt
62+
if hasattr(embedder.config, "batch_size"):
63+
assert embedder_loaded.config.batch_size == embedder.config.batch_size
6164

62-
def test_load_with_config_override(self, embedder: Embedder, on_windows):
65+
def test_load_with_config_override(self, embedder: Embedder, on_windows, embedder_config: EmbedderConfig): # noqa: ARG002
6366
"""Test loading with configuration override."""
67+
from autointent.configs import HashingVectorizerEmbeddingConfig, OpenaiEmbeddingConfig
68+
69+
# Skip for HashingVectorizer as it doesn't support batch_size override
70+
if isinstance(embedder.config, HashingVectorizerEmbeddingConfig):
71+
pytest.skip("HashingVectorizer doesn't support batch_size configuration")
72+
6473
with tempfile.TemporaryDirectory(ignore_cleanup_errors=on_windows) as temp_dir:
6574
temp_path = Path(temp_dir)
6675

@@ -72,8 +81,6 @@ def test_load_with_config_override(self, embedder: Embedder, on_windows):
7281
override_config = SentenceTransformerEmbeddingConfig(batch_size=16)
7382
else:
7483
# For OpenAI, we can override batch_size too
75-
from autointent.configs import OpenaiEmbeddingConfig
76-
7784
override_config = OpenaiEmbeddingConfig(batch_size=16)
7885

7986
# Load with override

tests/embedder/test_prompts.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ class TestEmbedderPrompts:
1414
@pytest.fixture
1515
def prompt_embedder_config(self, embedder_config: EmbedderConfig) -> EmbedderConfig:
1616
"""Create embedder config with different prompts based on backend type."""
17+
from autointent.configs import HashingVectorizerEmbeddingConfig
18+
19+
# Skip for HashingVectorizer as it doesn't support prompts
20+
if isinstance(embedder_config, HashingVectorizerEmbeddingConfig):
21+
pytest.skip("HashingVectorizer doesn't support prompts")
22+
1723
if hasattr(embedder_config, "similarity_fn_name"):
1824
# SentenceTransformers config
1925
return create_sentence_transformer_config(
@@ -49,6 +55,12 @@ def test_different_task_prompts(self, prompt_embedder_config: EmbedderConfig):
4955

5056
def test_fallback_to_default_prompt(self, embedder_config: EmbedderConfig):
5157
"""Test fallback to default prompt when specific prompt not set."""
58+
from autointent.configs import HashingVectorizerEmbeddingConfig
59+
60+
# Skip for HashingVectorizer as it doesn't support prompts
61+
if isinstance(embedder_config, HashingVectorizerEmbeddingConfig):
62+
pytest.skip("HashingVectorizer doesn't support prompts")
63+
5264
if hasattr(embedder_config, "similarity_fn_name"):
5365
# SentenceTransformers config
5466
config = create_sentence_transformer_config(

0 commit comments

Comments
 (0)