Skip to content

Commit 9ac001c

Browse files
authored
Reuse transformer code from Pystow (#485)
1 parent e06624d commit 9ac001c

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ dependencies = [
6868
"humanize",
6969
"tabulate",
7070
"cachier",
71-
"pystow<=0.7.23",
71+
"pystow>=0.7.28",
7272
"bioversions>=0.8.243",
7373
"bioregistry>=0.12.30",
7474
"ssslm>=0.0.13",

src/pyobo/api/embedding.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
from __future__ import annotations
44

55
import tempfile
6+
import warnings
67
from pathlib import Path
78
from typing import TYPE_CHECKING, Literal
89

910
import bioregistry
1011
import curies
1112
import numpy as np
1213
import pandas as pd
14+
from pystow import get_sentence_transformer
1315
from tqdm import tqdm
1416
from typing_extensions import Unpack
1517

@@ -34,10 +36,12 @@
3436

3537
def get_text_embedding_model() -> sentence_transformers.SentenceTransformer:
3638
"""Get the default text embedding model."""
37-
from sentence_transformers import SentenceTransformer
38-
39-
model = SentenceTransformer("all-MiniLM-L6-v2")
40-
return model
39+
warnings.warn(
40+
"get_text_embedding_model() is deprecated, use pystow.get_sentence_transfomer() directly",
41+
DeprecationWarning,
42+
stacklevel=2,
43+
)
44+
return get_sentence_transformer()
4145

4246

4347
def _get_text(
@@ -157,7 +161,7 @@ def get_text_embeddings_df(
157161
luids.append(identifier)
158162
texts.append(text)
159163
if model is None:
160-
model = get_text_embedding_model()
164+
model = get_sentence_transformer()
161165
res = model.encode(texts, show_progress_bar=True)
162166
df = pd.DataFrame(res, index=luids)
163167
df.to_csv(path, sep="\t") # index is important here!
@@ -199,7 +203,7 @@ def get_text_embedding(
199203
if text is None:
200204
return None
201205
if model is None:
202-
model = get_text_embedding_model()
206+
model = get_sentence_transformer()
203207
res = model.encode([text])
204208
return res[0]
205209

@@ -239,7 +243,7 @@ def get_text_embedding_similarity(
239243
# 0.24702128767967224
240244
"""
241245
if model is None:
242-
model = get_text_embedding_model()
246+
model = get_sentence_transformer()
243247
e1 = get_text_embedding(reference_1, model=model)
244248
e2 = get_text_embedding(reference_2, model=model)
245249
if e1 is None or e2 is None:

0 commit comments

Comments
 (0)