Skip to content

Commit 2cea810

Browse files
committed
Implement filters optimization
1 parent ac2a5b5 commit 2cea810

File tree

3 files changed

+276
-5
lines changed

3 files changed

+276
-5
lines changed

sota_extractor2/helpers/explainers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
from ..helpers.jupyter import table_to_html
77
from sota_extractor2.models.linking.format import extract_value
8+
from sota_extractor2.helpers.optimize import optimize_filters
89

910

1011
class Reason:
@@ -202,4 +203,8 @@ def linking_metrics(self, experiment_name="unk"):
202203
del df["experiment_name"]
203204

204205
metrics = Metrics(df, experiment_name=experiment_name)
205-
return metrics
206+
return metrics
207+
208+
def optimize_filters(self, metrics_info):
209+
results = optimize_filters(self, metrics_info)
210+
return results

sota_extractor2/helpers/optimize.py

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
import pandas as pd, numpy as np
2+
from dataclasses import dataclass, replace
3+
from sota_extractor2.models.linking.metrics import CM
4+
from matplotlib import pyplot as plt
5+
6+
7+
def annotations(matrix, structure, r, c, type='model'):
8+
ann = []
9+
for nc in range(0, c):
10+
if type in structure[r, nc]:
11+
ann.append(matrix[r, nc])
12+
for nr in range(0, r):
13+
if type in structure[nr, c]:
14+
ann.append(matrix[nr, c])
15+
return ' '.join(ann)
16+
17+
18+
def estimate_noises(extracted_values, gold_values, short_forms):
19+
if not len(extracted_values):
20+
return {}
21+
extracted_values = set(extracted_values)
22+
gold_values = set(gold_values)
23+
24+
return {gold: 1 - len(extracted_values & set(short_forms.get(gold, set()))) / len(extracted_values) for gold in
25+
gold_values}
26+
27+
28+
def estimate_context_noise(context, records):
29+
context = context or ""
30+
abbrvs = context_search.extract_acronyms(context)
31+
context = normalize_cell_ws(normalize_dataset(context))
32+
dss = set(cs.find_datasets(context)) | set(abbrvs.keys())
33+
mss = set(cs.find_metrics(context))
34+
dss -= mss
35+
dss = set([normalize_cell(ds) for ds in dss])
36+
mss = set([normalize_cell(ms) for ms in mss])
37+
38+
gold_ds = set(records.dataset.values)
39+
gold_ms = set(records.metric.values)
40+
ds_noises = estimate_noises(dss, gold_ds, cs.datasets)
41+
ms_noises = estimate_noises(mss, gold_ms, cs.metrics)
42+
43+
return ds_noises, ms_noises
44+
45+
46+
def estimate_paper_context_noise(paper, gold_sota_records):
47+
records = gold_sota_records[gold_sota_records.paper_id == paper.paper_id]
48+
datasets = de.from_paper(paper)
49+
context = " ".join(datasets)
50+
return estimate_context_noise(context, records)
51+
52+
53+
def estimate_caption_context_noise(paper, table, gold_sota_records):
54+
table_ext_id = f"{paper.paper_id}/{table.name}/"
55+
records = gold_sota_records[gold_sota_records.index.str.startswith(table_ext_id)]
56+
return estimate_context_noise(table.caption, records)
57+
58+
59+
def estimate_cell_context_noise(paper, table, row, col, gold_sota_records):
60+
cell_ext_id = f"{paper.paper_id}/{table.name}/{row}.{col}"
61+
records = gold_sota_records[gold_sota_records.index == cell_ext_id]
62+
value = annotations(table.matrix.values, table.matrix_gold_tags.values, row, col, 'dataset')
63+
return estimate_context_noise(value, records)
64+
65+
66+
def average_dicts(dicts):
67+
sums = {}
68+
for d in dicts:
69+
for k, v in d.items():
70+
sums.setdefault(k, []).append(v)
71+
return {k: np.mean(v) for k, v in sums.items()}
72+
73+
74+
def all_equal(row):
75+
cols = ["model_type", "dataset", "metric", "task", "parsed"]
76+
return np.all([row[f"{name}_pred"] == row[f"{name}_gold"] for name in cols])
77+
78+
79+
def merge_gold_records(explainer):
80+
paper_ids = list(explainer.le.proposals.keys())
81+
82+
proposals = pd.concat(explainer.le.proposals.values())
83+
84+
papers = {paper_id: explainer.paper_collection.get_by_id(paper_id) for paper_id in paper_ids}
85+
missing = [paper_id for paper_id, paper in papers.items() if paper is None]
86+
if missing:
87+
print("Missing papers in paper collection:")
88+
print(", ".join(missing))
89+
papers = [paper for paper in papers.values() if paper is not None]
90+
91+
if explainer.gold_sota_records is None:
92+
print("gold_sota_records is missing")
93+
return
94+
else:
95+
gold_sota_records = explainer.gold_sota_records
96+
which = gold_sota_records.index.to_series().str.split("/", expand=True)[0] \
97+
.isin([paper.paper_id for paper in papers])
98+
gold_sota_records = gold_sota_records[which]
99+
100+
df = gold_sota_records.merge(proposals, 'outer', left_index=True, right_index=True, suffixes=['_gold', '_pred'])
101+
df = df.reindex(sorted(df.columns), axis=1)
102+
df.confidence = df.confidence.fillna(0.0)
103+
df = df.fillna('not-present')
104+
df["equal"] = df.apply(all_equal, axis=1)
105+
df["pred_positive"] = df["model_type_pred"].str.contains("model-best")
106+
df["gold_positive"] = df["model_type_gold"].str.contains("model-best")
107+
return df
108+
109+
110+
def find_threshold_intervals(proposals, metrics_info, context="paper"):
111+
# maximal threshold to have this proposal returned
112+
proposals["max_threshold"] = proposals.confidence
113+
114+
proposals["min_threshold"] = 0.0
115+
116+
ignore = (proposals.model_type_pred != 'model-best') | (proposals.struct_model_type == '') | \
117+
(proposals.struct_dataset.str.contains('dev')) | (proposals.struct_dataset.str.contains('train'))
118+
119+
# this proposal won't be ever returned due to structure or model type filters
120+
proposals.loc[ignore, "min_threshold"] = 1.0
121+
proposals.loc[ignore, "max_threshold"] = 0.0
122+
123+
all_proposals = proposals
124+
proposals = proposals[~ignore]
125+
126+
if context == "paper":
127+
context_column = proposals.index.to_series().str.split('/', expand=False).apply(lambda x: x[0])
128+
else:
129+
context_column = proposals.index.to_series().str.split('/', expand=False).apply(lambda x: x[0] + "/" + x[1])
130+
131+
for i, p in proposals.iterrows():
132+
key = (p.task_pred, p.dataset_pred, p.metric_pred)
133+
proposals_context = proposals[context_column == context_column[p.name]]
134+
proposals_context = proposals_context[~proposals_context.parsed_pred.isna()]
135+
proposals_context = proposals_context[
136+
(proposals_context.task_pred == p.task_pred) &
137+
(proposals_context.dataset_pred == p.dataset_pred) &
138+
(proposals_context.metric_pred == p.metric_pred)
139+
]
140+
d = 0
141+
if key in metrics_info:
142+
d = metrics_info[key]
143+
elif p.metric_pred in metrics_info:
144+
d = metrics_info[p.metric_pred]
145+
elif 'error' in p.metric_pred.lower():
146+
d = -1
147+
elif 'accuracy' in p.metric_pred.lower():
148+
d = 1
149+
150+
if d >= 0:
151+
d = 1
152+
else:
153+
d = -1
154+
155+
# the minimal threshold above which all superior results are ignored
156+
which = d * proposals_context.parsed_pred > d * p.parsed_pred
157+
if np.any(which.values):
158+
all_proposals.at[i, "min_threshold"] = proposals_context[which].confidence.values.max()
159+
else:
160+
which = proposals_context[proposals_context.parsed_pred == p.parsed_pred].iloc[0]
161+
if which.name != p.name:
162+
all_proposals.at[i, "min_threshold"] = which.confidence
163+
164+
return all_proposals
165+
166+
167+
def update_cm(proposal, cm, is_activated):
168+
d = 1 if is_activated else -1
169+
if proposal.equal and proposal.pred_positive and proposal.gold_positive:
170+
cm = replace(cm, tp=cm.tp + d, fn=cm.fn - d)
171+
if proposal.equal and not proposal.pred_positive and not proposal.gold_positive:
172+
cm = replace(cm, tn=cm.tn + d)
173+
if proposal.pred_positive and (not proposal.equal or not proposal.gold_positive):
174+
cm = replace(cm, fp=cm.fp + d)
175+
# if proposal.gold_positive and (not proposal.equal or not proposal.pred_positive):
176+
# cm = replace(cm, fn = cm.fn+d)
177+
return cm
178+
179+
180+
def sweep_thresholds(df):
181+
cm = CM(fn=sum(df.gold_positive))
182+
df = df[df.min_threshold < df.max_threshold]
183+
184+
sweeps = df.reset_index().melt(id_vars="cell_ext_id", value_vars=["min_threshold", "max_threshold"],
185+
var_name="threshold_type", value_name="threshold")
186+
187+
sweeps = sweeps.sort_values(by=["threshold", "threshold_type"]).reset_index(drop=True)
188+
189+
steps = sweeps.threshold.drop_duplicates().index
190+
191+
results = []
192+
for i, idx1 in enumerate(steps[:-1]):
193+
th1 = sweeps.threshold[idx1]
194+
195+
to_restore = cm
196+
for j, idx2 in enumerate(steps[i + 1:], i + 1):
197+
th2 = sweeps.threshold[idx2]
198+
precision = cm.tp / (cm.tp + cm.fp + 1e-8)
199+
recall = cm.tp / (cm.tp + cm.fn + 1e-8)
200+
f1 = 2 * precision * recall / (precision + recall + 1e-8)
201+
202+
result = dict(threshold1=th1, threshold2=sweeps.threshold[idx2 - 1], tp=cm.tp, tn=cm.tn, fp=cm.fp, fn=cm.fn,
203+
precision=precision, recall=recall, f1=f1)
204+
results.append(result)
205+
for _, row in sweeps[sweeps.threshold == sweeps.threshold[idx2 - 1]].iterrows():
206+
proposal = df.loc[row.cell_ext_id]
207+
is_activated = row.threshold_type == 'min_threshold'
208+
if not is_activated and proposal.min_threshold < th1:
209+
cm = update_cm(proposal, cm, is_activated)
210+
211+
precision = cm.tp / (cm.tp + cm.fp + 1e-8)
212+
recall = cm.tp / (cm.tp + cm.fn + 1e-8)
213+
f1 = 2 * precision * recall / (precision + recall + 1e-8)
214+
215+
result = dict(threshold1=th1, threshold2=th2, tp=cm.tp, tn=cm.tn, fp=cm.fp, fn=cm.fn,
216+
precision=precision, recall=recall, f1=f1)
217+
results.append(result)
218+
219+
cm = to_restore
220+
221+
for _, row in sweeps[sweeps.threshold == th1].iterrows():
222+
proposal = df.loc[row.cell_ext_id]
223+
224+
is_activated = row.threshold_type == 'min_threshold'
225+
cm = update_cm(proposal, cm, is_activated)
226+
227+
return df, sweeps, steps, pd.DataFrame(results)
228+
229+
230+
class PRResults:
231+
def __init__(self, results):
232+
self.results = results
233+
234+
def plot(self):
235+
plt.figure(figsize=(6, 6))
236+
plt.plot(self.results["precision"], self.results["recall"], '.')
237+
plt.xlabel("precision")
238+
plt.ylabel("recall")
239+
240+
def _best(self, results, metric):
241+
b = results.loc[results[metric].idxmax()]
242+
x = ["precision", "recall", "f1"]
243+
x.remove(metric)
244+
y = [b[m] for m in x]
245+
print(f"Best {metric}={b[metric]:0.2f} (with {x[0]}={y[0]:.2f} and {x[1]}={y[1]:.2f})"
246+
f" is achieved with threshold1={b.threshold1} and threshold2={b.threshold2}")
247+
248+
def best(self, min_precision=0, min_recall=0, min_f1=0):
249+
results = self.results
250+
results = results[
251+
(results.precision >= min_precision) &
252+
(results.recall >= min_recall) &
253+
(results.f1 >= min_f1)
254+
]
255+
if not len(results):
256+
print("No results with this criteria")
257+
else:
258+
self._best(results, "precision")
259+
self._best(results, "recall")
260+
self._best(results, "f1")
261+
262+
def optimize_filters(explainer, metrics_info):
263+
df = merge_gold_records(explainer)
264+
df = find_threshold_intervals(df, metrics_info, context="paper")
265+
df, sweeps, steps, results = sweep_thresholds(df)
266+
return PRResults(results)

sota_extractor2/models/linking/metrics.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
@dataclass
1010
class CM:
11-
tp: float
12-
fn: float
13-
fp: float
14-
tn: float
11+
tp: float = 0
12+
fn: float = 0
13+
fp: float = 0
14+
tn: float = 0
1515

1616
class Metrics:
1717
def __init__(self, df, experiment_name="unk"):

0 commit comments

Comments
 (0)