Skip to content

Commit b85a6e5

Browse files
author
Marcin Kardas
committed
Add abstract and references contexts
* add abstract context * add table contexts consinsting of paragraphs referencing a given table * count evidences with repetitions * stop adding abbreviations as datasets evidences * fix passing paper context as string * add confidence to non-filtered linking metrics to analyze results * add threshold_map to make it easier to choose stable regions
1 parent 081bd5c commit b85a6e5

File tree

5 files changed

+121
-44
lines changed

5 files changed

+121
-44
lines changed

sota_extractor2/helpers/explainers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,16 @@ def _get_sota_records(self, paper):
168168
records.index.rename("cell_ext_id", inplace=True)
169169
return records
170170

171-
def linking_metrics(self, experiment_name="unk", topk_metrics=False, filtered=True):
171+
def linking_metrics(self, experiment_name="unk", topk_metrics=False, filtered=True, confidence=0.0):
172172
paper_ids = list(self.le.proposals.keys())
173173

174174
proposals = pd.concat(self.le.proposals.values())
175175

176176
# if not topk_metrics:
177177
if filtered:
178178
proposals = proposals[~proposals.index.isin(self.fe.reason.index)]
179+
if confidence:
180+
proposals = proposals[proposals.confidence > confidence]
179181

180182
papers = {paper_id: self.paper_collection.get_by_id(paper_id) for paper_id in paper_ids}
181183
missing = [paper_id for paper_id, paper in papers.items() if paper is None]

sota_extractor2/helpers/optimize.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from dataclasses import dataclass, replace
33
from sota_extractor2.models.linking.metrics import CM
44
from matplotlib import pyplot as plt
5+
import matplotlib.tri as tri
56

67

78
def annotations(matrix, structure, r, c, type='model'):
@@ -259,6 +260,20 @@ def best(self, min_precision=0, min_recall=0, min_f1=0):
259260
self._best(results, "recall")
260261
self._best(results, "f1")
261262

263+
def threshold_map(self, metric):
264+
lin = np.linspace(0, 1, 64)
265+
266+
triang = tri.Triangulation(self.results.threshold1.values, self.results.threshold2.values)
267+
interpolator = tri.LinearTriInterpolator(triang, self.results[metric])
268+
Xi, Yi = np.meshgrid(lin, lin)
269+
zi = interpolator(Xi, Yi)
270+
plt.figure(figsize=(6, 6))
271+
img = plt.imshow(zi[::-1], extent=[0, 1, 0, 1])
272+
plt.colorbar(img)
273+
plt.xlabel("threshold1")
274+
plt.ylabel("threshold2")
275+
276+
262277
def optimize_filters(explainer, metrics_info):
263278
df = merge_gold_records(explainer)
264279
df = find_threshold_intervals(df, metrics_info, context="paper")

sota_extractor2/loggers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ def on_before_linking(self, step, paper, tables):
159159
def on_after_linking(self, step, paper, tables, proposals):
160160
self.proposals[paper.paper_id] = proposals.copy(deep=True)
161161

162-
def on_before_taxonomy(self, step, ext_id, query, datasets, caption):
163-
self.queries[ext_id] = (query, datasets, caption)
162+
def on_before_taxonomy(self, step, ext_id, query, paper_context, abstract_context, table_context, caption):
163+
self.queries[ext_id] = (query, paper_context, abstract_context, table_context, caption)
164164

165165
def on_taxonomy_topk(self, step, ext_id, topk):
166166
paper_id, table_name, rc = ext_id.split('/')

sota_extractor2/models/linking/bm25_naive.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,8 @@ def convert_metric(raw_value, rng, complementary):
240240
'confidence', 'parsed', 'struct_model_type', 'struct_dataset']
241241

242242

243-
def generate_proposals_for_table(table_ext_id, matrix, structure, desc, taxonomy_linking, datasets, topk=1):
243+
def generate_proposals_for_table(table_ext_id, matrix, structure, desc, taxonomy_linking,
244+
paper_context, abstract_context, table_context, topk=1):
244245
# %%
245246
# Proposal generation
246247
def consume_cells(matrix):
@@ -289,7 +290,8 @@ def linked_proposals(proposals):
289290
if percentage:
290291
format += '%'
291292

292-
df = taxonomy_linking(prop.dataset, datasets, desc, topk=topk, debug_info=prop)
293+
df = taxonomy_linking(prop.dataset, paper_context, abstract_context, table_context,
294+
desc, topk=topk, debug_info=prop)
293295
for _, row in df.iterrows():
294296
raw_value = prop.raw_value
295297
task = row['task']
@@ -335,9 +337,10 @@ def linked_proposals(paper_ext_id, paper, annotated_tables, taxonomy_linking=Non
335337
dataset_extractor=None, topk=1):
336338
# dataset_extractor=DatasetExtractor()):
337339
proposals = []
338-
datasets = dataset_extractor.from_paper(paper)
340+
paper_context, abstract_context = dataset_extractor.from_paper(paper)
341+
table_contexts = dataset_extractor.get_table_contexts(paper, annotated_tables)
339342
#print(f"Extracted datasets: {datasets}")
340-
for idx, table in enumerate(annotated_tables):
343+
for idx, (table, table_context) in enumerate(zip(annotated_tables, table_contexts)):
341344
matrix = np.array(table.matrix)
342345
structure = np.array(table.matrix_tags)
343346
tags = 'sota'
@@ -347,7 +350,9 @@ def linked_proposals(paper_ext_id, paper, annotated_tables, taxonomy_linking=Non
347350
if 'sota' in tags and 'no_sota_records' not in tags: # only parse tables that are marked as sota
348351
proposals.append(
349352
generate_proposals_for_table(
350-
table_ext_id, matrix, structure, desc, taxonomy_linking, datasets, topk=topk
353+
table_ext_id, matrix, structure, desc, taxonomy_linking,
354+
paper_context, abstract_context, table_context,
355+
topk=topk
351356
)
352357
)
353358
if len(proposals):

sota_extractor2/models/linking/context_search.py

Lines changed: 91 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sota_extractor2.pipeline_logger import pipeline_logger
1515

1616
from sota_extractor2.models.linking import manual_dicts
17+
from collections import Counter
1718

1819
def dummy_item(reason):
1920
return pd.DataFrame(dict(dataset=[reason], task=[reason], metric=[reason], evidence=[""], confidence=[0.0]))
@@ -64,10 +65,10 @@ def find_names(text, names_trie):
6465
profile = EvidenceFinder.single_letter_re.sub("x", profile)
6566
text = text.replace(" ", "")
6667
profile = profile.replace(" ", "")
67-
s = set()
68+
s = Counter()
6869
for (end, (l, word)) in names_trie.iter(text):
6970
if profile[end] in ['e', 'x'] and profile[end - l + 1] in ['b', 'x']:
70-
s.add(word)
71+
s[word] += 1
7172
return s
7273

7374
def find_datasets(self, text):
@@ -105,30 +106,31 @@ def _init_structs(self, taxonomy):
105106

106107

107108
@njit
108-
def axis_logprobs(evidences_for, reverse_probs, found_evidences, noise, pb):
109+
def axis_logprobs(evidences_for, reverse_probs, found_evidences, noise, pb, max_repetitions):
109110
logprob = 0.0
110111
empty = typed.Dict.empty(types.unicode_type, types.float64)
111112
short_probs = reverse_probs.get(evidences_for, empty)
112-
for evidence in found_evidences:
113-
logprob += np.log(noise * pb + (1 - noise) * short_probs.get(evidence, 0.0))
113+
for evidence, count in found_evidences.items():
114+
logprob += min(count, max_repetitions) * np.log(noise * pb + (1 - noise) * short_probs.get(evidence, 0.0))
114115
return logprob
115116

116117

117118
# compute log-probabilities in a given context and add them to logprobs
118119
@njit
119120
def compute_logprobs(taxonomy, tasks, datasets, metrics,
120121
reverse_merged_p, reverse_metrics_p, reverse_task_p,
121-
dss, mss, tss, noise, ms_noise, ts_noise, ds_pb, ms_pb, ts_pb, logprobs, axes_logprobs):
122+
dss, mss, tss, noise, ms_noise, ts_noise, ds_pb, ms_pb, ts_pb, logprobs, axes_logprobs,
123+
max_repetitions):
122124
task_cache = typed.Dict.empty(types.unicode_type, types.float64)
123125
dataset_cache = typed.Dict.empty(types.unicode_type, types.float64)
124126
metric_cache = typed.Dict.empty(types.unicode_type, types.float64)
125127
for i, (task, dataset, metric) in enumerate(taxonomy):
126128
if dataset not in dataset_cache:
127-
dataset_cache[dataset] = axis_logprobs(dataset, reverse_merged_p, dss, noise, ds_pb)
129+
dataset_cache[dataset] = axis_logprobs(dataset, reverse_merged_p, dss, noise, ds_pb, 1)
128130
if metric not in metric_cache:
129-
metric_cache[metric] = axis_logprobs(metric, reverse_metrics_p, mss, ms_noise, ms_pb)
131+
metric_cache[metric] = axis_logprobs(metric, reverse_metrics_p, mss, ms_noise, ms_pb, 1)
130132
if task not in task_cache:
131-
task_cache[task] = axis_logprobs(task, reverse_task_p, tss, ts_noise, ts_pb)
133+
task_cache[task] = axis_logprobs(task, reverse_task_p, tss, ts_noise, ts_pb, max_repetitions)
132134

133135
logprobs[i] += dataset_cache[dataset] + metric_cache[metric] + task_cache[task]
134136
for i, task in enumerate(tasks):
@@ -149,7 +151,7 @@ def _to_typed_list(iterable):
149151

150152

151153
class ContextSearch:
152-
def __init__(self, taxonomy, evidence_finder, context_noise=(0.5, 0.2, 0.1), metrics_noise=None, task_noise=None,
154+
def __init__(self, taxonomy, evidence_finder, context_noise=(0.5, 0.1, 0.2, 0.2, 0.1), metric_noise=None, task_noise=None,
153155
ds_pb=0.001, ms_pb=0.01, ts_pb=0.01, debug_gold_df=None):
154156
merged_p = \
155157
get_probs({k: Counter([normalize_cell(normalize_dataset(x)) for x in v]) for k, v in evidence_finder.datasets.items()})[1]
@@ -169,7 +171,7 @@ def __init__(self, taxonomy, evidence_finder, context_noise=(0.5, 0.2, 0.1), met
169171

170172
self.extract_acronyms = AcronymExtractor()
171173
self.context_noise = context_noise
172-
self.metrics_noise = metrics_noise if metrics_noise else context_noise
174+
self.metrics_noise = metric_noise if metric_noise else context_noise
173175
self.task_noise = task_noise if task_noise else context_noise
174176
self.ds_pb = ds_pb
175177
self.ms_pb = ms_pb
@@ -178,6 +180,7 @@ def __init__(self, taxonomy, evidence_finder, context_noise=(0.5, 0.2, 0.1), met
178180
self.reverse_metrics_p = self._numba_update_nested_dict(reverse_probs(metrics_p))
179181
self.reverse_tasks_p = self._numba_update_nested_dict(reverse_probs(tasks_p))
180182
self.debug_gold_df = debug_gold_df
183+
self.max_repetitions = 1
181184

182185
def _numba_update_nested_dict(self, nested):
183186
d = typed.Dict()
@@ -188,32 +191,43 @@ def _numba_update_nested_dict(self, nested):
188191
return d
189192

190193
def _numba_extend_list(self, lst):
191-
l = typed.List.empty_list(types.unicode_type)
194+
l = typed.List.empty_list((types.unicode_type, types.int32))
192195
for x in lst:
193196
l.append(x)
194197
return l
195198

199+
def _numba_extend_dict(self, dct):
200+
d = typed.Dict.empty(types.unicode_type, types.int64)
201+
d.update(dct)
202+
return d
203+
196204
def compute_context_logprobs(self, context, noise, ms_noise, ts_noise, logprobs, axes_logprobs):
197-
context = context or ""
198-
abbrvs = self.extract_acronyms(context)
199-
context = normalize_cell_ws(normalize_dataset_ws(context))
200-
dss = set(self.evidence_finder.find_datasets(context)) | set(abbrvs.keys())
201-
mss = set(self.evidence_finder.find_metrics(context))
202-
tss = set(self.evidence_finder.find_tasks(context))
203-
dss -= mss
204-
dss -= tss
205-
dss = [normalize_cell(ds) for ds in dss]
206-
mss = [normalize_cell(ms) for ms in mss]
207-
tss = [normalize_cell(ts) for ts in tss]
205+
if isinstance(context, str) or context is None:
206+
context = context or ""
207+
#abbrvs = self.extract_acronyms(context)
208+
context = normalize_cell_ws(normalize_dataset_ws(context))
209+
#dss = set(self.evidence_finder.find_datasets(context)) | set(abbrvs.keys())
210+
dss = self.evidence_finder.find_datasets(context)
211+
mss = self.evidence_finder.find_metrics(context)
212+
tss = self.evidence_finder.find_tasks(context)
213+
214+
dss -= mss
215+
dss -= tss
216+
else:
217+
tss, dss, mss = context
218+
219+
dss = {normalize_cell(ds): count for ds, count in dss.items()}
220+
mss = {normalize_cell(ms): count for ms, count in mss.items()}
221+
tss = {normalize_cell(ts): count for ts, count in tss.items()}
208222
###print("dss", dss)
209223
###print("mss", mss)
210-
dss = self._numba_extend_list(dss)
211-
mss = self._numba_extend_list(mss)
212-
tss = self._numba_extend_list(tss)
224+
dss = self._numba_extend_dict(dss)
225+
mss = self._numba_extend_dict(mss)
226+
tss = self._numba_extend_dict(tss)
213227
compute_logprobs(self._taxonomy, self._taxonomy_tasks, self._taxonomy_datasets, self._taxonomy_metrics,
214228
self.reverse_merged_p, self.reverse_metrics_p, self.reverse_tasks_p,
215229
dss, mss, tss, noise, ms_noise, ts_noise, self.ds_pb, self.ms_pb, self.ts_pb, logprobs,
216-
axes_logprobs)
230+
axes_logprobs, self.max_repetitions)
217231

218232
def match(self, contexts):
219233
assert len(contexts) == len(self.context_noise)
@@ -239,11 +253,16 @@ def match(self, contexts):
239253
zip(self.taxonomy.metrics, axes_probs[2])
240254
)
241255

242-
def __call__(self, query, datasets, caption, topk=1, debug_info=None):
256+
def __call__(self, query, paper_context, abstract_context, table_context, caption, topk=1, debug_info=None):
243257
cellstr = debug_info.cell.cell_ext_id
244-
pipeline_logger("linking::taxonomy_linking::call", ext_id=cellstr, query=query, datasets=datasets, caption=caption)
245-
datasets = " ".join(datasets)
246-
key = (datasets, caption, query, topk)
258+
pipeline_logger("linking::taxonomy_linking::call", ext_id=cellstr, query=query,
259+
paper_context=paper_context, abstract_context=abstract_context, table_context=table_context,
260+
caption=caption)
261+
262+
paper_hash = ";".join(",".join(s.elements()) for s in paper_context)
263+
abstract_hash = ";".join(",".join(s.elements()) for s in abstract_context)
264+
mentions_hash = ";".join(",".join(s.elements()) for s in table_context)
265+
key = (paper_hash, abstract_hash, mentions_hash, caption, query, topk)
247266
###print(f"[DEBUG] {cellstr}")
248267
###print("[DEBUG]", debug_info)
249268
###print("query:", query, caption)
@@ -261,7 +280,7 @@ def __call__(self, query, datasets, caption, topk=1, debug_info=None):
261280
###print("Taking result from cache")
262281
p = self.queries[key]
263282
else:
264-
dists = self.match((datasets, caption, query))
283+
dists = self.match((paper_context, abstract_context, table_context, caption, query))
265284

266285
all_top_results = [sorted(dist, key=lambda x: x[1], reverse=True)[:max(topk, 5)] for dist in dists]
267286
top_results, top_results_t, top_results_d, top_results_m = all_top_results
@@ -279,7 +298,10 @@ def __call__(self, query, datasets, caption, topk=1, debug_info=None):
279298
# task=top_results_t[i][0],
280299
# dataset=top_results_d[i][0],
281300
# metric=top_results_m[i][0])
282-
# best_independent.update({"evidence": "", "confidence": top_results_t[i][1]})
301+
# best_independent.update({
302+
# "evidence": "",
303+
# "confidence": np.power(top_results_t[i][1] * top_results_d[i][1] * top_results_m[i][1], 1.0/3.0)
304+
# })
283305
# entries.append(best_independent)
284306
#entries = [best_independent] + entries
285307

@@ -314,18 +336,51 @@ def __call__(self, query, datasets, caption, topk=1, debug_info=None):
314336

315337

316338
# todo: compare regex approach (old) with find_datasets(.) (current)
339+
# todo: rename it
317340
class DatasetExtractor:
318341
def __init__(self, evidence_finder):
319342
self.evidence_finder = evidence_finder
320343
self.dataset_prefix_re = re.compile(r"[A-Z]|[a-z]+[A-Z]+|[0-9]")
321344
self.dataset_name_re = re.compile(r"\b(the)\b\s*(?P<name>((?!(the)\b)\w+\W+){1,10}?)(test|val(\.|idation)?|dev(\.|elopment)?|train(\.|ing)?\s+)?\bdata\s*set\b", re.IGNORECASE)
322345

346+
def find_references(self, text, references):
347+
refs = r"\bxxref-(" + "|".join([re.escape(ref) for ref in references]) + r")\b"
348+
return set(re.findall(refs, text))
349+
350+
def get_table_contexts(self, paper, tables):
351+
ref_tables = [table for table in tables if table.figure_id]
352+
refs = [table.figure_id.replace(".", "") for table in ref_tables]
353+
ref_contexts = {ref: [Counter(), Counter(), Counter()] for ref in refs}
354+
if hasattr(paper.text, "fragments"):
355+
for fragment in paper.text.fragments:
356+
found_refs = self.find_references(fragment.text, refs)
357+
if found_refs:
358+
ts, ds, ms = self(fragment.header + "\n" + fragment.text)
359+
for ref in found_refs:
360+
ref_contexts[ref][0] += ts
361+
ref_contexts[ref][1] += ds
362+
ref_contexts[ref][2] += ms
363+
table_contexts = [
364+
ref_contexts.get(
365+
table.figure_id.replace(".", ""),
366+
[Counter(), Counter(), Counter()]
367+
) if table.figure_id else [Counter(), Counter(), Counter()]
368+
for table in tables
369+
]
370+
return table_contexts
371+
323372
def from_paper(self, paper):
324-
text = paper.text.abstract
373+
abstract = paper.text.abstract
374+
text = ""
325375
if hasattr(paper.text, "fragments"):
326376
text += " ".join(f.text for f in paper.text.fragments)
327-
return self(text)
377+
return self(text), self(abstract)
328378

329379
def __call__(self, text):
330380
text = normalize_cell_ws(normalize_dataset_ws(text))
331-
return self.evidence_finder.find_datasets(text) | self.evidence_finder.find_tasks(text)
381+
ds = self.evidence_finder.find_datasets(text)
382+
ts = self.evidence_finder.find_tasks(text)
383+
ms = self.evidence_finder.find_metrics(text)
384+
ds -= ts
385+
ds -= ms
386+
return ts, ds, ms

0 commit comments

Comments
 (0)