|
5 | 5 | import pytest |
6 | 6 | import torch |
7 | 7 | from sklearn.linear_model import LogisticRegression |
8 | | -from transformers import AutoModelForSequenceClassification, AutoTokenizer |
9 | 8 |
|
10 | 9 | from autointent import Embedder, Ranker, VectorIndex |
11 | 10 | from autointent._dump_tools import Dumper |
@@ -39,6 +38,8 @@ def check_attributes(self): |
39 | 38 |
|
40 | 39 | class TestTransformers: |
41 | 40 | def init_attributes(self): |
| 41 | + from transformers import AutoModelForSequenceClassification, AutoTokenizer |
| 42 | + |
42 | 43 | self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
43 | 44 | self._tokenizer_predictions = np.array(self.tokenizer(["hello", "world"]).input_ids) |
44 | 45 | self.transformer = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased") |
@@ -143,16 +144,55 @@ def check_attributes(self): |
143 | 144 | assert not self.pydantic_model.tokenizer_config.truncation |
144 | 145 |
|
145 | 146 |
|
| 147 | +def _st_is_installed() -> bool: |
| 148 | + try: |
| 149 | + import sentence_transformers # noqa: F401 |
| 150 | + except ImportError: |
| 151 | + return False |
| 152 | + else: |
| 153 | + return True |
| 154 | + |
| 155 | + |
| 156 | +def _transformers_is_installed() -> bool: |
| 157 | + try: |
| 158 | + import transformers # noqa: F401 |
| 159 | + except ImportError: |
| 160 | + return False |
| 161 | + else: |
| 162 | + return True |
| 163 | + |
| 164 | + |
146 | 165 | @pytest.mark.parametrize( |
147 | 166 | "test_class", |
148 | 167 | [ |
149 | 168 | TestSimpleAttributes, |
150 | 169 | TestTags, |
151 | | - TestTransformers, |
| 170 | + pytest.param( |
| 171 | + TestTransformers, |
| 172 | + marks=pytest.mark.skipif( |
| 173 | + not _transformers_is_installed(), |
| 174 | + reason="need transformers dependency", |
| 175 | + ), |
| 176 | + id="transformer", |
| 177 | + ), |
152 | 178 | TestVectorIndex, |
153 | | - TestEmbedder, |
| 179 | + pytest.param( |
| 180 | + TestEmbedder, |
| 181 | + marks=pytest.mark.skipif( |
| 182 | + not _st_is_installed(), |
| 183 | + reason="need sentence-transformers dependency", |
| 184 | + ), |
| 185 | + id="embedder", |
| 186 | + ), |
154 | 187 | TestSklearnEstimator, |
155 | | - TestRanker, |
| 188 | + pytest.param( |
| 189 | + TestRanker, |
| 190 | + marks=pytest.mark.skipif( |
| 191 | + not _st_is_installed(), |
| 192 | + reason="need sentence-transformers dependency", |
| 193 | + ), |
| 194 | + id="ranker", |
| 195 | + ), |
156 | 196 | TestCrossEncoderConfig, |
157 | 197 | ], |
158 | 198 | ) |
|
0 commit comments