Skip to content

Commit ac2a5b5

Browse files
committed
Add separate noise weights for metrics
* add separate noise weights for metrics * add noise dataset and metric probabilities * make format parsing less sensitive to whitespaces * show non-parsed gold sota records
1 parent 437a68e commit ac2a5b5

File tree

9 files changed

+63
-32
lines changed

9 files changed

+63
-32
lines changed

sota_extractor2/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,5 @@
2828

2929
linking_models = datasets / "linking" / "models"
3030
linking_data = datasets / "linking" / "data"
31+
32+
autodict = linking_data / "autodict"

sota_extractor2/loggers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def __init__(self, pipeline_logger):
151151
pipeline_logger.register("linking::linked", self.on_after_linking)
152152
self.proposals = {}
153153
self.topk = {}
154+
self.queries = {}
154155

155156
def on_before_linking(self, step, paper, tables):
156157
pass
@@ -159,7 +160,7 @@ def on_after_linking(self, step, paper, tables, proposals):
159160
self.proposals[paper.paper_id] = proposals.copy(deep=True)
160161

161162
def on_before_taxonomy(self, step, ext_id, query, datasets, caption):
162-
pass
163+
self.queries[ext_id] = (query, datasets, caption)
163164

164165
def on_taxonomy_topk(self, step, ext_id, topk):
165166
paper_id, table_name, rc = ext_id.split('/')

sota_extractor2/models/linking/acronym_extractor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import spacy
22
from scispacy.abbreviation import AbbreviationDetector
3-
from .utils import normalize_cell, normalize_dataset_ws
3+
from .utils import normalize_cell, normalize_dataset
44

55
class AcronymExtractor:
66
def __init__(self):
@@ -14,7 +14,7 @@ def __call__(self, text):
1414
abbrvs = {}
1515
for abrv in doc._.abbreviations:
1616
# abbrvs.setdefault(normalize_cell(str(abrv)), Counter())[str(abrv._.long_form)] += 1
17-
norm = normalize_cell(normalize_dataset_ws(str(abrv)))
17+
norm = normalize_cell(normalize_dataset(str(abrv)))
1818
if norm != '':
19-
abbrvs[norm] = normalize_cell(normalize_dataset_ws(str(abrv._.long_form)))
19+
abbrvs[norm] = normalize_cell(normalize_dataset(str(abrv._.long_form)))
2020
return abbrvs

sota_extractor2/models/linking/context_search.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from sota_extractor2.models.linking.acronym_extractor import AcronymExtractor
55
from sota_extractor2.models.linking.probs import get_probs, reverse_probs
6-
from sota_extractor2.models.linking.utils import normalize_dataset_ws, normalize_cell, normalize_cell_ws
6+
from sota_extractor2.models.linking.utils import normalize_dataset, normalize_cell, normalize_cell_ws
77
from scipy.special import softmax
88
import re
99
import pandas as pd
@@ -108,16 +108,16 @@
108108
'Rain100L': ['rain100l'],
109109
'Rain12': ['rain12'],
110110
'Rain800': ['rain800'],
111-
'Rain1400': ['rain1400'],
112-
'Real Rain': ['real rain'],
113-
'Rain in Surveillance': ['ris'],
114-
'Rain in Driving': ['rid'],
111+
'Rain1400': ['rain1400'],
112+
'Real Rain': ['real rain'],
113+
'Rain in Surveillance': ['ris'],
114+
'Rain in Driving': ['rid'],
115115
'DID-MDN': ['did-mdn'],
116116
'SOTS': ['sots'],
117117
'Test 1': ['test 1'],
118118
'RainSynLight25': ['rainsynlight25'],
119-
'RainSynComplex25': ['rainsyncomplex25'],
120-
'NTURain': ['nturain'],
119+
'RainSynComplex25': ['rainsyncomplex25'],
120+
'NTURain': ['nturain'],
121121
'RainSynAll100': ['rainsynall100'],
122122
'SPA-DATA': ['spa-data'],
123123
'LasVR': ['lasvar'],
@@ -143,8 +143,8 @@
143143
# return re.compile(r'(?:^|\s+)' + escaped_ws_re.sub(r'\\s*', re.escape(name.strip())) + r'(?:$|\s+)', re.I)
144144

145145
#all_datasets = set(k for k,v in merged_p.items() if k != '' and not re.match("^\d+$", k) and v.get('NOMATCH', 0.0) < 0.9)
146-
all_datasets = set(y for x in datasets.values() for y in x)
147-
all_metrics = set(y for x in metrics.values() for y in x)
146+
all_datasets = set(normalize_cell_ws(normalize_dataset(y)) for x in datasets.values() for y in x)
147+
all_metrics = set(normalize_cell_ws(y) for x in metrics.values() for y in x)
148148
#all_metrics = set(metrics_p.keys())
149149

150150
# all_datasets_re = {x:name_to_re(x) for x in all_datasets}
@@ -201,7 +201,7 @@ def dummy_item(reason):
201201

202202

203203
@njit
204-
def compute_logprobs(taxonomy, reverse_merged_p, reverse_metrics_p, dss, mss, noise, logprobs):
204+
def compute_logprobs(taxonomy, reverse_merged_p, reverse_metrics_p, dss, mss, noise, ms_noise, ds_pb, ms_pb, logprobs):
205205
empty = typed.Dict.empty(types.unicode_type, types.float64)
206206
for i, (task, dataset, metric) in enumerate(taxonomy):
207207
logprob = 0.0
@@ -213,19 +213,19 @@ def compute_logprobs(taxonomy, reverse_merged_p, reverse_metrics_p, dss, mss, no
213213
# ds = long_form
214214
# break
215215
# if merged_p[ds].get('NOMATCH', 0.0) < 0.5:
216-
logprob += np.log(noise * 0.001 + (1 - noise) * short_probs.get(ds, 0.0))
216+
logprob += np.log(noise * ds_pb + (1 - noise) * short_probs.get(ds, 0.0))
217217
for ms in mss:
218-
logprob += np.log(noise * 0.01 + (1 - noise) * met_probs.get(ms, 0.0))
218+
logprob += np.log(ms_noise * ms_pb + (1 - ms_noise) * met_probs.get(ms, 0.0))
219219
logprobs[i] += logprob
220220
#logprobs[(dataset, metric)] = logprob
221221

222222

223223
class ContextSearch:
224-
def __init__(self, taxonomy, context_noise=(0.5, 0.2, 0.1), debug_gold_df=None):
224+
def __init__(self, taxonomy, context_noise=(0.5, 0.2, 0.1), metrics_noise=None, ds_pb=0.001, ms_pb=0.01, debug_gold_df=None):
225225
merged_p = \
226-
get_probs({k: Counter([normalize_cell(normalize_dataset_ws(x)) for x in v]) for k, v in datasets.items()})[1]
226+
get_probs({k: Counter([normalize_cell(normalize_dataset(x)) for x in v]) for k, v in datasets.items()})[1]
227227
metrics_p = \
228-
get_probs({k: Counter([normalize_cell(normalize_dataset_ws(x)) for x in v]) for k, v in metrics.items()})[1]
228+
get_probs({k: Counter([normalize_cell(normalize_dataset(x)) for x in v]) for k, v in metrics.items()})[1]
229229

230230

231231
self.queries = {}
@@ -235,6 +235,9 @@ def __init__(self, taxonomy, context_noise=(0.5, 0.2, 0.1), debug_gold_df=None):
235235
self._taxonomy.append(t)
236236
self.extract_acronyms = AcronymExtractor()
237237
self.context_noise = context_noise
238+
self.metrics_noise = metrics_noise if metrics_noise else context_noise
239+
self.ds_pb = ds_pb
240+
self.ms_pb = ms_pb
238241
self.reverse_merged_p = self._numba_update_nested_dict(reverse_probs(merged_p))
239242
self.reverse_metrics_p = self._numba_update_nested_dict(reverse_probs(metrics_p))
240243
self.debug_gold_df = debug_gold_df
@@ -253,10 +256,10 @@ def _numba_extend_list(self, lst):
253256
l.append(x)
254257
return l
255258

256-
def compute_context_logprobs(self, context, noise, logprobs):
259+
def compute_context_logprobs(self, context, noise, ms_noise, logprobs):
257260
context = context or ""
258261
abbrvs = self.extract_acronyms(context)
259-
context = normalize_cell_ws(normalize_dataset_ws(context))
262+
context = normalize_cell_ws(normalize_dataset(context))
260263
dss = set(find_datasets(context)) | set(abbrvs.keys())
261264
mss = set(find_metrics(context))
262265
dss -= mss
@@ -266,15 +269,16 @@ def compute_context_logprobs(self, context, noise, logprobs):
266269
###print("mss", mss)
267270
dss = self._numba_extend_list(dss)
268271
mss = self._numba_extend_list(mss)
269-
compute_logprobs(self._taxonomy, self.reverse_merged_p, self.reverse_metrics_p, dss, mss, noise, logprobs)
272+
compute_logprobs(self._taxonomy, self.reverse_merged_p, self.reverse_metrics_p,
273+
dss, mss, noise, ms_noise, self.ds_pb, self.ms_pb, logprobs)
270274

271275
def match(self, contexts):
272276
assert len(contexts) == len(self.context_noise)
273277
n = len(self._taxonomy)
274278
context_logprobs = np.zeros(n)
275279

276-
for context, noise in zip(contexts, self.context_noise):
277-
self.compute_context_logprobs(context, noise, context_logprobs)
280+
for context, noise, ms_noise in zip(contexts, self.context_noise, self.metrics_noise):
281+
self.compute_context_logprobs(context, noise, ms_noise, context_logprobs)
278282
keys = self.taxonomy.taxonomy
279283
logprobs = context_logprobs
280284
#keys, logprobs = zip(*context_logprobs.items())
@@ -293,7 +297,7 @@ def __call__(self, query, datasets, caption, debug_info=None):
293297
# print(self.queries[key])
294298
# for context in key:
295299
# abbrvs = self.extract_acronyms(context)
296-
# context = normalize_cell_ws(normalize_dataset_ws(context))
300+
# context = normalize_cell_ws(normalize_dataset(context))
297301
# dss = set(find_datasets(context)) | set(abbrvs.keys())
298302
# mss = set(find_metrics(context))
299303
# dss -= mss
@@ -353,4 +357,4 @@ def from_paper(self, paper):
353357
return self(text)
354358

355359
def __call__(self, text):
356-
return find_datasets(normalize_cell_ws(normalize_dataset_ws(text)))
360+
return find_datasets(normalize_cell_ws(normalize_dataset(text)))

sota_extractor2/models/linking/execution.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pandas as pd
22
from django.db import connection
3+
from IPython.core.display import display
34

45
from sota_extractor2.models.linking.metrics import Metrics
56
from sota_extractor2.models.linking.format import extract_value
@@ -36,6 +37,13 @@ def fetch_gold_sota_records():
3637
gold_sota_records["parsed"] = gold_sota_records[["raw_value", "format"]].apply(
3738
lambda row: float(extract_value(row.raw_value, row.format)), axis=1)
3839

40+
unparsed = gold_sota_records[gold_sota_records["parsed"] != gold_sota_records["parsed"]]
41+
if len(unparsed):
42+
print("Found unparsed values")
43+
display(unparsed.style.format({'cell_ext_id':
44+
lambda x: f'<a target="labeler" href="http://10.0.1.145:8001/paper/{x}">{x}</a>'})
45+
)
46+
3947
gold_sota_records = gold_sota_records[gold_sota_records["parsed"] == gold_sota_records["parsed"]]
4048

4149
strip_cols=["task", "dataset", "format", "metric", "raw_value", "model", "model_type"]

sota_extractor2/models/linking/format.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ def format_to_regexp(format):
1212
fn=lambda x: x
1313
for i, s in enumerate(placeholders):
1414
if i % 2 == 0:
15-
regexp += escaped_whitespace_re.sub(r"\\s+", re.escape(s))
15+
if s.strip() == "":
16+
regexp += escaped_whitespace_re.sub(r"\\s+", re.escape(s))
17+
else:
18+
regexp += escaped_whitespace_re.sub(r"\\s*", re.escape(s))
1619
elif s.strip() == "":
1720
regexp += float_value_nc.pattern
1821
else:
@@ -29,6 +32,6 @@ def extract_value(cell_value, format):
2932
cell_value = re.sub(r"\s+%", "%", cell_value)
3033
regexp, fn = format_to_regexp(format)
3134
match = regexp.match(cell_value.strip())
32-
if match is None:
35+
if match is None or not len(match.groups()):
3336
return Decimal('NaN')
3437
return fn(Decimal(match.group(1)))

sota_extractor2/models/linking/probs.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,19 @@
11
from collections import Counter
22

3+
34
def get_probs(occurrences):
5+
"""
6+
Computes conditional probabilities based on frequency of co-occurrences
7+
8+
Parameters
9+
----------
10+
occurrences: occurences[x][y] number of times with (X=x and Y=y)
11+
12+
Returns
13+
-------
14+
probs : probs[x][y] = Pr(Y=y | X=x)
15+
reverse_probs : reverse_probs[y][x] = Pr(X=x | Y=y)
16+
"""
417
probs = {}
518
reverse_probs = {}
619
y_occ = Counter()
@@ -27,7 +40,7 @@ def reverse_probs(probs):
2740
2841
Returns
2942
-------
30-
reverse : reverse[y][x] = Pr(X=x | Y=y) assuming X and Y are uniform
43+
reverse : reverse[y][x] = Pr(X=x | Y=y) assuming X is uniform
3144
"""
3245
reverse = {}
3346
for x, probs_x in probs.items():

sota_extractor2/models/linking/proposals_filters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def __init__(self, confidence=-1):
8686
self.confidence = confidence
8787

8888
def _filter(self, proposals):
89-
which = proposals.confidence > self.confidence
90-
reason = "confidence " + proposals[~which].confidence.round(2).astype(str) + f" <= {self.confidence}"
89+
which = proposals.confidence >= self.confidence
90+
reason = "confidence " + proposals[~which].confidence.round(2).astype(str) + f" < {self.confidence}"
9191
return which, reason[~which]
9292

9393
def log(self, **kwargs):

sota_extractor2/models/linking/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def remove_parens(text):
1515
return parens_re.sub("", text)
1616

1717
def clean_name(name):
18-
return remove_parens(name.replace('\xa0', ' ').strip()).strip()
18+
return remove_parens(unidecode(name).strip()).strip()
1919

2020
def clean_cell(cell):
2121
return strip_nonalnum(clean_name(cell))

0 commit comments

Comments
 (0)