Skip to content

Commit 9704c29

Browse files
committed
make tests runnable even without sentence-transformers dependency
1 parent 5e2ad52 commit 9704c29

File tree

8 files changed

+59
-4
lines changed

8 files changed

+59
-4
lines changed

tests/_transformers/test_nli_transformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from autointent import Dataset, Ranker
66
from autointent.context.data_handler import DataHandler
77

8+
pytest.importorskip("sentence-transformers")
9+
810

911
@pytest.fixture
1012
def data_handler():

tests/embedder/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
# Check if OpenAI API key is available for testing
99
openai_available = os.getenv("OPENAI_API_KEY") is not None
1010

11+
pytest.importorskip("sentence-transformers")
12+
1113

1214
@pytest.fixture
1315
def on_windows() -> bool:

tests/modules/scoring/test_bert.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from autointent.context.data_handler import DataHandler
99
from autointent.modules import BertScorer
1010

11+
pytest.importorskip("transformers")
12+
1113

1214
def test_bert_scorer_dump_load(dataset):
1315
"""Test that BertScorer can be saved and loaded while preserving predictions."""

tests/modules/scoring/test_description_bi.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from autointent.context.data_handler import DataHandler
77
from autointent.modules import BiEncoderDescriptionScorer
88

9+
pytest.importorskip("sentence-transformers")
10+
911

1012
@pytest.mark.parametrize(
1113
("expected_prediction", "multilabel"),

tests/modules/scoring/test_description_cross.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from autointent.context.data_handler import DataHandler
77
from autointent.modules import CrossEncoderDescriptionScorer
88

9+
pytest.importorskip("sentence-transformers")
10+
911

1012
@pytest.mark.parametrize(
1113
("expected_prediction", "multilabel"),

tests/modules/scoring/test_dnnc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from autointent.context.data_handler import DataHandler
77
from autointent.modules import DNNCScorer
88

9+
pytest.importorskip("sentence-transformers")
10+
911

1012
@pytest.mark.parametrize(("train_head", "pred_score"), [(True, 1)])
1113
def test_base_dnnc(dataset, train_head, pred_score):

tests/modules/scoring/test_rerank_scorer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import tempfile
22

33
import numpy as np
4+
import pytest
45

56
from autointent.context.data_handler import DataHandler
67
from autointent.modules import RerankScorer
78

9+
pytest.importorskip("sentence-transformers")
10+
811

912
def test_base_rerank_scorer(dataset):
1013
data_handler = DataHandler(dataset)

tests/modules/test_dumper.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import pytest
66
import torch
77
from sklearn.linear_model import LogisticRegression
8-
from transformers import AutoModelForSequenceClassification, AutoTokenizer
98

109
from autointent import Embedder, Ranker, VectorIndex
1110
from autointent._dump_tools import Dumper
@@ -39,6 +38,8 @@ def check_attributes(self):
3938

4039
class TestTransformers:
4140
def init_attributes(self):
41+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
42+
4243
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
4344
self._tokenizer_predictions = np.array(self.tokenizer(["hello", "world"]).input_ids)
4445
self.transformer = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
@@ -143,16 +144,55 @@ def check_attributes(self):
143144
assert not self.pydantic_model.tokenizer_config.truncation
144145

145146

147+
def _st_is_installed() -> bool:
148+
try:
149+
import sentence_transformers # noqa: F401
150+
except ImportError:
151+
return False
152+
else:
153+
return True
154+
155+
156+
def _transformers_is_installed() -> bool:
157+
try:
158+
import transformers # noqa: F401
159+
except ImportError:
160+
return False
161+
else:
162+
return True
163+
164+
146165
@pytest.mark.parametrize(
147166
"test_class",
148167
[
149168
TestSimpleAttributes,
150169
TestTags,
151-
TestTransformers,
170+
pytest.param(
171+
TestTransformers,
172+
marks=pytest.mark.skipif(
173+
not _transformers_is_installed(),
174+
reason="need transformers dependency",
175+
),
176+
id="transformer",
177+
),
152178
TestVectorIndex,
153-
TestEmbedder,
179+
pytest.param(
180+
TestEmbedder,
181+
marks=pytest.mark.skipif(
182+
not _st_is_installed(),
183+
reason="need sentence-transformers dependency",
184+
),
185+
id="embedder",
186+
),
154187
TestSklearnEstimator,
155-
TestRanker,
188+
pytest.param(
189+
TestRanker,
190+
marks=pytest.mark.skipif(
191+
not _st_is_installed(),
192+
reason="need sentence-transformers dependency",
193+
),
194+
id="ranker",
195+
),
156196
TestCrossEncoderConfig,
157197
],
158198
)

0 commit comments

Comments
 (0)