Skip to content

Commit 13d750e

Browse files
authored
Improve some nlp utils (#178)
* only consider one model per lang * cache lang to models func * add func to process one text into spacy doc * tests: add unit test for new nlp util func * refactor: use single text process func in places * tests: update mocks to use new func
1 parent ffb6e63 commit 13d750e

File tree

7 files changed

+85
-26
lines changed

7 files changed

+85
-26
lines changed

colandr/lib/extractors/locations.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from spacy.tokens import Span
88

9-
from ..nlp.utils import process_texts_into_docs
9+
from ..nlp.utils import process_text_into_doc
1010
from .metadata import Metadata
1111

1212

@@ -50,8 +50,7 @@ def extract_locations(self, record_id: int, text: str) -> list[Metadata]:
5050
if not text or not text.strip():
5151
return []
5252

53-
processed_docs_iter = process_texts_into_docs([text], max_len=None)
54-
doc = next(iter(processed_docs_iter), None)
53+
doc = process_text_into_doc(text, max_len=None)
5554
if doc is None:
5655
return []
5756

colandr/lib/extractors/review_model.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from sklearn.preprocessing import MultiLabelBinarizer, StandardScaler
1818
from spacy.tokens import Doc, Span
1919

20-
from ..nlp.utils import process_texts_into_docs
20+
from ..nlp.utils import process_text_into_doc, process_texts_into_docs
2121
from .metadata import Metadata
2222

2323

@@ -444,10 +444,7 @@ def _process_text(self, text_content: str) -> tuple[pd.DataFrame, list[dict]]:
444444
Tuple containing the feature DataFrame and original sentences list.
445445
"""
446446
main_content, _ = self._split_references(text_content)
447-
processed_docs_iter = process_texts_into_docs(
448-
[main_content], max_len=None, exclude=("ner",)
449-
)
450-
doc = next(iter(processed_docs_iter), None)
447+
doc = process_text_into_doc(main_content, max_len=None, exclude=("ner",))
451448
return self._extract_features_from_doc(doc)
452449

453450
def _is_valid_sentence(self, sent: Optional[Span]) -> bool:

colandr/lib/nlp/utils.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,19 @@ def detect_languages(texts: Iterable[str]) -> list[t.Optional[str]]:
4545
]
4646

4747

48-
def get_lang_to_models() -> dict[str, list[str]]:
48+
@functools.cache
49+
def get_lang_to_models() -> dict[str, str]:
4950
"""Get a mapping of ISO language code to installed spacy language models."""
50-
lang_to_models = collections.defaultdict(list)
51+
lang_to_models = {}
5152
models = spacy.util.get_installed_models()
5253
for model in models:
5354
if "_" in model:
5455
lang, _ = model.split("_", 1)
55-
lang_to_models[lang].append(model)
56+
lang_to_models[lang] = model
5657
else:
5758
LOGGER.warning("found unexpected spacy model name: %s", model)
5859

59-
return dict(lang_to_models)
60+
return lang_to_models
6061

6162

6263
@functools.lru_cache(maxsize=10)
@@ -89,6 +90,41 @@ def load_spacy_lang(name: str, **kwargs) -> SpacyLang:
8990
return spacy_lang
9091

9192

93+
def process_text_into_doc(
94+
text: str,
95+
*,
96+
max_len: t.Optional[int] = 1000,
97+
fallback_lang: t.Optional[str] = "en",
98+
**kwargs,
99+
) -> t.Optional[SpacyDoc]:
100+
"""
101+
Args:
102+
text
103+
max_len: Maximum number of chars (code points) in text to include
104+
when identifying its language and processing into a spacy document.
105+
fallback_lang: Fallback language used in place of low-confidence predictions.
106+
**kwargs: Passed as-is into :func:`load_spacy_lang()` .
107+
"""
108+
# clean up whitespace, to make it easier on lang detector
109+
text = text.strip().replace("\n", " ")
110+
# truncate texts, optionally
111+
if max_len is not None:
112+
text = text[:max_len]
113+
# identify most probable language (w/ optional fallback) for text
114+
lang = detect_language(text) or fallback_lang
115+
lang_models = get_lang_to_models()
116+
if lang in lang_models:
117+
spacy_lang: SpacyLang = load_spacy_lang(lang_models[lang], **kwargs)
118+
spacy_doc = spacy_lang(text)
119+
return spacy_doc
120+
else:
121+
LOGGER.info(
122+
"unable to load spacy model for text with lang='%s'; doc set to null ...",
123+
lang,
124+
)
125+
return None
126+
127+
92128
def process_texts_into_docs(
93129
texts: Iterable[str],
94130
*,
@@ -120,7 +156,7 @@ def process_texts_into_docs(
120156
lang_models = get_lang_to_models()
121157
for lang, tl_grp in itertools.groupby(text_langs, key=itemgetter(1)):
122158
if lang in lang_models:
123-
spacy_lang = load_spacy_lang(lang_models[lang][0], **kwargs)
159+
spacy_lang = load_spacy_lang(lang_models[lang], **kwargs)
124160
spacy_docs = spacy_lang.pipe((text for text, _ in tl_grp), n_process=1)
125161
for spacy_doc in spacy_docs:
126162
yield spacy_doc

colandr/tasks.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -295,13 +295,12 @@ def get_fulltext_text_content_vector(fulltext_id: int):
295295
)
296296
return
297297

298-
docs = nlp_utils.process_texts_into_docs(
299-
[fulltext["text_content"]],
298+
doc = nlp_utils.process_text_into_doc(
299+
fulltext["text_content"],
300300
max_len=3000,
301301
fallback_lang=None,
302302
exclude=("parser", "ner"),
303303
)
304-
doc = next(iter(docs))
305304
text_content_vector_rep = doc.vector.tolist() if doc is not None else None
306305
if text_content_vector_rep is None:
307306
LOGGER.warning(
@@ -310,12 +309,12 @@ def get_fulltext_text_content_vector(fulltext_id: int):
310309
return
311310

312311
fulltext["text_content_vector_rep"] = text_content_vector_rep
313-
stmt = (
312+
update_stmt = (
314313
sa.update(models.Study)
315314
.where(models.Study.id == fulltext_id)
316315
.values(fulltext=fulltext)
317316
)
318-
db.session.execute(stmt)
317+
db.session.execute(update_stmt)
319318
db.session.commit()
320319

321320

tests/lib/extractors/test_locations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def test_is_in_reference(self):
4848

4949
assert extractor.is_in_reference(mock_ent) is False
5050

51-
@patch("colandr.lib.extractors.locations.process_texts_into_docs")
52-
def test_extract_locations(self, mock_process_texts):
51+
@patch("colandr.lib.extractors.locations.process_text_into_doc")
52+
def test_extract_locations(self, mock_process_text):
5353
"""Test extract_locations function."""
5454
extractor = LocationExtractor()
5555

@@ -80,7 +80,7 @@ def test_extract_locations(self, mock_process_texts):
8080

8181
mock_doc.ents = [mock_ent1, mock_ent2]
8282

83-
mock_process_texts.return_value = iter([mock_doc])
83+
mock_process_text.return_value = mock_doc
8484

8585
extractor.is_in_reference = MagicMock(return_value=False)
8686

tests/lib/extractors/test_review_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ def test_compare_and_train_skips_retrain(self):
120120
assert retrained is False
121121
mock_train.assert_not_called()
122122

123-
@patch("colandr.lib.extractors.review_model.process_texts_into_docs")
124-
def test_extract_metadata(self, mock_process_texts):
123+
@patch("colandr.lib.extractors.review_model.process_text_into_doc")
124+
def test_extract_metadata(self, mock_process_text):
125125
"""Test the full metadata extraction integration."""
126126
model = ReviewModel()
127127

@@ -141,11 +141,11 @@ def test_extract_metadata(self, mock_process_texts):
141141
mock_sent = self._create_mock_sentence(sent_text, has_verb=True)
142142
mock_doc = MagicMock()
143143
mock_doc.sents = [mock_sent]
144-
mock_process_texts.return_value = iter([mock_doc])
144+
mock_process_text.return_value = mock_doc
145145

146146
results = model.extract_metadata(123, "some input text", threshold=0.5)
147147

148-
mock_process_texts.assert_called_once()
148+
mock_process_text.assert_called_once()
149149
assert len(results) == 1
150150
result = results[0]
151151
assert result.record == 123

tests/lib/nlp/test_utils.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,34 @@ def test_detect_languages(texts, exp_langs):
7878
assert obs_langs == exp_langs
7979

8080

81+
@pytest.mark.parametrize(
82+
["text", "max_len", "fallback_lang"],
83+
[
84+
(
85+
"This is a short -- but not too short -- example English sentence.",
86+
1000,
87+
None,
88+
),
89+
("And this is another short example English sentence.", 100, "en"),
90+
("Esta es una frase corta de ejemplo en español.", None, None),
91+
],
92+
)
93+
def test_process_text_into_doc(text, max_len, fallback_lang, app):
94+
doc = utils.process_text_into_doc(
95+
text,
96+
max_len=max_len,
97+
fallback_lang=fallback_lang,
98+
exclude=("parser", "ner"),
99+
)
100+
assert isinstance(doc, Doc) or doc is None
101+
if doc.lang_ == "en":
102+
spacy_lang = utils.load_spacy_lang(
103+
utils.get_lang_to_models()["en"], exclude=("parser", "ner")
104+
)
105+
assert isinstance(spacy_lang, Language) and isinstance(doc, Doc) # type guards
106+
assert spacy_lang(text).to_bytes() == doc.to_bytes()
107+
108+
81109
@pytest.mark.parametrize(
82110
["texts", "max_len", "fallback_lang"],
83111
[
@@ -124,7 +152,7 @@ def test_process_texts_into_docs(texts, max_len, fallback_lang, app):
124152
assert any(isinstance(doc, Doc) for doc in docs)
125153
# sanity-check vector value for first text only
126154
spacy_lang = utils.load_spacy_lang(
127-
utils.get_lang_to_models()["en"][0], exclude=("parser", "ner")
155+
utils.get_lang_to_models()["en"], exclude=("parser", "ner")
128156
)
129157
doc = docs[0]
130158
assert isinstance(spacy_lang, Language) and isinstance(doc, Doc) # type guards

0 commit comments

Comments
 (0)