Skip to content

Commit fdb06d0

Browse files
author
Marcin Kardas
committed
Better support for complementary metrics
1 parent 7c22a57 commit fdb06d0

File tree

6 files changed

+175
-11
lines changed

6 files changed

+175
-11
lines changed

sota_extractor2/helpers/latex_converter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def latex2html(self, source_dir, output_dir, use_named_volumes=False):
6060
except ContainerError as err:
6161
if err.exit_status == MAGIC_EXIT_ERROR:
6262
raise LatexConversionError("LaTeXML was unable to convert source code of this paper")
63+
if "Unable to find any suitable tex file" in err.stderr.decode('utf-8'):
64+
raise LatexConversionError("Unable to find any suitable tex file")
6365
raise
6466

6567
# todo: check for errors

sota_extractor2/models/linking/context_search.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def find_metrics(self, text):
8080
def find_tasks(self, text):
8181
return EvidenceFinder.find_names(text, self.all_tasks_trie)
8282

83-
def _init_structs(self, taxonomy):
83+
def init_evidence_dicts(self, taxonomy):
8484
self.tasks, self.datasets, self.metrics = EvidenceFinder.get_basic_dicts(taxonomy)
8585
EvidenceFinder.merge_evidences(self.tasks, manual_dicts.tasks)
8686
EvidenceFinder.merge_evidences(self.datasets, manual_dicts.datasets)
@@ -92,6 +92,9 @@ def _init_structs(self, taxonomy):
9292
'LibriSpeech dev-other': ['libri speech dev other', 'libri speech', 'dev', 'other', 'dev other', 'development', 'noisy'],
9393
})
9494

95+
def _init_structs(self, taxonomy):
96+
self.init_evidence_dicts(taxonomy)
97+
9598
self.datasets = {k: set(v) for k, v in self.datasets.items()}
9699
self.metrics = {k: set(v) for k, v in self.metrics.items()}
97100
self.tasks = {k: set(v) for k, v in self.tasks.items()}

sota_extractor2/models/linking/manual_dicts.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,15 +163,36 @@
163163

164164
tasks = {}
165165

166-
complementary_metrics = {
166+
complementary_metrics = {k.lower(): v for k, v in {
167167
'Accuracy': 'Error',
168168
'Error': 'Accuracy',
169+
'Acc': 'Err',
170+
'Err': 'Acc',
169171
'Percentage Error': 'Accuracy',
172+
'Error rate': 'Accuracy',
170173
'Word Error Rate': 'Word Accuracy',
171174
'Word Error Rate (WER)': 'Word Accuracy',
172175
'Top-1 Accuracy': 'Top-1 Error Rate',
173-
'Top-5 Accuracy': 'Top-5 Error',
174-
}
176+
'Top-3 Accuracy': 'Top-3 Error Rate',
177+
'Top-5 Accuracy': 'Top-5 Error Rate',
178+
'Top 1 Accuracy': 'Top 1 Error Rate',
179+
'Top 3 Accuracy': 'Top 3 Error Rate',
180+
'Top 5 Accuracy': 'Top 5 Error Rate',
181+
'Top-1 Error Rate': 'Top-1 Accuracy',
182+
'Top-3 Error Rate': 'Top-3 Accuracy',
183+
'Top-5 Error Rate': 'Top-5 Accuracy',
184+
'Top 1 Error Rate': 'Top 1 Accuracy',
185+
'Top 3 Error Rate': 'Top 3 Accuracy',
186+
'Top 5 Error Rate': 'Top 5 Accuracy',
187+
'Top-1 Error': 'Top-1 Accuracy',
188+
'Top-3 Error': 'Top-3 Accuracy',
189+
'Top-5 Error': 'Top-5 Accuracy',
190+
'Top 1 Error': 'Top 1 Accuracy',
191+
'Top 3 Error': 'Top 3 Accuracy',
192+
'Top 5 Error': 'Top 5 Accuracy',
193+
'Classification Accuracy': 'Classification Error',
194+
'Classification Error': 'Classification Accuracy',
195+
}.items()}
175196

176197
stop_words = {
177198
"a", "an", "and", "are", "as", "at", "be", "but", "by",

sota_extractor2/models/linking/taxonomy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ def _get_complementary_metrics(self):
2626
self._complementary = {}
2727
for record in self.canonical_records:
2828
metric = record["metric"]
29-
if metric in complementary_metrics:
29+
if metric.lower() in complementary_metrics:
3030
task = record["task"]
3131
dataset = record["dataset"]
32-
comp_metric = complementary_metrics[record["metric"]]
32+
comp_metric = complementary_metrics[metric.lower()]
3333
complementary.append(
3434
dict(
3535
task=task,

sota_extractor2/models/structure/experiment.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,14 @@ def show_results(self, *ds, normalize=True, full_cm=True):
297297
suffix = '_full' if full_cm and f'{prefix}_cm_full' in self.results else ''
298298
self._plot_confusion_matrix(np.array(self.results[f'{prefix}_cm{suffix}']), normalize=normalize)
299299

300+
def get_cm_labels(self, cm):
301+
if len(cm) == 6:
302+
target_names = ["OTHER", "DATASET", "MODEL (paper)", "MODEL (comp.)", "METRIC", "EMPTY"]
303+
else:
304+
target_names = ["OTHER", "params", "task", "DATASET", "subdataset", "MODEL (paper)", "model (best)",
305+
"model (ens.)", "MODEL (comp.)", "METRIC", "EMPTY"]
306+
return target_names
307+
300308
def _plot_confusion_matrix(self, cm, normalize, fmt=None):
301309
if normalize:
302310
s = cm.sum(axis=1)[:, None]
@@ -305,11 +313,7 @@ def _plot_confusion_matrix(self, cm, normalize, fmt=None):
305313
if fmt is None:
306314
fmt = "0.2f" if normalize else "d"
307315

308-
if len(cm) == 6:
309-
target_names = ["OTHER", "DATASET", "MODEL (paper)", "MODEL (comp.)", "METRIC", "EMPTY"]
310-
else:
311-
target_names = ["OTHER", "params", "task", "DATASET", "subdataset", "MODEL (paper)", "model (best)",
312-
"model (ens.)", "MODEL (comp.)", "METRIC", "EMPTY"]
316+
target_names = self.get_cm_labels(cm)
313317
df_cm = pd.DataFrame(cm, index=[i for i in target_names],
314318
columns=[i for i in target_names])
315319
plt.figure(figsize=(10, 10))

sota_extractor2/models/structure/ulmfit_experiment.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from .experiment import Experiment, label_map_ext
2+
from sota_extractor2.models.structure.nbsvm import *
3+
from sklearn.metrics import confusion_matrix
24
from .nbsvm import preds_for_cell_content, preds_for_cell_content_max, preds_for_cell_content_multi
35
import dataclasses
46
from dataclasses import dataclass
@@ -134,3 +136,135 @@ def evaluate(self, model, train_df, valid_df, test_df):
134136
true_y_ext = tdf["cell_type"].apply(lambda x: label_map_ext.get(x, 0))
135137
self._set_results(prefix, preds, true_y, true_y_ext)
136138
self._preds.append(probs)
139+
140+
141+
@dataclass
142+
class ULMFiTTableTypeExperiment(ULMFiTExperiment):
143+
sigmoid: bool = True
144+
distinguish_ablation: bool = True
145+
irrelevant_as_class: bool = False
146+
caption: bool = True
147+
first_row: bool = False
148+
first_column: bool = False
149+
referencing_sections: bool = False
150+
dedup_seqs: bool = False
151+
152+
def _save_model(self, path):
153+
pass
154+
155+
def _transform_df(self, df):
156+
df = df.copy(True)
157+
if self.sigmoid:
158+
if self.irrelevant_as_class:
159+
df["irrelevant"] = ~(df["sota"] | df["ablation"])
160+
if not self.distinguish_ablation:
161+
df["sota"] = df["sota"] | df["ablation"]
162+
df = df.drop(columns=["ablation"])
163+
else:
164+
if self.distinguish_ablation:
165+
df["class"] = 2
166+
df.loc[df.ablation, "class"] = 1
167+
df.loc[df.sota, "class"] = 0
168+
else:
169+
df["class"] = 1
170+
df.loc[df.sota, "class"] = 0
171+
df.loc[df.ablation, "class"] = 0
172+
173+
df["label"] = 2
174+
df.loc[df.ablation, "label"] = 1
175+
df.loc[df.sota, "label"] = 0
176+
drop_columns = []
177+
if not self.caption:
178+
drop_columns.append("caption")
179+
if not self.first_column:
180+
drop_columns.append("col0")
181+
if not self.first_row:
182+
drop_columns.append("row0")
183+
if not self.referencing_sections:
184+
drop_columns.append("sections")
185+
df = df.drop(columns=drop_columns)
186+
return df
187+
188+
def evaluate(self, model, train_df, valid_df, test_df):
189+
valid_probs = model.get_preds(ds_type=DatasetType.Valid, ordered=True)[0].cpu().numpy()
190+
test_probs = model.get_preds(ds_type=DatasetType.Test, ordered=True)[0].cpu().numpy()
191+
train_probs = model.get_preds(ds_type=DatasetType.Train, ordered=True)[0].cpu().numpy()
192+
self._preds = []
193+
194+
def multipreds2preds(preds, threshold=0.5):
195+
bs = preds.shape[0]
196+
return np.concatenate([probs, np.ones((bs, 1)) * threshold], axis=-1).argmax(-1)
197+
198+
for prefix, tdf, probs in zip(["train", "valid", "test"],
199+
[train_df, valid_df, test_df],
200+
[train_probs, valid_probs, test_probs]):
201+
202+
if self.sigmoid and not self.irrelevant_as_class:
203+
preds = multipreds2preds(probs)
204+
else:
205+
preds = np.argmax(probs, axis=1)
206+
if not self.distinguish_ablation:
207+
preds *= 2
208+
209+
true_y = tdf["label"]
210+
self._set_results(prefix, preds, true_y)
211+
self._preds.append(probs)
212+
213+
def _set_results(self, prefix, preds, true_y, true_y_ext=None):
214+
def metrics(preds, true_y):
215+
y = true_y
216+
p = preds
217+
218+
if self.distinguish_ablation:
219+
g = {0: 0, 1: 0, 2: 1}.get
220+
bin_y = np.array([g(x) for x in y])
221+
bin_p = np.array([g(x) for x in p])
222+
irr = 2
223+
else:
224+
bin_y = y
225+
bin_p = p
226+
irr = 1
227+
228+
acc = (p == y).mean()
229+
tp = ((y != irr) & (p == y)).sum()
230+
fp = ((p != irr) & (p != y)).sum()
231+
fn = ((y != irr) & (p == irr)).sum()
232+
233+
bin_acc = (bin_p == bin_y).mean()
234+
bin_tp = ((bin_y != 1) & (bin_p == bin_y)).sum()
235+
bin_fp = ((bin_p != 1) & (bin_p != bin_y)).sum()
236+
bin_fn = ((bin_y != 1) & (bin_p == 1)).sum()
237+
238+
prec = tp / (fp + tp)
239+
reca = tp / (fn + tp)
240+
bin_prec = bin_tp / (bin_fp + bin_tp)
241+
bin_reca = bin_tp / (bin_fn + bin_tp)
242+
return {
243+
"precision": prec,
244+
"accuracy": acc,
245+
"recall": reca,
246+
"TP": tp,
247+
"FP": fp,
248+
"bin_precision": bin_prec,
249+
"bin_accuracy": bin_acc,
250+
"bin_recall": bin_reca,
251+
"bin_TP": bin_tp,
252+
"bin_FP": bin_fp,
253+
}
254+
255+
m = metrics(preds, true_y)
256+
r = {}
257+
r[f"{prefix}_accuracy"] = m["accuracy"]
258+
r[f"{prefix}_precision"] = m["precision"]
259+
r[f"{prefix}_recall"] = m["recall"]
260+
r[f"{prefix}_bin_accuracy"] = m["bin_accuracy"]
261+
r[f"{prefix}_bin_precision"] = m["bin_precision"]
262+
r[f"{prefix}_bin_recall"] = m["bin_recall"]
263+
r[f"{prefix}_cm"] = confusion_matrix(true_y, preds).tolist()
264+
self.update_results(**r)
265+
266+
def get_cm_labels(self, cm):
267+
if len(cm) == 3:
268+
return ["SOTA", "ABLATION", "IRRELEVANT"]
269+
else:
270+
return ["SOTA", "IRRELEVANT"]

0 commit comments

Comments
 (0)