|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import tempfile |
| 6 | +import warnings |
6 | 7 | from pathlib import Path |
7 | 8 | from typing import TYPE_CHECKING, Literal |
8 | 9 |
|
9 | 10 | import bioregistry |
10 | 11 | import curies |
11 | 12 | import numpy as np |
12 | 13 | import pandas as pd |
| 14 | +from pystow import get_sentence_transformer |
13 | 15 | from tqdm import tqdm |
14 | 16 | from typing_extensions import Unpack |
15 | 17 |
|
|
34 | 36 |
|
35 | 37 | def get_text_embedding_model() -> sentence_transformers.SentenceTransformer: |
36 | 38 | """Get the default text embedding model.""" |
37 | | - from sentence_transformers import SentenceTransformer |
38 | | - |
39 | | - model = SentenceTransformer("all-MiniLM-L6-v2") |
40 | | - return model |
| 39 | + warnings.warn( |
| 40 | + "get_text_embedding_model() is deprecated, use pystow.get_sentence_transfomer() directly", |
| 41 | + DeprecationWarning, |
| 42 | + stacklevel=2, |
| 43 | + ) |
| 44 | + return get_sentence_transformer() |
41 | 45 |
|
42 | 46 |
|
43 | 47 | def _get_text( |
@@ -157,7 +161,7 @@ def get_text_embeddings_df( |
157 | 161 | luids.append(identifier) |
158 | 162 | texts.append(text) |
159 | 163 | if model is None: |
160 | | - model = get_text_embedding_model() |
| 164 | + model = get_sentence_transformer() |
161 | 165 | res = model.encode(texts, show_progress_bar=True) |
162 | 166 | df = pd.DataFrame(res, index=luids) |
163 | 167 | df.to_csv(path, sep="\t") # index is important here! |
@@ -199,7 +203,7 @@ def get_text_embedding( |
199 | 203 | if text is None: |
200 | 204 | return None |
201 | 205 | if model is None: |
202 | | - model = get_text_embedding_model() |
| 206 | + model = get_sentence_transformer() |
203 | 207 | res = model.encode([text]) |
204 | 208 | return res[0] |
205 | 209 |
|
@@ -239,7 +243,7 @@ def get_text_embedding_similarity( |
239 | 243 | # 0.24702128767967224 |
240 | 244 | """ |
241 | 245 | if model is None: |
242 | | - model = get_text_embedding_model() |
| 246 | + model = get_sentence_transformer() |
243 | 247 | e1 = get_text_embedding(reference_1, model=model) |
244 | 248 | e2 = get_text_embedding(reference_2, model=model) |
245 | 249 | if e1 is None or e2 is None: |
|
0 commit comments