Skip to content

Commit 3ee1b85

Browse files
refactor: pymorphy2 to spacy in Entity Linking and KBQA (#1618)
Co-authored-by: Fedor Ignatov <[email protected]>
1 parent 9ff98b6 commit 3ee1b85

File tree

5 files changed

+34
-40
lines changed

5 files changed

+34
-40
lines changed

deeppavlov/core/common/requirements_registry.json

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
],
66
"entity_linker": [
77
"{DEEPPAVLOV_PATH}/requirements/hdt.txt",
8-
"{DEEPPAVLOV_PATH}/requirements/rapidfuzz.txt"
8+
"{DEEPPAVLOV_PATH}/requirements/rapidfuzz.txt",
9+
"{DEEPPAVLOV_PATH}/requirements/en_core_web_sm.txt",
10+
"{DEEPPAVLOV_PATH}/requirements/ru_core_news_sm.txt"
911
],
1012
"fasttext": [
1113
"{DEEPPAVLOV_PATH}/requirements/fasttext.txt"
@@ -58,6 +60,7 @@
5860
"{DEEPPAVLOV_PATH}/requirements/transformers.txt"
5961
],
6062
"ru_adj_to_noun": [
63+
"{DEEPPAVLOV_PATH}/requirements/ru_core_news_sm.txt",
6164
"{DEEPPAVLOV_PATH}/requirements/udapi.txt"
6265
],
6366
"russian_words_vocab": [
@@ -147,7 +150,9 @@
147150
"{DEEPPAVLOV_PATH}/requirements/transformers.txt"
148151
],
149152
"tree_to_sparql": [
150-
"{DEEPPAVLOV_PATH}/requirements/udapi.txt"
153+
"{DEEPPAVLOV_PATH}/requirements/udapi.txt",
154+
"{DEEPPAVLOV_PATH}/requirements/en_core_web_sm.txt",
155+
"{DEEPPAVLOV_PATH}/requirements/ru_core_news_sm.txt"
151156
],
152157
"typos_custom_reader": [
153158
"{DEEPPAVLOV_PATH}/requirements/lxml.txt"

deeppavlov/models/entity_extraction/entity_linking.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,19 @@
1414

1515
import re
1616
import sqlite3
17+
from collections import defaultdict
1718
from logging import getLogger
1819
from typing import List, Dict, Tuple, Union, Any
19-
from collections import defaultdict
2020

21-
import pymorphy2
21+
import spacy
2222
from hdt import HDTDocument
2323
from nltk.corpus import stopwords
2424
from rapidfuzz import fuzz
2525

26+
from deeppavlov.core.commands.utils import expand_path
2627
from deeppavlov.core.common.registry import register
2728
from deeppavlov.core.models.component import Component
2829
from deeppavlov.core.models.serializable import Serializable
29-
from deeppavlov.core.commands.utils import expand_path
3030

3131
log = getLogger(__name__)
3232

@@ -75,7 +75,6 @@ def __init__(
7575
**kwargs:
7676
"""
7777
super().__init__(save_path=None, load_path=load_path)
78-
self.morph = pymorphy2.MorphAnalyzer()
7978
self.lemmatize = lemmatize
8079
self.entities_database_filename = entities_database_filename
8180
self.num_entities_for_bert_ranking = num_entities_for_bert_ranking
@@ -86,8 +85,10 @@ def __init__(
8685
self.lang = f"@{lang}"
8786
if self.lang == "@en":
8887
self.stopwords = set(stopwords.words("english"))
88+
self.nlp = spacy.load("en_core_web_sm")
8989
elif self.lang == "@ru":
9090
self.stopwords = set(stopwords.words("russian"))
91+
self.nlp = spacy.load("ru_core_news_sm")
9192
self.use_descriptions = use_descriptions
9293
self.use_connections = use_connections
9394
self.max_paragraph_len = max_paragraph_len
@@ -198,7 +199,7 @@ def link_entities(
198199
):
199200
cand_ent_scores = []
200201
if len(entity_substr) > 1:
201-
entity_substr_split_lemm = [self.morph.parse(tok)[0].normal_form for tok in entity_substr_split]
202+
entity_substr_split_lemm = [self.nlp(tok)[0].lemma_ for tok in entity_substr_split]
202203
cand_ent_init = self.find_exact_match(entity_substr, tag)
203204
if not cand_ent_init or entity_substr_split != entity_substr_split_lemm:
204205
cand_ent_init = self.find_fuzzy_match(entity_substr_split, tag)
@@ -297,28 +298,23 @@ def find_exact_match(self, entity_substr, tag):
297298
entity_substr_split = entity_substr_split[1:]
298299
entities_and_ids = self.find_title(entity_substr)
299300
cand_ent_init = self.process_cand_ent(cand_ent_init, entities_and_ids, entity_substr_split, tag)
300-
if self.lang == "@ru":
301-
entity_substr_split_lemm = [self.morph.parse(tok)[0].normal_form for tok in entity_substr_split]
302-
entity_substr_lemm = " ".join(entity_substr_split_lemm)
303-
if entity_substr_lemm != entity_substr:
304-
entities_and_ids = self.find_title(entity_substr_lemm)
305-
if entities_and_ids:
306-
cand_ent_init = self.process_cand_ent(
307-
cand_ent_init, entities_and_ids, entity_substr_split_lemm, tag
308-
)
301+
302+
entity_substr_split_lemm = [self.nlp(tok)[0].lemma_ for tok in entity_substr_split]
303+
entity_substr_lemm = " ".join(entity_substr_split_lemm)
304+
if entity_substr_lemm != entity_substr:
305+
entities_and_ids = self.find_title(entity_substr_lemm)
306+
if entities_and_ids:
307+
cand_ent_init = self.process_cand_ent(cand_ent_init, entities_and_ids, entity_substr_split_lemm, tag)
309308
return cand_ent_init
310309

311310
def find_fuzzy_match(self, entity_substr_split, tag):
312-
if self.lang == "@ru":
313-
entity_substr_split_lemm = [self.morph.parse(tok)[0].normal_form for tok in entity_substr_split]
314-
else:
315-
entity_substr_split_lemm = entity_substr_split
311+
entity_substr_split_lemm = [self.nlp(tok)[0].lemma_ for tok in entity_substr_split]
316312
cand_ent_init = defaultdict(set)
317313
for word in entity_substr_split:
318314
part_entities_and_ids = self.find_title(word)
319315
cand_ent_init = self.process_cand_ent(cand_ent_init, part_entities_and_ids, entity_substr_split, tag)
320316
if self.lang == "@ru":
321-
word_lemm = self.morph.parse(word)[0].normal_form
317+
word_lemm = self.nlp(word)[0].lemma_
322318
if word != word_lemm:
323319
part_entities_and_ids = self.find_title(word_lemm)
324320
cand_ent_init = self.process_cand_ent(
@@ -329,11 +325,6 @@ def find_fuzzy_match(self, entity_substr_split, tag):
329325
)
330326
return cand_ent_init
331327

332-
def morph_parse(self, word):
333-
morph_parse_tok = self.morph.parse(word)[0]
334-
normal_form = morph_parse_tok.normal_form
335-
return normal_form
336-
337328
def calc_substr_score(self, cand_entity_title, entity_substr_split):
338329
label_tokens = cand_entity_title.split()
339330
cnt = 0.0

deeppavlov/models/kbqa/tree_to_sparql.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Any, List, Tuple, Dict, Union
2020

2121
import numpy as np
22-
import pymorphy2
22+
import spacy
2323
from navec import Navec
2424
from scipy.sparse import csr_matrix
2525
from slovnet import Syntax
@@ -66,11 +66,10 @@ def __init__(self, freq_dict_filename: str, candidate_nouns: int = 10, **kwargs)
6666
self.adj_set = set([word for word, freq in pos_freq_dict["a"]])
6767
self.nouns = [noun[0] for noun in self.nouns_with_freq]
6868
self.matrix = self.make_sparse_matrix(self.nouns).transpose()
69-
self.morph = pymorphy2.MorphAnalyzer()
69+
self.nlp = spacy.load("ru_core_news_sm")
7070

7171
def search(self, word: str):
72-
word = self.morph.parse(word)[0]
73-
word = word.normal_form
72+
word = self.nlp(word)[0].lemma_
7473
if word in self.adj_set:
7574
q_matrix = self.make_sparse_matrix([word])
7675
scores = q_matrix * self.matrix
@@ -190,6 +189,7 @@ def __init__(self, sparql_queries_filename: str, lang: str = "rus", adj_to_noun:
190189
self.begin_tokens = {"начинать", "начать"}
191190
self.end_tokens = {"завершить", "завершать", "закончить"}
192191
self.ranking_tokens = {"самый"}
192+
self.nlp = spacy.load("ru_core_news_sm")
193193
elif self.lang == "eng":
194194
self.q_pronouns = {"what", "who", "how", "when", "where", "which"}
195195
self.how_many = "how many"
@@ -199,12 +199,12 @@ def __init__(self, sparql_queries_filename: str, lang: str = "rus", adj_to_noun:
199199
self.begin_tokens = set()
200200
self.end_tokens = set()
201201
self.ranking_tokens = set()
202+
self.nlp = spacy.load("en_core_web_sm")
202203
else:
203204
raise ValueError(f"unsupported language {lang}")
204205
self.sparql_queries_filename = expand_path(sparql_queries_filename)
205206
self.template_queries = read_json(self.sparql_queries_filename)
206207
self.adj_to_noun = adj_to_noun
207-
self.morph = pymorphy2.MorphAnalyzer()
208208

209209
def __call__(self, syntax_tree_batch: List[str],
210210
positions_batch: List[List[List[int]]]) -> Tuple[
@@ -274,7 +274,7 @@ def __call__(self, syntax_tree_batch: List[str],
274274
self.root_entity = True
275275

276276
temporal_order = self.find_first_last(new_root)
277-
new_root_nf = self.morph.parse(new_root.form)[0].normal_form
277+
new_root_nf = self.nlp(new_root.form)[0].lemma_
278278
if new_root_nf in self.begin_tokens or new_root_nf in self.end_tokens:
279279
temporal_order = new_root_nf
280280
ranking_tokens = self.find_ranking_tokens(new_root)
@@ -288,7 +288,7 @@ def __call__(self, syntax_tree_batch: List[str],
288288
question = []
289289
for node in tree.descendants:
290290
if node.ord in ranking_tokens or node.form.lower() in self.q_pronouns:
291-
question.append(self.morph.parse(node.form)[0].normal_form)
291+
question.append(self.nlp(node.form)[0].lemma_)
292292
else:
293293
question.append(node.form)
294294
question = ' '.join(question)
@@ -496,9 +496,9 @@ def find_first_last(self, node: Node) -> str:
496496
for node in nodes:
497497
node_desc = defaultdict(set)
498498
for elem in node.children:
499-
parsed_elem = self.morph.parse(elem.form.lower())[0].inflect({"masc", "sing", "nomn"})
499+
parsed_elem = self.nlp(elem.form.lower())[0].lemma_
500500
if parsed_elem is not None:
501-
node_desc[elem.deprel].add(parsed_elem.word)
501+
node_desc[elem.deprel].add(parsed_elem)
502502
else:
503503
node_desc[elem.deprel].add(elem.form)
504504
if "amod" in node_desc.keys() and "nmod" in node_desc.keys() and \
@@ -511,7 +511,7 @@ def find_first_last(self, node: Node) -> str:
511511
def find_ranking_tokens(self, node: Node) -> list:
512512
ranking_tokens = []
513513
for elem in node.descendants:
514-
if self.morph.parse(elem.form)[0].normal_form in self.ranking_tokens:
514+
if self.nlp(elem.form)[0].lemma_ in self.ranking_tokens:
515515
ranking_tokens.append(elem.ord)
516516
ranking_tokens.append(elem.parent.ord)
517517
return ranking_tokens

deeppavlov/models/kbqa/type_define.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import pickle
1616
from typing import List
1717

18-
import pymorphy2
1918
import spacy
2019
from nltk.corpus import stopwords
2120

@@ -43,7 +42,6 @@ def __init__(self, lang: str, types_filename: str, types_sets_filename: str,
4342
self.types_filename = str(expand_path(types_filename))
4443
self.types_sets_filename = str(expand_path(types_sets_filename))
4544
self.num_types_to_return = num_types_to_return
46-
self.morph = pymorphy2.MorphAnalyzer()
4745
if self.lang == "@en":
4846
self.stopwords = set(stopwords.words("english"))
4947
self.nlp = spacy.load("en_core_web_sm")
@@ -102,7 +100,7 @@ def __call__(self, questions_batch: List[str], entity_substr_batch: List[List[st
102100
types_substr_tokens = types_substr.split()
103101
types_substr_tokens = [tok for tok in types_substr_tokens if tok not in self.stopwords]
104102
if self.lang == "@ru":
105-
types_substr_tokens = [self.morph.parse(tok)[0].normal_form for tok in types_substr_tokens]
103+
types_substr_tokens = [self.nlp(tok)[0].lemma_ for tok in types_substr_tokens]
106104
types_substr_tokens = set(types_substr_tokens)
107105
types_scores = []
108106
for entity in self.types_dict:

tests/test_quick_start.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@
256256
("kbqa/kbqa_cq_ru.json", "kbqa", ('IP',)):
257257
[
258258
("Кто такой Оксимирон?", ("российский рэп-исполнитель",)),
259-
("Чем питаются коалы?", ("Лист",)),
259+
("Кто написал «Евгений Онегин»?", ("Александр Сергеевич Пушкин",)),
260260
("абв", ("Not Found",))
261261
]
262262
},

0 commit comments

Comments
 (0)