Skip to content

Commit 2e88594

Browse files
authored
Fix/index error when unhighlighting (#434)
* make sure nonhighlighted ents don't cause IndexError when unhighlighting * linting
1 parent 0fc4633 commit 2e88594

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

spacy_llm/tasks/entity_linker/task.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]:
105105
self._ents_cands_by_shard = [[] * len(self._ents_cands_by_doc)]
106106
self._has_ent_cands_by_shard = [[] * len(self._ents_cands_by_doc)]
107107
self._n_shards = None
108-
109108
return [
110109
EntityLinkerTask.highlight_ents_in_doc(doc, self._has_ent_cands_by_doc[i])
111110
for i, doc in enumerate(docs)
@@ -335,7 +334,11 @@ def unhighlight_ents_in_doc(doc: Doc) -> Doc:
335334
for ent in doc.ents
336335
if ent.start - 1 > 0 and doc[ent.start - 1].text == "*"
337336
}
338-
highlight_end_idx = {ent.end for ent in doc.ents if doc[ent.end].text == "*"}
337+
highlight_end_idx = {
338+
ent.end
339+
for ent in doc.ents
340+
if ent.end < len(doc) and doc[ent.end].text == "*"
341+
}
339342
highlight_idx = highlight_start_idx | highlight_end_idx
340343

341344
# Compute entity indices with removed highlights.

spacy_llm/tests/tasks/test_entity_linker.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -682,9 +682,38 @@ def test_ent_highlighting():
682682
EntityLinkerTask.highlight_ents_in_doc(doc).text
683683
== "Alice goes to *Boston* to see the *Boston Celtics* game."
684684
)
685+
686+
687+
@pytest.mark.parametrize(
688+
"text,ents,include_ents",
689+
[
690+
(
691+
"Alice goes to Boston to see the Boston Celtics game.",
692+
[
693+
{"start": 3, "end": 4, "label": "LOC"},
694+
{"start": 7, "end": 9, "label": "ORG"},
695+
],
696+
[True, True],
697+
),
698+
(
699+
"I went to see Boston in concert yesterday",
700+
[
701+
{"start": 4, "end": 5, "label": "GPE"},
702+
{"start": 7, "end": 8, "label": "DATE"},
703+
],
704+
[True, False],
705+
),
706+
],
707+
)
708+
def test_ent_unhighlighting(text, ents, include_ents):
709+
"""Tests unhighlighting of entities in text."""
710+
nlp = spacy.blank("en")
711+
doc = nlp.make_doc(text)
712+
doc.ents = [Span(doc=doc, **ents[0]), Span(doc=doc, **ents[1])]
713+
685714
assert (
686715
EntityLinkerTask.unhighlight_ents_in_doc(
687-
EntityLinkerTask.highlight_ents_in_doc(doc)
716+
EntityLinkerTask.highlight_ents_in_doc(doc, include_ents)
688717
).text
689718
== doc.text
690719
== text

0 commit comments

Comments
 (0)