1- from typing import Optional , List , Dict , Tuple , NamedTuple , Type
1+ from typing import Optional , List , Dict , Tuple , NamedTuple , Type , Union
22import json
33import datetime
44import os .path
55from collections import defaultdict
6+ from pathlib import Path
67
78import scipy
89import numpy
910import joblib
1011from sklearn .feature_extraction .text import TfidfVectorizer
1112import nmslib
1213from nmslib .dist import FloatIndex
14+ from typing_extensions import Self
1315
1416from scispacy .util import scipy_supports_sparse_float16
1517from scispacy .file_cache import cached_path
@@ -41,6 +43,38 @@ class LinkerPaths(NamedTuple):
4143 tfidf_vectors : str
4244 concept_aliases_list : str
4345
46+ @classmethod
47+ def from_directory (cls , directory : Union [str , Path ]) -> Self :
48+ if not os .path .isdir (directory ):
49+ raise NotADirectoryError
50+ return cls (
51+ ann_index = os .path .join (directory , "nmslib_index.bin" ),
52+ tfidf_vectorizer = os .path .join (directory , "tfidf_vectorizer.joblib" ),
53+ tfidf_vectors = os .path .join (directory , "tfidf_vectors_sparse.npz" ),
54+ concept_aliases_list = os .path .join (directory , "concept_aliases.json" ),
55+ )
56+
57+ def is_locally_cached (self ) -> bool :
58+ return all (os .path .isfile (x ) for x in self )
59+
60+ def get_concept_aliases (self ) -> List [str ]:
61+ with open (cached_path (self .concept_aliases_list )) as file :
62+ return json .load (file )
63+
64+ def get_tfidf_vectorizer (self ) -> TfidfVectorizer :
65+ return joblib .load (cached_path (self .tfidf_vectorizer ))
66+
67+ def get_ann_index (self , * , ef_search : int = 200 ) -> FloatIndex :
68+ return load_approximate_nearest_neighbours_index (self , ef_search = ef_search )
69+
70+ def load (
71+ self , * , ef_search : int = 200
72+ ) -> Tuple [List [str ], TfidfVectorizer , FloatIndex ]:
73+ concept_aliases = self .get_concept_aliases ()
74+ tfidf_vectorizer = self .get_tfidf_vectorizer ()
75+ ann_index = self .get_ann_index (ef_search = ef_search )
76+ return concept_aliases , tfidf_vectorizer , ann_index
77+
4478
4579UmlsLinkerPaths = LinkerPaths (
4680 ann_index = "https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/data/linkers/2023-04-23/umls/nmslib_index.bin" , # noqa
@@ -220,14 +254,10 @@ def __init__(
220254
221255 linker_paths = DEFAULT_PATHS .get (name , UmlsLinkerPaths )
222256
223- self .ann_index = ann_index or load_approximate_nearest_neighbours_index (
224- linker_paths = linker_paths , ef_search = ef_search
225- )
226- self .vectorizer = tfidf_vectorizer or joblib .load (
227- cached_path (linker_paths .tfidf_vectorizer )
228- )
229- self .ann_concept_aliases_list = ann_concept_aliases_list or json .load (
230- open (cached_path (linker_paths .concept_aliases_list ))
257+ self .ann_index = ann_index or linker_paths .get_ann_index (ef_search = ef_search )
258+ self .vectorizer = tfidf_vectorizer or linker_paths .get_tfidf_vectorizer ()
259+ self .ann_concept_aliases_list = (
260+ ann_concept_aliases_list or linker_paths .get_concept_aliases ()
231261 )
232262
233263 self .kb = kb or DEFAULT_KNOWLEDGE_BASES [name ]()
@@ -364,7 +394,10 @@ def __call__(
364394
365395
366396def create_tfidf_ann_index (
367- out_path : Optional [str ], kb : Optional [KnowledgeBase ] = None
397+ out_path : Optional [str ],
398+ kb : Optional [KnowledgeBase ] = None ,
399+ * ,
400+ ef_search : int = 200 ,
368401) -> Tuple [List [str ], TfidfVectorizer , FloatIndex ]:
369402 """
370403 Build tfidf vectorizer and ann index.
@@ -378,6 +411,13 @@ def create_tfidf_ann_index(
378411 The kb items to generate the index and vectors for.
379412
380413 """
414+ if out_path is None :
415+ linker_paths = None
416+ else :
417+ linker_paths = LinkerPaths .from_directory (out_path )
418+ if linker_paths .is_locally_cached ():
419+ return linker_paths .load (ef_search = ef_search )
420+
381421 if not scipy_supports_sparse_float16 ():
382422 raise RuntimeError (
383423 "This function requires scipy<1.11, which only runs on Python<3.11."
@@ -419,8 +459,8 @@ def create_tfidf_ann_index(
419459 )
420460 start_time = datetime .datetime .now ()
421461 concept_alias_tfidfs = tfidf_vectorizer .fit_transform (concept_aliases )
422- if out_path is not None :
423- tfidf_vectorizer_path = os . path . join ( out_path , " tfidf_vectorizer.joblib" )
462+ if linker_paths is not None :
463+ tfidf_vectorizer_path = linker_paths . tfidf_vectorizer
424464 print (f"Saving tfidf vectorizer to { tfidf_vectorizer_path } " )
425465 joblib .dump (tfidf_vectorizer , tfidf_vectorizer_path )
426466 end_time = datetime .datetime .now ()
@@ -446,9 +486,9 @@ def create_tfidf_ann_index(
446486 concept_alias_tfidfs = concept_alias_tfidfs [empty_tfidfs_boolean_flags ]
447487 assert len (concept_aliases ) == numpy .size (concept_alias_tfidfs , 0 )
448488
449- if out_path is not None :
450- tfidf_vectors_path = os . path . join ( out_path , "tfidf_vectors_sparse.npz" )
451- concept_aliases_path = os . path . join ( out_path , "concept_aliases.json" )
489+ if linker_paths is not None :
490+ tfidf_vectors_path = linker_paths . tfidf_vectors
491+ concept_aliases_path = linker_paths . concept_aliases_list
452492 print (
453493 f"Saving list of concept ids and tfidfs vectors to { concept_aliases_path } and { tfidf_vectors_path } "
454494 )
@@ -467,8 +507,8 @@ def create_tfidf_ann_index(
467507 )
468508 ann_index .addDataPointBatch (concept_alias_tfidfs )
469509 ann_index .createIndex (index_params , print_progress = True )
470- if out_path is not None :
471- ann_index_path = os . path . join ( out_path , "nmslib_index.bin" )
510+ if linker_paths is not None :
511+ ann_index_path = linker_paths . ann_index
472512 ann_index .saveIndex (ann_index_path )
473513 end_time = datetime .datetime .now ()
474514 elapsed_time = end_time - start_time
0 commit comments