Skip to content

Commit d005f9c

Browse files
author
Marcin Kardas
committed
Add text evidences for taxonomy entries
* plus some fixes
1 parent 298a938 commit d005f9c

File tree

6 files changed

+97
-22
lines changed

6 files changed

+97
-22
lines changed

sota_extractor2/data/elastic.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,3 +352,20 @@ def display_fragment(f, cell_type="", display=True):
352352
if display:
353353
display_html(html)
354354
return html
355+
356+
357+
def get_evidences_for_taxonomy(paper_id, task, dataset, metric, value):
358+
evidence_query = Fragment.search().highlight(
359+
'text', pre_tags="<b>", post_tags="</b>", fragment_size=50)
360+
361+
values = [task, dataset, metric, value]
362+
query = {
363+
"query": ' '.join(values)
364+
}
365+
366+
fragments = list(evidence_query
367+
.filter('term', paper_id=paper_id)
368+
.query('match', text=query)[:5]
369+
)
370+
371+
return '\n'.join([' '.join(f.meta['highlight']['text']) for f in fragments])

sota_extractor2/models/linking/bm25_naive.py

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import re
2-
from decimal import Decimal
2+
from decimal import Decimal, localcontext, InvalidOperation
33
from dataclasses import dataclass
44
import numpy as np
55
import pandas as pd
@@ -9,6 +9,7 @@
99
import spacy
1010
from scispacy.abbreviation import AbbreviationDetector
1111
from sota_extractor2.models.linking.format import extract_value
12+
from functools import total_ordering
1213

1314

1415
@dataclass()
@@ -58,6 +59,44 @@ def model_type(self):
5859
def __str__(self):
5960
return f"{self.model_name}: {self.raw_value} on {self.dataset}"
6061

62+
63+
@total_ordering
64+
class MetricValue(Decimal):
65+
value: Decimal
66+
unit: str = None
67+
68+
def __new__(cls, value, unit):
69+
return super().__new__(cls, value / Decimal(100) if unit is '%' else value)
70+
71+
def __init__(self, value, unit):
72+
self.value = value
73+
self.unit = unit
74+
75+
def to_absolute(self):
76+
return Decimal(self)
77+
78+
# unit = None means that no unit was specified, so we have to guess the unit.
79+
# if there's a value "21" in a table's cell, then we guess if it's 21 or 0.21 (i.e., 21%)
80+
# based on the target metric properties.
81+
def to_percentage(self):
82+
if self.unit is None and 0 < self.value < 1:
83+
return self.value * 100
84+
return self.value
85+
86+
def complement(self):
87+
if self.unit is None and 1 < self.value < 100:
88+
value = 100 - self.value
89+
else:
90+
value = 1 - self.value
91+
return MetricValue(value, self.unit)
92+
93+
def __repr__(self):
94+
return f"MetricValue({self.value}, {repr(self.unit)})"
95+
96+
def __str__(self):
97+
return str(self.value)
98+
99+
61100
def mkquery_ngrams(query):
62101
return {
63102
"query": {
@@ -164,7 +203,9 @@ def handle_pm(value):
164203
for match in float_pm_re.findall(value):
165204
if not match[0]:
166205
try:
167-
yield Decimal(whitespace_re.sub("", match[1])) / (100 if match[-1] else 1)
206+
percent = bool(match[-1])
207+
value = Decimal(whitespace_re.sub("", match[1])) / (100 if percent else 1)
208+
yield MetricValue(value, "%" if percent else None)
168209
except:
169210
pass
170211
# %%
@@ -217,26 +258,30 @@ def annotations(r, c, type='model'):
217258
def linked_proposals(proposals):
218259
for prop in proposals:
219260
# heuristyic to handle accuracy vs error
220-
first_num = (list(handle_pm(prop.raw_value)) + [0])[0]
221261
format = "{x}"
222-
# if first_num > 1:
223-
# first_num /= 100
224-
# format = "{x/100}"
225-
if 0 < first_num < 1 and '%' not in prop.raw_value:
226-
first_num *= 100
227-
format = "{100*x}"
228-
if '%' in prop.raw_value:
262+
263+
percentage = '%' in prop.raw_value
264+
if percentage:
229265
format += '%'
230266

231267
df = taxonomy_linking(prop.dataset, datasets, desc, topk=topk, debug_info=prop)
232268
for _, row in df.iterrows():
233269
raw_value = prop.raw_value
234-
parsed = extract_value(raw_value, format)
235270
metric = row['metric']
236-
if metric != row['true_metric']:
237-
metric = row['true_metric']
238-
parsed = 1 - parsed if 0 < parsed < 1 else 100 - parsed
239-
parsed = float(parsed)
271+
272+
with localcontext() as ctx:
273+
ctx.traps[InvalidOperation] = 0
274+
parsed = extract_value(raw_value, format)
275+
parsed = MetricValue(parsed, '%' if percentage else None)
276+
277+
if metric != row['true_metric']:
278+
metric = row['true_metric']
279+
parsed = parsed.complement()
280+
281+
if set(metric.lower().split()) & {"error", "accuracy", "bleu", "f1", "precision", "recall"}:
282+
parsed = float(parsed.to_percentage() / 100)
283+
else:
284+
parsed = float(parsed.to_absolute())
240285

241286
linked = {
242287
'dataset': row['dataset'],

sota_extractor2/models/linking/format.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import re
22
from decimal import Decimal, ROUND_DOWN, ROUND_HALF_UP, InvalidOperation
33

4-
float_value_re = re.compile(r"([+-]?(?:(?:\d{1,2}(?:,\d{3})+|\d+)(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)")
5-
float_value_nc = re.compile(r"(?:[+-]?(?:(?:\d{1,2}(?:,\d{3})+|\d+)(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)")
4+
float_value_re = re.compile(r"([+-]?(?:(?:\d{1,3}(?:,\d{3})+|\d+)(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)")
5+
float_value_nc = re.compile(r"(?:[+-]?(?:(?:\d{1,3}(?:,\d{3})+|\d+)(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)")
66
par_re = re.compile(r"\{([^\}]*)\}")
77
escaped_whitespace_re = re.compile(r"(\\\s)+")
88

sota_extractor2/models/linking/taxonomy.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
class Taxonomy:
88
def __init__(self, taxonomy, metrics_info):
99
self.taxonomy = self._get_taxonomy(taxonomy)
10-
self.metrics_info = self._read_metrics_info(metrics_info)
10+
self.metrics_info, self.metrics_range = self._read_metrics_info(metrics_info)
1111
self.tasks = self._get_axis('task')
1212
self.datasets = self._get_axis('dataset')
1313
self.metrics = self._get_axis('metric')
@@ -52,9 +52,19 @@ def _get_axis(self, axis):
5252
def _read_metrics_info(self, path):
5353
records = self._read_json(path)
5454
metrics_info = {}
55+
metrics_range = {}
56+
mr = {}
5557
for r in records:
5658
task, dataset, metric = r['task'], r['dataset'], r['metric']
59+
key = (task, dataset, metric)
5760
d = 1 if r['higher_is_better'] else -1
58-
metrics_info[(task, dataset, metric)] = d
61+
rng = r['range']
62+
metrics_info[key] = d
5963
metrics_info[metric] = metrics_info.get(metric, 0) + d
60-
return metrics_info
64+
metrics_range[key] = rng
65+
s = mr.get(metric, {})
66+
s[rng] = s.get(rng, 0) + 1
67+
mr[metric] = s
68+
for metric in mr:
69+
metrics_range[metric] = sorted(mr[metric].items(), key=lambda x: x[1])[-1]
70+
return metrics_info, metrics_range

sota_extractor2/models/linking/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ def normalize_dataset_ws(name):
3939

4040
def normalize_dataset(name):
4141
name = remove_references(name)
42-
name = hyphens_re.sub("", name)
4342
name = year_2k_re.sub(r"\1", name)
43+
name = hyphens_re.sub("", name)
4444
name = ws_re.sub(" ", name)
4545
return unidecode(name.strip().lower())
4646

@@ -51,4 +51,4 @@ def normalize_cell(s):
5151
def normalize_cell_ws(s):
5252
return unidecode("".join([x for x in s if x.isalnum() or x.isspace()]))
5353

54-
# end of cleaning & normalization
54+
# end of cleaning & normalization

sota_extractor2/pipeline_logger.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ class PipelineLogger:
55
def __init__(self):
66
self.observers = []
77

8+
def reset(self):
9+
self.observers = []
10+
811
def register(self, pattern, observer):
912
if isinstance(pattern, str):
1013
pattern = re.compile(pattern)

0 commit comments

Comments
 (0)