Skip to content

Commit c989173

Browse files
author
Marcin Kardas
committed
Small fixes
1 parent b1f690a commit c989173

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

sota_extractor2/models/linking/context_search.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,9 @@ def match(self, contexts):
248248
axes_probs = [softmax(np.array(a)) for a in axes_context_logprobs]
249249
return (
250250
zip(keys, probs),
251-
zip(self.taxonomy.tasks, axes_probs[0]),
252-
zip(self.taxonomy.datasets, axes_probs[1]),
253-
zip(self.taxonomy.metrics, axes_probs[2])
251+
zip(self._taxonomy_tasks, axes_probs[0]),
252+
zip(self._taxonomy_datasets, axes_probs[1]),
253+
zip(self._taxonomy_metrics, axes_probs[2])
254254
)
255255

256256
def __call__(self, query, paper_context, abstract_context, table_context, caption, topk=1, debug_info=None):
@@ -259,9 +259,9 @@ def __call__(self, query, paper_context, abstract_context, table_context, captio
259259
paper_context=paper_context, abstract_context=abstract_context, table_context=table_context,
260260
caption=caption)
261261

262-
paper_hash = ";".join(",".join(s.elements()) for s in paper_context)
263-
abstract_hash = ";".join(",".join(s.elements()) for s in abstract_context)
264-
mentions_hash = ";".join(",".join(s.elements()) for s in table_context)
262+
paper_hash = ";".join(",".join(sorted(s.elements())) for s in paper_context)
263+
abstract_hash = ";".join(",".join(sorted(s.elements())) for s in abstract_context)
264+
mentions_hash = ";".join(",".join(sorted(s.elements())) for s in table_context)
265265
key = (paper_hash, abstract_hash, mentions_hash, caption, query, topk)
266266
###print(f"[DEBUG] {cellstr}")
267267
###print("[DEBUG]", debug_info)
@@ -282,7 +282,7 @@ def __call__(self, query, paper_context, abstract_context, table_context, captio
282282
else:
283283
dists = self.match((paper_context, abstract_context, table_context, caption, query))
284284

285-
all_top_results = [sorted(dist, key=lambda x: x[1], reverse=True)[:max(topk, 5)] for dist in dists]
285+
all_top_results = [sorted(list(dist), key=lambda x: x[1], reverse=True)[:max(topk, 5)] for dist in dists]
286286
top_results, top_results_t, top_results_d, top_results_m = all_top_results
287287

288288
entries = []
@@ -358,8 +358,10 @@ def find_references(self, text, references):
358358
return set(re.findall(refs, text))
359359

360360
def get_table_contexts(self, paper, tables):
361-
ref_tables = [table for table in tables if table.figure_id]
361+
ref_tables = [table for table in tables if table.figure_id and table.figure_id.replace(".", "")]
362362
refs = [table.figure_id.replace(".", "") for table in ref_tables]
363+
if not refs:
364+
return [[Counter(), Counter(), Counter()] for table in tables]
363365
ref_contexts = {ref: [Counter(), Counter(), Counter()] for ref in refs}
364366
if hasattr(paper.text, "fragments"):
365367
for fragment in paper.text.fragments:

0 commit comments

Comments
 (0)