Skip to content

Commit fca59f1

Browse files
author
Marcin Kardas
committed
Refactor linking error analysis
1 parent 006dbaf commit fca59f1

File tree

10 files changed

+287
-22
lines changed

10 files changed

+287
-22
lines changed

sota_extractor2/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
# otherwise use this files
1212
data = Path("/mnt/efs/pwc/data")
13-
goldtags_dump = data / "dumps" / "goldtags-2019.09.13_0219.json.gz"
13+
goldtags_dump = data / "dumps" / "goldtags-2019.10.15_2227.json.gz"
1414

1515

1616
elastic = dict(hosts=['localhost'], timeout=20)

sota_extractor2/data/paper_collection.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,12 @@ def _load_tables(path, annotations, jobs, migrate):
7373
return {f.parent.name: tbls for f, tbls in zip(files, tables)}
7474

7575

76-
def _load_annotated_papers(path):
77-
dump = load_gql_dump(path, compressed=path.suffix == ".gz")["allPapers"]
76+
def _load_annotated_papers(data_or_path):
77+
if isinstance(data_or_path, dict):
78+
compressed = False
79+
else:
80+
compressed = data_or_path.suffix == ".gz"
81+
dump = load_gql_dump(data_or_path, compressed=compressed)["allPapers"]
7882
annotations = {remove_arxiv_version(a.arxiv_id): a for a in dump}
7983
annotations.update({a.arxiv_id: a for a in dump})
8084
return annotations

sota_extractor2/data/table.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import re
66
from dataclasses import dataclass, field
77
from typing import List
8-
from ..helpers.jupyter import display_table
8+
from ..helpers.jupyter import display_html, table_to_html
99
from copy import deepcopy
1010

1111

@@ -21,6 +21,7 @@ class Cell:
2121
reference_re = re.compile(r"<ref id='([^']*)'>(.*?)</ref>")
2222
num_re = re.compile(r"^\d+$")
2323

24+
2425
def extract_references(s):
2526
parts = reference_re.split(s)
2627
refs = parts[1::3]
@@ -89,10 +90,16 @@ def __init__(self, name, df, layout, caption=None, figure_id=None, annotations=N
8990
if layout is not None:
9091
self.set_layout(layout)
9192

93+
self._set_annotations(annotations, migrate=migrate, old_name=old_name, guessed_tags=guessed_tags)
94+
95+
def _set_annotations(self, annotations, migrate=False, old_name=None, guessed_tags=None):
9296
if annotations is not None:
9397
self.gold_tags = annotations.gold_tags.strip()
9498
self.dataset_text = annotations.dataset_text.strip()
9599
self.notes = annotations.notes.strip()
100+
101+
sota_records = json.loads(annotations.cells_sota_records)
102+
96103
if guessed_tags is not None:
97104
tags = guessed_tags.values
98105
else:
@@ -117,14 +124,24 @@ def __init__(self, name, df, layout, caption=None, figure_id=None, annotations=N
117124
self.gold_tags = ''
118125
self.dataset_text = ''
119126
self.notes = ''
127+
sota_records = {}
128+
129+
sota_records = pd.DataFrame(sota_records.values(), index=sota_records.keys(),
130+
columns=['task', 'dataset', 'metric', 'format', 'model', 'value'])
131+
sota_records.index = self.name + "/" + sota_records.index
132+
sota_records.index.rename("cell_ext_id", inplace=True)
133+
sota_records.rename(columns={"value": "raw_value"}, inplace=True)
134+
135+
self.sota_records = sota_records.replace("", np.nan).dropna(subset=["model", "metric", "task", "dataset"])
136+
120137

121138
def set_layout(self, layout):
122139
for r, row in layout.iterrows():
123140
for c, cell in enumerate(row):
124141
self.df.iloc[r, c].layout = cell
125142

126143
def set_tags(self, tags):
127-
for r, row in tags.iterrows():
144+
for r, row in enumerate(tags):
128145
for c, cell in enumerate(row):
129146
# todo: change gold_tags to tags to avoid confusion
130147
self.df.iloc[r,c].gold_tags = cell.strip()
@@ -133,6 +150,10 @@ def set_tags(self, tags):
133150
def matrix(self):
134151
return self.df.applymap(lambda x: x.value)
135152

153+
@property
154+
def matrix_html(self):
155+
return self.df.applymap(lambda x: raw_value_to_html(x.raw_value))
156+
136157
@property
137158
def matrix_layout(self):
138159
return self.df.applymap(lambda x: x.layout)
@@ -169,8 +190,11 @@ def from_file(cls, path, metadata, annotations=None, migrate=False, match_name=N
169190
table_ann = None
170191
return cls(metadata['filename'], df, layout, metadata.get('caption'), metadata.get('figure_id'), table_ann, migrate, match_name, guessed_tags)
171192

193+
def _repr_html_(self):
194+
return table_to_html(self.matrix_html.values, self.matrix_tags.values, self.matrix_layout.values)
195+
172196
def display(self):
173-
display_table(self.df.applymap(lambda x: raw_value_to_html(x.raw_value)).values, self.df.applymap(lambda x: x.gold_tags).values, self.df.applymap(lambda x:x.layout).values)
197+
display_html(self._repr_html_())
174198

175199
def _save_df(self, df, filename):
176200
df.to_csv(filename, header=None, index=None)

sota_extractor2/helpers/explainers.py

Lines changed: 160 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,98 @@
1+
from sota_extractor2.models.linking.metrics import Metrics
12
from ..models.structure import TableType
23
from ..loggers import StructurePredictionEvaluator, LinkerEvaluator, FilteringEvaluator
34
import pandas as pd
5+
import numpy as np
6+
from ..helpers.jupyter import table_to_html
7+
from sota_extractor2.models.linking.format import extract_value
48

59

6-
class TableTypeExplainer:
10+
class Reason:
11+
pass
12+
13+
14+
class IrrelevantTable(Reason):
715
def __init__(self, paper, table, table_type, probs):
816
self.paper = paper
917
self.table = table
1018
self.table_type = table_type
1119
self.probs = pd.DataFrame(probs, columns=["type", "probability"])
1220

1321
def __str__(self):
14-
return f"Table {self.table.name} was labelled as {self.table_type}."
22+
return f"Table {self.table.name} was labelled as {self.table_type.name}."
23+
24+
def _repr_html_(self):
25+
prediction = f'<div>{self}</div>'
26+
caption = f'<div>Caption: {self.table.caption}</div>'
27+
probs = self.probs.style.format({"probability": "{:.2f}"})._repr_html_()
28+
return prediction + caption + probs
29+
30+
31+
class MislabeledCell(Reason):
32+
def __init__(self, paper, table, row, col, probs):
33+
self.paper = paper
34+
self.table = table
35+
36+
37+
class TableExplanation:
38+
def __init__(self, paper, table, table_type, proposals, reasons, topk):
39+
self.paper = paper
40+
self.table = table
41+
self.table_type = table_type
42+
self.proposals = proposals
43+
self.reasons = reasons
44+
self.topk = topk
45+
46+
def _format_tooltip(self, proposal):
47+
return f"dataset: {proposal.dataset}\n" \
48+
f"metric: {proposal.metric}\n" \
49+
f"task: {proposal.task}\n" \
50+
f"score: {proposal.parsed}\n" \
51+
f"confidence: {proposal.confidence:0.2f}"
52+
53+
def _format_topk(self, topk):
54+
return ""
1555

16-
def display(self):
17-
print(self)
18-
self.probs.display()
56+
def _repr_html_(self):
57+
matrix = self.table.matrix_html.values
58+
predictions = np.zeros_like(matrix, dtype=object)
59+
tooltips = np.zeros_like(matrix, dtype=object)
60+
for cell_ext_id, proposal in self.proposals.iterrows():
61+
paper_id, table_name, rc = cell_ext_id.split("/")
62+
row, col = [int(x) for x in rc.split('.')]
63+
if cell_ext_id in self.reasons:
64+
reason = self.reasons[cell_ext_id]
65+
tooltips[row, col] = reason
66+
if reason.startswith("replaced by "):
67+
tooltips[row, col] += "\n\n" + self._format_tooltip(proposal)
68+
elif reason.startswith("confidence "):
69+
tooltips[row, col] += "\n\n" + self._format_topk(self.topk[row, col])
70+
else:
71+
predictions[row, col] = 'final-proposal'
72+
tooltips[row, col] = self._format_tooltip(proposal)
73+
74+
table_type_html = f'<div>Table {self.table.name} was labelled as {self.table_type.name}.</div>'
75+
caption_html = f'<div>Caption: {self.table.caption}</div>'
76+
table_html = table_to_html(matrix,
77+
self.table.matrix_tags.values,
78+
self.table.matrix_layout.values,
79+
predictions,
80+
tooltips)
81+
html = table_type_html + caption_html + table_html
82+
proposals = self.proposals[~self.proposals.index.isin(self.reasons.index)]
83+
if len(proposals):
84+
proposals = proposals[["dataset", "metric", "task", "model", "parsed"]]\
85+
.reset_index(drop=True).rename(columns={"parsed": "score"})
86+
html2 = proposals._repr_html_()
87+
return f"<div><div>{html}</div><div>Proposals</div><div>{html2}</div></div>"
88+
return html
1989

2090

2191
class Explainer:
92+
_sota_record_columns = ['task', 'dataset', 'metric', 'format', 'model', 'model_type', 'raw_value', 'parsed']
93+
2294
def __init__(self, pipeline_logger, paper_collection):
95+
self.paper_collection = paper_collection
2396
self.spe = StructurePredictionEvaluator(pipeline_logger, paper_collection)
2497
self.le = LinkerEvaluator(pipeline_logger, paper_collection)
2598
self.fe = FilteringEvaluator(pipeline_logger)
@@ -29,15 +102,94 @@ def explain(self, paper, cell_ext_id):
29102
if paper.paper_id != paper_id:
30103
return "No such cell"
31104

32-
row, col = [int(x) for x in rc.split('.')]
33-
34105
table_type, probs = self.spe.get_table_type_predictions(paper_id, table_name)
35106

36107
if table_type == TableType.IRRELEVANT:
37-
return TableTypeExplainer(paper, paper.table_by_name(table_name), table_type, probs)
108+
return IrrelevantTable(paper, paper.table_by_name(table_name), table_type, probs)
109+
110+
all_proposals = self.le.proposals[paper_id]
111+
reasons = self.fe.reason
112+
table_ext_id = f"{paper_id}/{table_name}"
113+
table_proposals = all_proposals[all_proposals.index.str.startswith(table_ext_id+"/")]
114+
topk = {(row, col): topk for (pid, tn, row, col), topk in self.le.topk.items()
115+
if (pid, tn) == (paper_id, table_name)}
116+
117+
return TableExplanation(paper, paper.table_by_name(table_name), table_type, table_proposals, reasons, topk)
118+
119+
row, col = [int(x) for x in rc.split('.')]
38120

39121
reason = self.fe.reason.get(cell_ext_id)
40122
if reason is None:
41123
pass
42124
else:
43125
return reason
126+
127+
def _get_table_sota_records(self, table):
128+
129+
first_model = lambda x: ([a for a in x if a.startswith('model')] + [''])[0]
130+
if len(table.sota_records):
131+
matrix = table.matrix.values
132+
tags = table.matrix_tags
133+
model_type_col = tags.apply(first_model)
134+
model_type_row = tags.T.apply(first_model)
135+
sota_records = table.sota_records.copy()
136+
sota_records['model_type'] = ''
137+
sota_records['raw_value'] = ''
138+
for cell_ext_id, record in sota_records.iterrows():
139+
name, rc = cell_ext_id.split('/')
140+
row, col = [int(x) for x in rc.split('.')]
141+
record.model_type = model_type_col[col] or model_type_row[row]
142+
record.raw_value = matrix[row, col]
143+
144+
sota_records["parsed"] = sota_records[["raw_value", "format"]].apply(
145+
lambda row: float(extract_value(row.raw_value, row.format)), axis=1)
146+
147+
sota_records = sota_records[sota_records["parsed"] == sota_records["parsed"]]
148+
149+
strip_cols = ["task", "dataset", "format", "metric", "raw_value", "model", "model_type"]
150+
sota_records = sota_records.transform(
151+
lambda x: x.str.strip() if x.name in strip_cols else x)
152+
return sota_records[self._sota_record_columns]
153+
else:
154+
empty = pd.DataFrame(columns=self._sota_record_columns)
155+
empty.index.rename("cell_ext_id", inplace=True)
156+
return empty
157+
158+
def _get_sota_records(self, paper):
159+
if not len(paper.tables):
160+
empty = pd.DataFrame(columns=self._sota_record_columns)
161+
empty.index.rename("cell_ext_id", inplace=True)
162+
return empty
163+
records = [self._get_table_sota_records(table) for table in paper.tables]
164+
records = pd.concat(records)
165+
records.index = paper.paper_id + "/" + records.index
166+
records.index.rename("cell_ext_id", inplace=True)
167+
return records
168+
169+
def linking_metrics(self, experiment_name="unk"):
170+
paper_ids = list(self.le.proposals.keys())
171+
172+
proposals = pd.concat(self.le.proposals.values())
173+
proposals = proposals[~proposals.index.isin(self.fe.reason.index)]
174+
175+
papers = {paper_id: self.paper_collection.get_by_id(paper_id) for paper_id in paper_ids}
176+
missing = [paper_id for paper_id, paper in papers.items() if paper is None]
177+
if missing:
178+
print("Missing papers in paper collection:")
179+
print(", ".join(missing))
180+
papers = [paper for paper in papers.values() if paper is not None]
181+
182+
if not len(papers):
183+
gold_sota_records = pd.DataFrame(columns=self._sota_record_columns)
184+
gold_sota_records.index.rename("cell_ext_id", inplace=True)
185+
else:
186+
gold_sota_records = pd.concat([self._get_sota_records(paper) for paper in papers])
187+
188+
df = gold_sota_records.merge(proposals, 'outer', left_index=True, right_index=True, suffixes=['_gold', '_pred'])
189+
df = df.reindex(sorted(df.columns), axis=1)
190+
df = df.fillna('not-present')
191+
if "experiment_name" in df.columns:
192+
del df["experiment_name"]
193+
194+
metrics = Metrics(df, experiment_name=experiment_name)
195+
return metrics

sota_extractor2/helpers/jupyter.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from IPython.core.display import display, HTML
22
from .table_style import table_style
3+
import numpy as np
4+
5+
36
def set_seed(seed, name):
47
import torch
58
import numpy as np
@@ -9,11 +12,11 @@ def set_seed(seed, name):
912
torch.backends.cudnn.benchmark = False
1013
np.random.seed(seed)
1114

12-
def display_html(s): return display(HTML(s))
1315

16+
def display_html(s): return display(HTML(s))
1417

1518

16-
def display_table(table, structure=None, layout=None):
19+
def table_to_html(table, structure=None, layout=None, predictions=None, tooltips=None):
1720
"""
1821
matrix - 2d ndarray with cell values
1922
strucutre - 2d ndarray with structure annotation
@@ -24,15 +27,22 @@ def display_table(table, structure=None, layout=None):
2427
matrix = table
2528
if structure is None: structure = table.matrix_gold_tags
2629
if layout is None: layout = np.zeros_like(matrix, dtype=str)
30+
if predictions is None: predictions = np.zeros_like(matrix, dtype=str)
31+
if tooltips is None: tooltips = np.zeros_like(matrix, dtype=str)
2732
html = []
2833
html.append(table_style)
2934
html.append('<div class="tableWrapper">')
3035
html.append("<table>")
31-
for row,struc_row, layout_row in zip(matrix, structure, layout):
36+
for row,struc_row, layout_row, preds_row, tt_row in zip(matrix, structure, layout, predictions, tooltips):
3237
html.append("<tr>")
33-
for cell,struct,layout in zip(row,struc_row,layout_row):
34-
html.append(f'<td class="{struct} {layout}">{cell}</td>')
38+
for cell,struct,layout,preds, tt in zip(row,struc_row,layout_row,preds_row, tt_row):
39+
html.append(f'<td class="{struct} {layout} {preds}" title="{tt}">{cell}</td>')
3540
html.append("</tr>")
3641
html.append("</table>")
3742
html.append('</div>')
43+
return "\n".join(html)
44+
45+
46+
def display_table(table, structure=None, layout=None):
47+
html = table_to_html(table, structure, layout)
3848
display_html("\n".join(html))

0 commit comments

Comments
 (0)