Skip to content

Commit ed6aa99

Browse files
author
Marcin Kardas
committed
Add evaluation and extraction helpers
1 parent 9ab8a8e commit ed6aa99

File tree

3 files changed

+149
-0
lines changed

3 files changed

+149
-0
lines changed

axcell/helpers/evaluate.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2+
3+
import re
4+
import pandas as pd
5+
6+
from axcell.data.paper_collection import remove_arxiv_version
7+
8+
9+
def norm_score_str(x):
10+
x = str(x)
11+
if re.match('^(\+|-|)(\d+)\.9{5,}$', x):
12+
x = re.sub('^(\+|-|)(\d+)\.9{5,}$', lambda a: a.group(1)+str(int(a.group(2))+1), x)
13+
elif x.endswith('9' * 5) and '.' in x:
14+
x = re.sub(r'([0-8])9+$', lambda a: str(int(a.group(1))+1), x)
15+
if '.' in x:
16+
x = re.sub(r'0+$', '', x)
17+
if x[-1] == '.':
18+
x = x[:-1]
19+
if x == '-0':
20+
x = '0'
21+
return x
22+
23+
24+
epsilon = 1e-10
25+
26+
27+
def precision(tp, fp):
28+
pred_positives = tp + fp + epsilon
29+
return ((1.0 * tp) / pred_positives)
30+
31+
32+
def recall(tp, fn):
33+
true_positives = tp + fn + epsilon
34+
return ((1.0 * tp) / true_positives)
35+
36+
37+
def f1(prec, recall):
38+
norm = prec + recall + epsilon
39+
return (2 * prec * recall / norm)
40+
41+
42+
def stats(predictions, ground_truth, axis=None):
43+
gold = pd.DataFrame(ground_truth, columns=["paper", "task", "dataset", "metric", "value"])
44+
pred = pd.DataFrame(predictions, columns=["paper", "task", "dataset", "metric", "value"])
45+
46+
if axis == 'tdm':
47+
columns = ['paper', 'task', 'dataset', 'metric']
48+
elif axis == 'tdms' or axis is None:
49+
columns = ['paper', 'task', 'dataset', 'metric', 'value']
50+
else:
51+
columns = ['paper', axis]
52+
gold = gold[columns].drop_duplicates()
53+
pred = pred[columns].drop_duplicates()
54+
55+
results = gold.merge(pred, on=columns, how="outer", indicator=True)
56+
57+
is_correct = results["_merge"] == "both"
58+
no_pred = results["_merge"] == "left_only"
59+
no_gold = results["_merge"] == "right_only"
60+
61+
results["TP"] = is_correct.astype('int8')
62+
results["FP"] = no_gold.astype('int8')
63+
results["FN"] = no_pred.astype('int8')
64+
65+
m = results.groupby(["paper"]).agg({"TP": "sum", "FP": "sum", "FN": "sum"})
66+
m["precision"] = precision(m.TP, m.FP)
67+
m["recall"] = recall(m.TP, m.FN)
68+
m["f1"] = f1(m.precision, m.recall)
69+
70+
TP_ALL = m.TP.sum()
71+
FP_ALL = m.FP.sum()
72+
FN_ALL = m.FN.sum()
73+
74+
prec, reca = precision(TP_ALL, FP_ALL), recall(TP_ALL, FN_ALL)
75+
return {
76+
'Micro Precision': prec,
77+
'Micro Recall': reca,
78+
'Micro F1': f1(prec, reca),
79+
'Macro Precision': m.precision.mean(),
80+
'Macro Recall': m.recall.mean(),
81+
'Macro F1': m.f1.mean()
82+
}
83+
84+
85+
def evaluate(predictions, ground_truth):
86+
predictions = predictions.copy()
87+
ground_truth = ground_truth.copy()
88+
predictions['value'] = predictions['score' if 'score' in predictions else 'value'].apply(norm_score_str)
89+
ground_truth['value'] = ground_truth['score' if 'score' in ground_truth else 'value'].apply(norm_score_str)
90+
predictions['paper'] = predictions['arxiv_id'].apply(remove_arxiv_version)
91+
ground_truth['paper'] = ground_truth['arxiv_id'].apply(remove_arxiv_version)
92+
93+
metrics = []
94+
for axis in [None, "tdm", "task", "dataset", "metric"]:
95+
s = stats(predictions, ground_truth, axis)
96+
s['type'] = {'tdms': 'TDMS', 'tdm': 'TDM', 'task': 'Task', 'dataset': 'Dataset', 'metric': 'Metric'}.get(axis)
97+
metrics.append(s)
98+
columns = ['Micro Precision', 'Micro Recall', 'Micro F1', 'Macro Precision', 'Macro Recall', 'Macro F1']
99+
return pd.DataFrame(metrics, columns=columns)

axcell/helpers/results_extractor.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2+
3+
from axcell.data.structure import CellEvidenceExtractor
4+
from axcell.models.structure import TableType, TableStructurePredictor, TableTypePredictor
5+
from axcell.models.linking import *
6+
from pathlib import Path
7+
8+
9+
class ResultsExtractor:
10+
def __init__(self, models_path):
11+
models_path = Path(models_path)
12+
self.cell_evidences = CellEvidenceExtractor()
13+
self.ttp = TableTypePredictor(models_path, "table-type-classifier.pth")
14+
self.tsp = TableStructurePredictor(models_path, "table-structure-classifier.pth")
15+
self.taxonomy = Taxonomy(taxonomy=models_path / "taxonomy.json", metrics_info=models_path / "metrics.json")
16+
17+
self.evidence_finder = EvidenceFinder(self.taxonomy, abbreviations_path=models_path / "abbreviations.json")
18+
self.context_search = ContextSearch(self.taxonomy, self.evidence_finder)
19+
self.dataset_extractor = DatasetExtractor(self.evidence_finder)
20+
21+
self.linker = Linker("linking", self.context_search, self.dataset_extractor)
22+
self.filters = StructurePredictionFilter() >> ConfidenceFilter(0.8) >> \
23+
BestResultFilter(self.taxonomy, context="paper") >> ConfidenceFilter(0.85)
24+
25+
def __call__(self, paper, tables=None, in_place=False):
26+
if tables is None:
27+
tables = paper.tables
28+
tables_types = self.ttp.predict(paper, tables)
29+
if in_place:
30+
types = {
31+
TableType.SOTA: 'leaderboard',
32+
TableType.ABLATION: 'ablation',
33+
TableType.IRRELEVANT: 'irrelevant'
34+
}
35+
for table, table_type in zip(paper.tables, tables_types):
36+
table.gold_tags = types[table_type]
37+
sota_tables = [
38+
table for table, table_type in zip(paper.tables, tables_types)
39+
if table_type != TableType.IRRELEVANT
40+
]
41+
paper.text.save()
42+
evidences = self.cell_evidences(paper, sota_tables)
43+
labeled_tables = self.tsp.label_tables(paper, sota_tables, evidences, in_place=in_place, use_crf=False)
44+
45+
proposals = self.linker(paper, labeled_tables)
46+
proposals = self.filters(proposals)
47+
proposals = proposals[["dataset", "metric", "task", "model", "parsed"]] \
48+
.reset_index(drop=True).rename(columns={"parsed": "score"})
49+
return proposals

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ dependencies:
2020
- docker-py=4.1.0
2121
- python-magic=0.4.15
2222
- html5lib=1.0.1
23+
- seaborn

0 commit comments

Comments
 (0)