Skip to content

Commit 081bd5c

Browse files
author
Marcin Kardas
committed
Predict tasks, datasets and metrics independently
* compute probabilities for each axis (tasks, datasets and metrics) independently * fix metric score extraction and conversion
1 parent d005f9c commit 081bd5c

File tree

5 files changed

+118
-39
lines changed

5 files changed

+118
-39
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,6 @@ venv.bak/
9898
.mypy_cache/
9999
.idea/*
100100
.vscode/settings.json
101+
102+
# pytest
103+
.pytest_cache

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,9 @@ To test the whole extraction on a single file run
3434
```
3535
make test
3636
```
37+
38+
### Unit Tests
39+
40+
```
41+
PYTHONPATH=. py.test
42+
```

sota_extractor2/models/linking/bm25_naive.py

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import spacy
1010
from scispacy.abbreviation import AbbreviationDetector
1111
from sota_extractor2.models.linking.format import extract_value
12-
from functools import total_ordering
1312

1413

1514
@dataclass()
@@ -60,20 +59,19 @@ def __str__(self):
6059
return f"{self.model_name}: {self.raw_value} on {self.dataset}"
6160

6261

63-
@total_ordering
64-
class MetricValue(Decimal):
62+
class MetricValue:
6563
value: Decimal
6664
unit: str = None
6765

68-
def __new__(cls, value, unit):
69-
return super().__new__(cls, value / Decimal(100) if unit is '%' else value)
70-
7166
def __init__(self, value, unit):
7267
self.value = value
7368
self.unit = unit
7469

70+
def to_unitless(self):
71+
return self.value
72+
7573
def to_absolute(self):
76-
return Decimal(self)
74+
return self.value / Decimal(100) if self.unit is '%' else self.value
7775

7876
# unit = None means that no unit was specified, so we have to guess the unit.
7977
# if there's a value "21" in a table's cell, then we guess if it's 21 or 0.21 (i.e., 21%)
@@ -84,10 +82,13 @@ def to_percentage(self):
8482
return self.value
8583

8684
def complement(self):
87-
if self.unit is None and 1 < self.value < 100:
88-
value = 100 - self.value
85+
if self.unit is None:
86+
if 1 < self.value < 100:
87+
value = 100 - self.value
88+
else:
89+
value = 1 - self.value
8990
else:
90-
value = 1 - self.value
91+
value = 100 - self.value
9192
return MetricValue(value, self.unit)
9293

9394
def __repr__(self):
@@ -211,6 +212,30 @@ def handle_pm(value):
211212
# %%
212213

213214

215+
def convert_metric(raw_value, rng, complementary):
216+
format = "{x}"
217+
218+
percentage = '%' in raw_value
219+
if percentage:
220+
format += '%'
221+
222+
with localcontext() as ctx:
223+
ctx.traps[InvalidOperation] = 0
224+
parsed = extract_value(raw_value, format)
225+
parsed = MetricValue(parsed, '%' if percentage else None)
226+
227+
if complementary:
228+
parsed = parsed.complement()
229+
if rng == '0-1':
230+
parsed = parsed.to_percentage() / 100
231+
elif rng == '1-100':
232+
parsed = parsed.to_percentage()
233+
elif rng == 'abs':
234+
parsed = parsed.to_absolute()
235+
else:
236+
parsed = parsed.to_unitless()
237+
return parsed
238+
214239
proposal_columns = ['dataset', 'metric', 'task', 'format', 'raw_value', 'model', 'model_type', 'cell_ext_id',
215240
'confidence', 'parsed', 'struct_model_type', 'struct_dataset']
216241

@@ -267,26 +292,27 @@ def linked_proposals(proposals):
267292
df = taxonomy_linking(prop.dataset, datasets, desc, topk=topk, debug_info=prop)
268293
for _, row in df.iterrows():
269294
raw_value = prop.raw_value
295+
task = row['task']
296+
dataset = row['dataset']
270297
metric = row['metric']
271298

272-
with localcontext() as ctx:
273-
ctx.traps[InvalidOperation] = 0
274-
parsed = extract_value(raw_value, format)
275-
parsed = MetricValue(parsed, '%' if percentage else None)
299+
complementary = False
300+
if metric != row['true_metric']:
301+
metric = row['true_metric']
302+
complementary = True
276303

277-
if metric != row['true_metric']:
278-
metric = row['true_metric']
279-
parsed = parsed.complement()
304+
# todo: pass taxonomy directly to proposals generation
305+
ranges = taxonomy_linking.taxonomy.metrics_range
306+
key = (task, dataset, metric)
307+
rng = ranges.get(key, '')
308+
if not rng: rng = ranges.get(metric, '')
280309

281-
if set(metric.lower().split()) & {"error", "accuracy", "bleu", "f1", "precision", "recall"}:
282-
parsed = float(parsed.to_percentage() / 100)
283-
else:
284-
parsed = float(parsed.to_absolute())
310+
parsed = float(convert_metric(raw_value, rng, complementary))
285311

286312
linked = {
287-
'dataset': row['dataset'],
313+
'dataset': dataset,
288314
'metric': metric,
289-
'task': row['task'],
315+
'task': task,
290316
'format': format,
291317
'raw_value': raw_value,
292318
'model': prop.model_name,
@@ -305,7 +331,7 @@ def linked_proposals(proposals):
305331
return proposals
306332

307333

308-
def linked_proposals(paper_ext_id, paper, annotated_tables, taxonomy_linking=MatchSearch(),
334+
def linked_proposals(paper_ext_id, paper, annotated_tables, taxonomy_linking=None,
309335
dataset_extractor=None, topk=1):
310336
# dataset_extractor=DatasetExtractor()):
311337
proposals = []

sota_extractor2/models/linking/context_search.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,9 @@ def axis_logprobs(evidences_for, reverse_probs, found_evidences, noise, pb):
116116

117117
# compute log-probabilities in a given context and add them to logprobs
118118
@njit
119-
def compute_logprobs(taxonomy, reverse_merged_p, reverse_metrics_p, reverse_task_p,
120-
dss, mss, tss, noise, ms_noise, ts_noise, ds_pb, ms_pb, ts_pb, logprobs):
119+
def compute_logprobs(taxonomy, tasks, datasets, metrics,
120+
reverse_merged_p, reverse_metrics_p, reverse_task_p,
121+
dss, mss, tss, noise, ms_noise, ts_noise, ds_pb, ms_pb, ts_pb, logprobs, axes_logprobs):
121122
task_cache = typed.Dict.empty(types.unicode_type, types.float64)
122123
dataset_cache = typed.Dict.empty(types.unicode_type, types.float64)
123124
metric_cache = typed.Dict.empty(types.unicode_type, types.float64)
@@ -130,6 +131,21 @@ def compute_logprobs(taxonomy, reverse_merged_p, reverse_metrics_p, reverse_task
130131
task_cache[task] = axis_logprobs(task, reverse_task_p, tss, ts_noise, ts_pb)
131132

132133
logprobs[i] += dataset_cache[dataset] + metric_cache[metric] + task_cache[task]
134+
for i, task in enumerate(tasks):
135+
axes_logprobs[0][i] += task_cache[task]
136+
137+
for i, dataset in enumerate(datasets):
138+
axes_logprobs[1][i] += dataset_cache[dataset]
139+
140+
for i, metric in enumerate(metrics):
141+
axes_logprobs[2][i] += metric_cache[metric]
142+
143+
144+
def _to_typed_list(iterable):
145+
l = typed.List()
146+
for i in iterable:
147+
l.append(i)
148+
return l
133149

134150

135151
class ContextSearch:
@@ -145,9 +161,12 @@ def __init__(self, taxonomy, evidence_finder, context_noise=(0.5, 0.2, 0.1), met
145161
self.queries = {}
146162
self.taxonomy = taxonomy
147163
self.evidence_finder = evidence_finder
148-
self._taxonomy = typed.List()
149-
for t in self.taxonomy.taxonomy:
150-
self._taxonomy.append(t)
164+
165+
self._taxonomy = _to_typed_list(self.taxonomy.taxonomy)
166+
self._taxonomy_tasks = _to_typed_list(self.taxonomy.tasks)
167+
self._taxonomy_datasets = _to_typed_list(self.taxonomy.datasets)
168+
self._taxonomy_metrics = _to_typed_list(self.taxonomy.metrics)
169+
151170
self.extract_acronyms = AcronymExtractor()
152171
self.context_noise = context_noise
153172
self.metrics_noise = metrics_noise if metrics_noise else context_noise
@@ -174,10 +193,10 @@ def _numba_extend_list(self, lst):
174193
l.append(x)
175194
return l
176195

177-
def compute_context_logprobs(self, context, noise, ms_noise, ts_noise, logprobs):
196+
def compute_context_logprobs(self, context, noise, ms_noise, ts_noise, logprobs, axes_logprobs):
178197
context = context or ""
179198
abbrvs = self.extract_acronyms(context)
180-
context = normalize_cell_ws(normalize_dataset(context))
199+
context = normalize_cell_ws(normalize_dataset_ws(context))
181200
dss = set(self.evidence_finder.find_datasets(context)) | set(abbrvs.keys())
182201
mss = set(self.evidence_finder.find_metrics(context))
183202
tss = set(self.evidence_finder.find_tasks(context))
@@ -191,21 +210,34 @@ def compute_context_logprobs(self, context, noise, ms_noise, ts_noise, logprobs)
191210
dss = self._numba_extend_list(dss)
192211
mss = self._numba_extend_list(mss)
193212
tss = self._numba_extend_list(tss)
194-
compute_logprobs(self._taxonomy, self.reverse_merged_p, self.reverse_metrics_p, self.reverse_tasks_p,
195-
dss, mss, tss, noise, ms_noise, ts_noise, self.ds_pb, self.ms_pb, self.ts_pb, logprobs)
213+
compute_logprobs(self._taxonomy, self._taxonomy_tasks, self._taxonomy_datasets, self._taxonomy_metrics,
214+
self.reverse_merged_p, self.reverse_metrics_p, self.reverse_tasks_p,
215+
dss, mss, tss, noise, ms_noise, ts_noise, self.ds_pb, self.ms_pb, self.ts_pb, logprobs,
216+
axes_logprobs)
196217

197218
def match(self, contexts):
198219
assert len(contexts) == len(self.context_noise)
199220
n = len(self._taxonomy)
200221
context_logprobs = np.zeros(n)
222+
axes_context_logprobs = _to_typed_list([
223+
np.zeros(len(self._taxonomy_tasks)),
224+
np.zeros(len(self._taxonomy_datasets)),
225+
np.zeros(len(self._taxonomy_metrics)),
226+
])
201227

202228
for context, noise, ms_noise, ts_noise in zip(contexts, self.context_noise, self.metrics_noise, self.task_noise):
203-
self.compute_context_logprobs(context, noise, ms_noise, ts_noise, context_logprobs)
229+
self.compute_context_logprobs(context, noise, ms_noise, ts_noise, context_logprobs, axes_context_logprobs)
204230
keys = self.taxonomy.taxonomy
205231
logprobs = context_logprobs
206232
#keys, logprobs = zip(*context_logprobs.items())
207233
probs = softmax(np.array(logprobs))
208-
return zip(keys, probs)
234+
axes_probs = [softmax(np.array(a)) for a in axes_context_logprobs]
235+
return (
236+
zip(keys, probs),
237+
zip(self.taxonomy.tasks, axes_probs[0]),
238+
zip(self.taxonomy.datasets, axes_probs[1]),
239+
zip(self.taxonomy.metrics, axes_probs[2])
240+
)
209241

210242
def __call__(self, query, datasets, caption, topk=1, debug_info=None):
211243
cellstr = debug_info.cell.cell_ext_id
@@ -229,8 +261,10 @@ def __call__(self, query, datasets, caption, topk=1, debug_info=None):
229261
###print("Taking result from cache")
230262
p = self.queries[key]
231263
else:
232-
dist = self.match((datasets, caption, query))
233-
top_results = sorted(dist, key=lambda x: x[1], reverse=True)[:max(topk, 5)]
264+
dists = self.match((datasets, caption, query))
265+
266+
all_top_results = [sorted(dist, key=lambda x: x[1], reverse=True)[:max(topk, 5)] for dist in dists]
267+
top_results, top_results_t, top_results_d, top_results_m = all_top_results
234268

235269
entries = []
236270
for it, prob in top_results:
@@ -239,6 +273,16 @@ def __call__(self, query, datasets, caption, topk=1, debug_info=None):
239273
entry.update({"evidence": "", "confidence": prob})
240274
entries.append(entry)
241275

276+
# entries = []
277+
# for i in range(5):
278+
# best_independent = dict(
279+
# task=top_results_t[i][0],
280+
# dataset=top_results_d[i][0],
281+
# metric=top_results_m[i][0])
282+
# best_independent.update({"evidence": "", "confidence": top_results_t[i][1]})
283+
# entries.append(best_independent)
284+
#entries = [best_independent] + entries
285+
242286
# best, best_p = sorted(dist, key=lambda x: x[1], reverse=True)[0]
243287
# entry = et[best]
244288
# p = pd.DataFrame({k:[v] for k, v in entry.items()})
@@ -283,5 +327,5 @@ def from_paper(self, paper):
283327
return self(text)
284328

285329
def __call__(self, text):
286-
text = normalize_cell_ws(normalize_dataset(text))
330+
text = normalize_cell_ws(normalize_dataset_ws(text))
287331
return self.evidence_finder.find_datasets(text) | self.evidence_finder.find_tasks(text)

sota_extractor2/models/linking/taxonomy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,5 @@ def _read_metrics_info(self, path):
6666
s[rng] = s.get(rng, 0) + 1
6767
mr[metric] = s
6868
for metric in mr:
69-
metrics_range[metric] = sorted(mr[metric].items(), key=lambda x: x[1])[-1]
69+
metrics_range[metric] = sorted(mr[metric].items(), key=lambda x: x[1])[-1][0]
7070
return metrics_info, metrics_range

0 commit comments

Comments
 (0)