Skip to content

Commit d39f33b

Browse files
committed
Implement TF-IDF caching (allenai#549)
1 parent d837f26 commit d39f33b

File tree

1 file changed

+57
-17
lines changed

1 file changed

+57
-17
lines changed

scispacy/candidate_generation.py

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1-
from typing import Optional, List, Dict, Tuple, NamedTuple, Type
1+
from typing import Optional, List, Dict, Tuple, NamedTuple, Type, Union
22
import json
33
import datetime
44
import os.path
55
from collections import defaultdict
6+
from pathlib import Path
67

78
import scipy
89
import numpy
910
import joblib
1011
from sklearn.feature_extraction.text import TfidfVectorizer
1112
import nmslib
1213
from nmslib.dist import FloatIndex
14+
from typing_extensions import Self
1315

1416
from scispacy.util import scipy_supports_sparse_float16
1517
from scispacy.file_cache import cached_path
@@ -41,6 +43,38 @@ class LinkerPaths(NamedTuple):
4143
tfidf_vectors: str
4244
concept_aliases_list: str
4345

46+
@classmethod
47+
def from_directory(cls, directory: Union[str, Path]) -> Self:
48+
if not os.path.isdir(directory):
49+
raise NotADirectoryError
50+
return cls(
51+
ann_index=os.path.join(directory, "nmslib_index.bin"),
52+
tfidf_vectorizer=os.path.join(directory, "tfidf_vectorizer.joblib"),
53+
tfidf_vectors=os.path.join(directory, "tfidf_vectors_sparse.npz"),
54+
concept_aliases_list=os.path.join(directory, "concept_aliases.json"),
55+
)
56+
57+
def is_locally_cached(self) -> bool:
58+
return all(os.path.isfile(x) for x in self)
59+
60+
def get_concept_aliases(self) -> List[str]:
61+
with open(cached_path(self.concept_aliases_list)) as file:
62+
return json.load(file)
63+
64+
def get_tfidf_vectorizer(self) -> TfidfVectorizer:
65+
return joblib.load(cached_path(self.tfidf_vectorizer))
66+
67+
def get_ann_index(self, *, ef_search: int = 200) -> FloatIndex:
68+
return load_approximate_nearest_neighbours_index(self, ef_search=ef_search)
69+
70+
def load(
71+
self, *, ef_search: int = 200
72+
) -> Tuple[List[str], TfidfVectorizer, FloatIndex]:
73+
concept_aliases = self.get_concept_aliases()
74+
tfidf_vectorizer = self.get_tfidf_vectorizer()
75+
ann_index = self.get_ann_index(ef_search=ef_search)
76+
return concept_aliases, tfidf_vectorizer, ann_index
77+
4478

4579
UmlsLinkerPaths = LinkerPaths(
4680
ann_index="https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/data/linkers/2023-04-23/umls/nmslib_index.bin", # noqa
@@ -220,14 +254,10 @@ def __init__(
220254

221255
linker_paths = DEFAULT_PATHS.get(name, UmlsLinkerPaths)
222256

223-
self.ann_index = ann_index or load_approximate_nearest_neighbours_index(
224-
linker_paths=linker_paths, ef_search=ef_search
225-
)
226-
self.vectorizer = tfidf_vectorizer or joblib.load(
227-
cached_path(linker_paths.tfidf_vectorizer)
228-
)
229-
self.ann_concept_aliases_list = ann_concept_aliases_list or json.load(
230-
open(cached_path(linker_paths.concept_aliases_list))
257+
self.ann_index = ann_index or linker_paths.get_ann_index(ef_search=ef_search)
258+
self.vectorizer = tfidf_vectorizer or linker_paths.get_tfidf_vectorizer()
259+
self.ann_concept_aliases_list = (
260+
ann_concept_aliases_list or linker_paths.get_concept_aliases()
231261
)
232262

233263
self.kb = kb or DEFAULT_KNOWLEDGE_BASES[name]()
@@ -364,7 +394,10 @@ def __call__(
364394

365395

366396
def create_tfidf_ann_index(
367-
out_path: Optional[str], kb: Optional[KnowledgeBase] = None
397+
out_path: Optional[str],
398+
kb: Optional[KnowledgeBase] = None,
399+
*,
400+
ef_search: int = 200,
368401
) -> Tuple[List[str], TfidfVectorizer, FloatIndex]:
369402
"""
370403
Build tfidf vectorizer and ann index.
@@ -378,6 +411,13 @@ def create_tfidf_ann_index(
378411
The kb items to generate the index and vectors for.
379412
380413
"""
414+
if out_path is None:
415+
linker_paths = None
416+
else:
417+
linker_paths = LinkerPaths.from_directory(out_path)
418+
if linker_paths.is_locally_cached():
419+
return linker_paths.load(ef_search=ef_search)
420+
381421
if not scipy_supports_sparse_float16():
382422
raise RuntimeError(
383423
"This function requires scipy<1.11, which only runs on Python<3.11."
@@ -419,8 +459,8 @@ def create_tfidf_ann_index(
419459
)
420460
start_time = datetime.datetime.now()
421461
concept_alias_tfidfs = tfidf_vectorizer.fit_transform(concept_aliases)
422-
if out_path is not None:
423-
tfidf_vectorizer_path = os.path.join(out_path, "tfidf_vectorizer.joblib")
462+
if linker_paths is not None:
463+
tfidf_vectorizer_path = linker_paths.tfidf_vectorizer
424464
print(f"Saving tfidf vectorizer to {tfidf_vectorizer_path}")
425465
joblib.dump(tfidf_vectorizer, tfidf_vectorizer_path)
426466
end_time = datetime.datetime.now()
@@ -446,9 +486,9 @@ def create_tfidf_ann_index(
446486
concept_alias_tfidfs = concept_alias_tfidfs[empty_tfidfs_boolean_flags]
447487
assert len(concept_aliases) == numpy.size(concept_alias_tfidfs, 0)
448488

449-
if out_path is not None:
450-
tfidf_vectors_path = os.path.join(out_path, "tfidf_vectors_sparse.npz")
451-
concept_aliases_path = os.path.join(out_path, "concept_aliases.json")
489+
if linker_paths is not None:
490+
tfidf_vectors_path = linker_paths.tfidf_vectors
491+
concept_aliases_path = linker_paths.concept_aliases_list
452492
print(
453493
f"Saving list of concept ids and tfidfs vectors to {concept_aliases_path} and {tfidf_vectors_path}"
454494
)
@@ -467,8 +507,8 @@ def create_tfidf_ann_index(
467507
)
468508
ann_index.addDataPointBatch(concept_alias_tfidfs)
469509
ann_index.createIndex(index_params, print_progress=True)
470-
if out_path is not None:
471-
ann_index_path = os.path.join(out_path, "nmslib_index.bin")
510+
if linker_paths is not None:
511+
ann_index_path = linker_paths.ann_index
472512
ann_index.saveIndex(ann_index_path)
473513
end_time = datetime.datetime.now()
474514
elapsed_time = end_time - start_time

0 commit comments

Comments
 (0)