Skip to content

Commit 81a0cf3

Browse files
committed
Refactor linker
1 parent 678e88c commit 81a0cf3

File tree

2 files changed

+131
-27
lines changed

2 files changed

+131
-27
lines changed

scispacy/linking_utils.py

Lines changed: 106 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,65 @@
1-
from typing import List, Dict, NamedTuple, Optional, Set
1+
"""
2+
This submodule contains a data structure for storing lexical
3+
indexes over various biomedical vocabularies.
4+
5+
There are several built-in vocabularies, which can be imported
6+
and instantiated like in:
7+
8+
.. code-block:: python
9+
10+
from scispacy.linking_utils import UmlsKnowledgeBase
11+
12+
kb = UmlsKnowledgeBase()
13+
14+
In general, new :class:`KnowledgeBase` objects can be constructed
15+
from a list of :class:`Entity` objects, or a path to a JSON or JSONL
16+
file containing dictionaries shaped the same way:
17+
18+
.. code-block:: python
19+
20+
from scispacy.linking_utils import KnowledgeBase
21+
22+
# UMLS
23+
kb = KnowledgeBase(
24+
"https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/"
25+
"data/kbs/2023-04-23/umls_mesh_2022.jsonl"
26+
)
27+
28+
"""
29+
230
import json
331
from collections import defaultdict
32+
from contextlib import contextmanager
33+
from pathlib import Path
34+
from typing import (
35+
List,
36+
Dict,
37+
NamedTuple,
38+
Optional,
39+
Set,
40+
Union,
41+
Iterable,
42+
Tuple,
43+
DefaultDict,
44+
Generator,
45+
)
446

547
from scispacy.file_cache import cached_path
648
from scispacy.umls_semantic_type_tree import (
749
UmlsSemanticTypeTree,
850
construct_umls_tree_from_tsv,
951
)
1052

53+
__all__ = [
54+
"Entity",
55+
"KnowledgeBase",
56+
"UmlsKnowledgeBase",
57+
"Mesh",
58+
"GeneOntology",
59+
"HumanPhenotypeOntology",
60+
"RxNorm",
61+
]
62+
1163

1264
class Entity(NamedTuple):
1365
concept_id: str
@@ -38,6 +90,53 @@ def __repr__(self):
3890
DEFAULT_UMLS_TYPES_PATH = "https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/data/umls_semantic_type_tree.tsv"
3991

4092

93+
@contextmanager
94+
def _iter_entities(
95+
path_or_entities: Union[str, Path, Iterable[Entity]],
96+
) -> Generator[Iterable[Entity], None, None]:
97+
"""Iterate through entities from a JSON file, JSONL file, or pass through an existing iterable."""
98+
if isinstance(path_or_entities, (str, Path)):
99+
# normalize paths
100+
path_or_entities = cached_path(path_or_entities)
101+
102+
# do the following inside a context manager to
103+
# make sure the file gets closed properly
104+
with open(path_or_entities) as file:
105+
if path_or_entities.endswith("jsonl"):
106+
yield (Entity(**json.loads(line)) for line in file)
107+
else:
108+
yield (Entity(**record) for record in json.load(file))
109+
else:
110+
yield path_or_entities
111+
112+
113+
def _index_entities(
114+
entities: Iterable[Entity],
115+
) -> Tuple[Dict[str, Entity], Dict[str, Set[str]]]:
116+
"""Create indexes over entities for use in a :class:`KnowledgeBase`.
117+
118+
Parameters
119+
----------
120+
entities :
121+
An iterable (e.g., a list) of entity objects
122+
123+
Returns
124+
-------
125+
A pair of indexes for:
126+
127+
1. A mapping from local unique identifiers (e.g., CUIs for UMLS) to entity objects
128+
2. A mapping from aliases (e.g., canonical names, aliases) to local unique identifiers
129+
"""
130+
cui_to_entity: Dict[str, Entity] = {}
131+
alias_to_cuis: DefaultDict[str, Set[str]] = defaultdict(set)
132+
for entity in entities:
133+
alias_to_cuis[entity.canonical_name].add(entity.concept_id)
134+
for alias in entity.aliases:
135+
alias_to_cuis[alias].add(entity.concept_id)
136+
cui_to_entity[entity.concept_id] = entity
137+
return cui_to_entity, dict(alias_to_cuis)
138+
139+
41140
class KnowledgeBase:
42141
"""
43142
A class representing two commonly needed views of a Knowledge Base:
@@ -50,31 +149,20 @@ class KnowledgeBase:
50149
The file path to the json/jsonl representation of the KB to load.
51150
"""
52151

152+
cui_to_entity: Dict[str, Entity]
153+
alias_to_cuis: Dict[str, Set[str]]
154+
53155
def __init__(
54156
self,
55-
file_path: Optional[str] = None,
157+
file_path: Union[None, str, Path, Iterable[Entity]] = None,
56158
):
57159
if file_path is None:
58160
raise ValueError(
59161
"Do not use the default arguments to KnowledgeBase. "
60162
"Instead, use a subclass (e.g UmlsKnowledgeBase) or pass a path to a kb."
61163
)
62-
if file_path.endswith("jsonl"):
63-
raw = (json.loads(line) for line in open(cached_path(file_path)))
64-
else:
65-
raw = json.load(open(cached_path(file_path)))
66-
67-
alias_to_cuis: Dict[str, Set[str]] = defaultdict(set)
68-
self.cui_to_entity: Dict[str, Entity] = {}
69-
70-
for concept in raw:
71-
unique_aliases = set(concept["aliases"])
72-
unique_aliases.add(concept["canonical_name"])
73-
for alias in unique_aliases:
74-
alias_to_cuis[alias].add(concept["concept_id"])
75-
self.cui_to_entity[concept["concept_id"]] = Entity(**concept)
76-
77-
self.alias_to_cuis: Dict[str, Set[str]] = {**alias_to_cuis}
164+
with _iter_entities(file_path) as entities:
165+
self.cui_to_entity, self.alias_to_cuis = _index_entities(entities)
78166

79167

80168
class UmlsKnowledgeBase(KnowledgeBase):

tests/test_linking.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from scispacy.candidate_generation import CandidateGenerator, create_tfidf_ann_index
77
from scispacy.linking import EntityLinker
8+
from scispacy.linking_utils import Entity, KnowledgeBase
89
from scispacy.umls_utils import UmlsKnowledgeBase
910
from scispacy.abbreviation import AbbreviationDetector
1011
from scispacy.util import scipy_supports_sparse_float16
@@ -27,32 +28,47 @@ def setUp(self):
2728
self.linker = EntityLinker(candidate_generator=candidate_generator, filter_for_definitions=False)
2829

2930
def test_naive_entity_linking(self):
31+
self._test_linker(self.linker)
32+
33+
def test_custom_loading(self):
34+
entities = [
35+
Entity(concept_id="C0000039", canonical_name="dipalmitoylphosphatidylcholine", types=["T109", "T121"])
36+
]
37+
kb = KnowledgeBase(entities)
38+
with tempfile.TemporaryDirectory() as dir_name:
39+
concept_aliases, tfidf_vectorizer, ann_index = create_tfidf_ann_index(dir_name, kb)
40+
candidate_generator = CandidateGenerator(ann_index, tfidf_vectorizer, concept_aliases, kb)
41+
linker = EntityLinker(candidate_generator=candidate_generator, filter_for_definitions=False)
42+
43+
self._test_linker(linker)
44+
45+
def _test_linker(self, linker: EntityLinker) -> None:
3046
text = "There was a lot of Dipalmitoylphosphatidylcholine."
3147
doc = self.nlp(text)
3248

3349
# Check that the linker returns nothing if we set the filter_for_definitions flag
3450
# and set the threshold very high for entities without definitions.
35-
self.linker.filter_for_definitions = True
36-
self.linker.no_definition_threshold = 3.0
37-
doc = self.linker(doc)
51+
linker.filter_for_definitions = True
52+
linker.no_definition_threshold = 3.0
53+
doc = linker(doc)
3854
assert doc.ents[0]._.kb_ents == []
3955

4056
# Check that the linker returns only high confidence entities if we
4157
# set the threshold to something more reasonable.
42-
self.linker.no_definition_threshold = 0.95
43-
doc = self.linker(doc)
58+
linker.no_definition_threshold = 0.95
59+
doc = linker(doc)
4460
assert doc.ents[0]._.kb_ents == [("C0000039", 1.0)]
4561

46-
self.linker.filter_for_definitions = False
47-
self.linker.threshold = 0.45
48-
doc = self.linker(doc)
62+
linker.filter_for_definitions = False
63+
linker.threshold = 0.45
64+
doc = linker(doc)
4965
# Without the filter_for_definitions filter, we get 2 entities for
5066
# the first mention.
5167
assert len(doc.ents[0]._.kb_ents) == 2
5268

5369
id_with_score = doc.ents[0]._.kb_ents[0]
5470
assert id_with_score == ("C0000039", 1.0)
55-
umls_entity = self.linker.kb.cui_to_entity[id_with_score[0]]
71+
umls_entity = linker.kb.cui_to_entity[id_with_score[0]]
5672
assert umls_entity.concept_id == "C0000039"
5773
assert umls_entity.types == ["T109", "T121"]
5874

0 commit comments

Comments
 (0)