Skip to content

Commit 5eeb377

Browse files
authored
Add standardized caching of sentence transformers (#137)
1 parent d39e5e6 commit 5eeb377

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

src/pystow/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
ensure_yaml,
3636
ensure_zip_df,
3737
ensure_zip_np,
38+
get_sentence_transformer,
3839
join,
3940
joinpath_sqlite,
4041
load_df,
@@ -91,6 +92,7 @@
9192
"ensure_zip_df",
9293
"ensure_zip_np",
9394
"get_config",
95+
"get_sentence_transformer",
9496
"join",
9597
"joinpath_sqlite",
9698
"load_df",

src/pystow/api.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import numpy.typing
2424
import pandas as pd
2525
import rdflib
26+
import sentence_transformers
2627

2728
__all__ = [
2829
"dump_df",
@@ -59,6 +60,7 @@
5960
"ensure_yaml",
6061
"ensure_zip_df",
6162
"ensure_zip_np",
63+
"get_sentence_transformer",
6264
"join",
6365
"joinpath_sqlite",
6466
"load_df",
@@ -2007,3 +2009,25 @@ def ensure_nltk(resource: str = "stopwords") -> tuple[Path, bool]:
20072009
# if the package was downloaded
20082010

20092011
return directory, result
2012+
2013+
2014+
def get_sentence_transformer(
2015+
name: str | None = None, **kwargs: Any
2016+
) -> sentence_transformers.SentenceTransformer:
2017+
"""Get a sentence transformer.
2018+
2019+
:param name: The name of the sentence transformer model on HuggingFace
2020+
:param kwargs: Keyword arguments to pass to
2021+
:class:`sentence_transformers.SentenceTransformer`.
2022+
2023+
:returns: An instantiated sentence transformer object, which has a
2024+
:meth:`sentence_transformers.SentenceTransformer.encode` function
2025+
"""
2026+
from sentence_transformers import SentenceTransformer
2027+
2028+
if name is None:
2029+
name = "all-MiniLM-L6-v2"
2030+
2031+
directory = join("sentence-transformers", name)
2032+
model = SentenceTransformer(name, cache_folder=directory.as_posix(), **kwargs)
2033+
return model

0 commit comments

Comments
 (0)