Skip to content

Commit 2bc186b

Browse files
KBQA fixes (#1591)
1 parent d39818c commit 2bc186b

File tree

4 files changed

+26
-19
lines changed

4 files changed

+26
-19
lines changed

deeppavlov/models/entity_extraction/entity_linking.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
self.use_tags = use_tags
9595
self.full_paragraph = full_paragraph
9696
self.re_tokenizer = re.compile(r"[\w']+|[^\w ]")
97-
self.not_found_str = "not in wiki"
97+
self.not_found_str = "not_in_wiki"
9898

9999
self.load()
100100

@@ -277,27 +277,31 @@ def process_cand_ent(self, cand_ent_init, entities_and_ids, entity_substr_split,
277277
cand_ent_init[cand_entity_id].add((substr_score, cand_entity_rels))
278278
return cand_ent_init
279279

280+
def find_title(self, entity_substr):
281+
entities_and_ids = []
282+
try:
283+
res = self.cur.execute("SELECT * FROM inverted_index WHERE title MATCH '{}';".format(entity_substr))
284+
entities_and_ids = res.fetchall()
285+
except sqlite3.OperationalError as e:
286+
log.debug(f"error in searching an entity {e}")
287+
return entities_and_ids
288+
280289
def find_exact_match(self, entity_substr, tag):
281290
entity_substr_split = entity_substr.split()
282291
cand_ent_init = defaultdict(set)
283-
res = self.cur.execute("SELECT * FROM inverted_index WHERE title MATCH '{}';".format(entity_substr))
284-
entities_and_ids = res.fetchall()
292+
entities_and_ids = self.find_title(entity_substr)
285293
if entities_and_ids:
286294
cand_ent_init = self.process_cand_ent(cand_ent_init, entities_and_ids, entity_substr_split, tag)
287295
if entity_substr.startswith("the "):
288296
entity_substr = entity_substr.split("the ")[1]
289297
entity_substr_split = entity_substr_split[1:]
290-
res = self.cur.execute("SELECT * FROM inverted_index WHERE title MATCH '{}';".format(entity_substr))
291-
entities_and_ids = res.fetchall()
298+
entities_and_ids = self.find_title(entity_substr)
292299
cand_ent_init = self.process_cand_ent(cand_ent_init, entities_and_ids, entity_substr_split, tag)
293300
if self.lang == "@ru":
294301
entity_substr_split_lemm = [self.morph.parse(tok)[0].normal_form for tok in entity_substr_split]
295302
entity_substr_lemm = " ".join(entity_substr_split_lemm)
296303
if entity_substr_lemm != entity_substr:
297-
res = self.cur.execute(
298-
"SELECT * FROM inverted_index WHERE title MATCH '{}';".format(entity_substr_lemm)
299-
)
300-
entities_and_ids = res.fetchall()
304+
entities_and_ids = self.find_title(entity_substr_lemm)
301305
if entities_and_ids:
302306
cand_ent_init = self.process_cand_ent(
303307
cand_ent_init, entities_and_ids, entity_substr_split_lemm, tag
@@ -311,14 +315,12 @@ def find_fuzzy_match(self, entity_substr_split, tag):
311315
entity_substr_split_lemm = entity_substr_split
312316
cand_ent_init = defaultdict(set)
313317
for word in entity_substr_split:
314-
res = self.cur.execute("SELECT * FROM inverted_index WHERE title MATCH '{}';".format(word))
315-
part_entities_and_ids = res.fetchall()
318+
part_entities_and_ids = self.find_title(word)
316319
cand_ent_init = self.process_cand_ent(cand_ent_init, part_entities_and_ids, entity_substr_split, tag)
317320
if self.lang == "@ru":
318321
word_lemm = self.morph.parse(word)[0].normal_form
319322
if word != word_lemm:
320-
res = self.cur.execute("SELECT * FROM inverted_index WHERE title MATCH '{}';".format(word_lemm))
321-
part_entities_and_ids = res.fetchall()
323+
part_entities_and_ids = self.find_title(word_lemm)
322324
cand_ent_init = self.process_cand_ent(
323325
cand_ent_init,
324326
part_entities_and_ids,

deeppavlov/models/kbqa/query_generator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,11 @@ def query_parser(self, question: str, query_info: Dict[str, str],
165165
rel_combs = make_combs(rels, permut=False)
166166
entity_positions, type_positions = [elem.split('_') for elem in entities_and_types_select.split(' ')]
167167
log.debug(f"entity_positions {entity_positions}, type_positions {type_positions}")
168-
selected_entity_ids = [entity_ids[int(pos) - 1] for pos in entity_positions if int(pos) > 0]
169-
selected_type_ids = [type_ids[int(pos) - 1] for pos in type_positions if int(pos) > 0]
168+
selected_entity_ids, selected_type_ids = [], []
169+
if entity_ids:
170+
selected_entity_ids = [entity_ids[int(pos) - 1] for pos in entity_positions if int(pos) > 0]
171+
if type_ids:
172+
selected_type_ids = [type_ids[int(pos) - 1] for pos in type_positions if int(pos) > 0]
170173
entity_combs = make_combs(selected_entity_ids, permut=True)
171174
type_combs = make_combs(selected_type_ids, permut=False)
172175
log.debug(f"(query_parser)entity_combs: {entity_combs[:3]}, type_combs: {type_combs[:3]},"

deeppavlov/models/kbqa/rel_ranking_infer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,11 @@ def __call__(self, questions_list: List[str],
148148
answer_ids = answers_with_scores[0][0]
149149
if self.return_all_possible_answers and isinstance(answer_ids, tuple):
150150
answer_ids_input = [(answer_id, question) for answer_id in answer_ids]
151-
answer_ids = [answer_id.split("/")[-1] for answer_id in answer_ids]
151+
answer_ids = list(map(lambda x: x.split("/")[-1] if str(x).startswith("http") else x, answer_ids))
152152
else:
153153
answer_ids_input = [(answer_ids, question)]
154-
answer_ids = answer_ids.split("/")[-1]
154+
if str(answer_ids).startswith("http:"):
155+
answer_ids = answer_ids.split("/")[-1]
155156
parser_info_list = ["find_label" for _ in answer_ids_input]
156157
answer_labels = self.wiki_parser(parser_info_list, answer_ids_input)
157158
log.debug(f"answer_labels {answer_labels}")

deeppavlov/models/kbqa/type_define.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,15 @@ def __call__(self, questions_batch: List[str], entity_substr_batch: List[List[st
8585
break
8686
elif token.head.text == type_noun and token.dep_ == "prep":
8787
if len(list(token.children)) == 1 \
88-
and not any([list(token.children)[0] in entity_substr.lower()
88+
and not any([[tok.text for tok in token.children][0] in entity_substr.lower()
8989
for entity_substr in entity_substr_list]):
90-
types_substr += [token.text, list(token.children)[0]]
90+
types_substr += [token.text, [tok.text for tok in token.children][0]]
9191
elif any([word in question for word in self.pronouns]):
9292
for token in doc:
9393
if token.dep_ == "nsubj" and not any([token.text in entity_substr.lower()
9494
for entity_substr in entity_substr_list]):
9595
types_substr.append(token.text)
96+
9697
types_substr = [(token, token_pos_dict[token]) for token in types_substr]
9798
types_substr = sorted(types_substr, key=lambda x: x[1])
9899
types_substr = " ".join([elem[0] for elem in types_substr])

0 commit comments

Comments
 (0)