Skip to content

Commit 1378aa2

Browse files
author
Marcin Kardas
committed
Add support for complementary metrics
1 parent 53917e0 commit 1378aa2

File tree

5 files changed

+66
-20
lines changed

5 files changed

+66
-20
lines changed

sota_extractor2/models/linking/bm25_naive.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -230,33 +230,32 @@ def linked_proposals(proposals):
230230

231231
df = taxonomy_linking(prop.dataset, datasets, desc, topk=topk, debug_info=prop)
232232
for _, row in df.iterrows():
233+
raw_value = prop.raw_value
234+
parsed = float(extract_value(raw_value, format))
233235
metric = row['metric']
234-
# if ("error" in metric or "Error" in metric) and (first_num > 0.5):
235-
if (metric.strip().lower() == "error") and (first_num > 0.5):
236-
metric = "Accuracy"
236+
if metric != row['true_metric']:
237+
metric = row['true_metric']
238+
parsed = 1 - parsed if parsed < 1 else 100 - parsed
237239

238240
linked = {
239241
'dataset': row['dataset'],
240242
'metric': metric,
241243
'task': row['task'],
242244
'format': format,
243-
'raw_value': prop.raw_value,
245+
'raw_value': raw_value,
244246
'model': prop.model_name,
245247
'model_type': prop.model_type,
246248
'cell_ext_id': prop.cell.cell_ext_id,
247249
'confidence': row['confidence'],
248250
'struct_model_type': prop.model_type,
249-
'struct_dataset': prop.dataset
251+
'struct_dataset': prop.dataset,
252+
'parsed': parsed
250253
}
251254
yield linked
252255

253256
# specify columns in case there's no proposal
254257

255258
proposals = pd.DataFrame.from_records(list(linked_proposals(proposals)), columns=proposal_columns)
256-
257-
if len(proposals):
258-
proposals["parsed"]=proposals[["raw_value", "format"]].apply(
259-
lambda row: float(extract_value(row.raw_value, row.format)), axis=1)
260259
return proposals
261260

262261

sota_extractor2/models/linking/context_search.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def _init_structs(self, taxonomy):
104104
self.all_tasks_trie = EvidenceFinder.make_trie(self.all_tasks)
105105

106106

107-
@njit(inline="always")
107+
@njit
108108
def axis_logprobs(evidences_for, reverse_probs, found_evidences, noise, pb):
109109
logprob = 0.0
110110
empty = typed.Dict.empty(types.unicode_type, types.float64)
@@ -114,6 +114,7 @@ def axis_logprobs(evidences_for, reverse_probs, found_evidences, noise, pb):
114114
return logprob
115115

116116

117+
# compute log-probabilities in a given context and add them to logprobs
117118
@njit
118119
def compute_logprobs(taxonomy, reverse_merged_p, reverse_metrics_p, reverse_task_p,
119120
dss, mss, tss, noise, ms_noise, ts_noise, ds_pb, ms_pb, ts_pb, logprobs):
@@ -128,7 +129,7 @@ def compute_logprobs(taxonomy, reverse_merged_p, reverse_metrics_p, reverse_task
128129
if task not in task_cache:
129130
task_cache[task] = axis_logprobs(task, reverse_task_p, tss, ts_noise, ts_pb)
130131

131-
logprobs[i] = dataset_cache[dataset] + metric_cache[metric] + task_cache[task]
132+
logprobs[i] += dataset_cache[dataset] + metric_cache[metric] + task_cache[task]
132133

133134

134135
class ContextSearch:
@@ -262,7 +263,10 @@ def __call__(self, query, datasets, caption, topk=1, debug_info=None):
262263
print("[EA] No gold sota record found for the cell")
263264
# end of error analysis only
264265
pipeline_logger("linking::taxonomy_linking::topk", ext_id=cellstr, topk=p.head(5))
265-
return p.head(topk)
266+
267+
q = p.head(topk).copy()
268+
q["true_metric"] = q.apply(lambda row: self.taxonomy.normalize_metric(row.task, row.dataset, row.metric), axis=1)
269+
return q
266270

267271

268272
# todo: compare regex approach (old) with find_datasets(.) (current)

sota_extractor2/models/linking/format.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def format_to_regexp(format):
2929
return re.compile('^' + regexp), fn
3030

3131
def extract_value(cell_value, format):
32-
cell_value = re.sub(r"\s+%", "%", cell_value).replace(",", "").strip()
32+
cell_value = re.sub(r"\s+%", "%", cell_value).replace(",", "")
33+
cell_value = cell_value.replace("(", " ").replace(")", " ").strip()
3334
regexp, fn = format_to_regexp(format)
3435
match = regexp.match(cell_value)
3536
if match is None or not len(match.groups()):

sota_extractor2/models/linking/manual_dicts.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
metrics = {
2+
'Accuracy': ['acc', 'accuracy'],
23
'BLEU': ['bleu'],
34
'BLEU score': ['bleu'],
45
'Character Error Rate': ['cer', 'cers'],
5-
'Error': ['error'],
6+
'Error': ['error', 'err', 'error rate'],
67
'Exact Match Ratio': ['exact match'],
78
'F1': ['f1', 'f1 score'],
89
'F1 score': ['f1', 'f1 score'],
@@ -11,6 +12,7 @@
1112
'phoneme error rate', 'error', 'error rate', 'error rates'],
1213
'Word Error Rate': ['wer', 'wers', 'word error rate', 'word error rates', 'error', 'error rate', 'error rates'],
1314
'Word Error Rate (WER)': ['wer', 'wers', 'word error rate', 'word error rates', 'error', 'error rate', 'error rates'],
15+
'Word Accuracy': ['accuracy', 'word accuracy', 'acc', 'word acc'],
1416
'ROUGE-1': ['r1'],
1517
'ROUGE-2': ['r2'],
1618
'ROUGE-F': ['rf'],
@@ -31,8 +33,12 @@
3133
'Category IoU': ['cat iou', 'iou cat'],
3234
'class iIoU': ['class iiou', 'iiou cla'],
3335
'Category iIoU': ['cat iiou', 'iiou cat'],
34-
'Mean Accuracy': ['mean acc', 'mean', 'acc']
35-
36+
'Mean Accuracy': ['mean acc', 'mean', 'acc', 'accuracy', 'mean accuracy'],
37+
'Mean Error': ['mean err', 'mean', 'err', 'mean error', 'error'],
38+
'Top-1 Accuracy': ['top 1 accuracy', 'top 1', 'top 1 acc'],
39+
'Top-5 Accuracy': ['top 5 accuracy', 'top 5', 'top 5 acc'],
40+
'Top-1 Error Rate': ['top 1 error', 'top 1', 'top 1 err'],
41+
'Top-5 Error': ['top 5 error', 'top 5', 'top 5 err']
3642
}
3743

3844
# datasets[taxonomy name] is a list of normalized evidences for taxonomy name
@@ -153,3 +159,13 @@
153159
}
154160

155161
tasks = {}
162+
163+
complementary_metrics = {
164+
'Accuracy': 'Error',
165+
'Error': 'Accuracy',
166+
'Percentage Error': 'Accuracy',
167+
'Word Error Rate': 'Word Accuracy',
168+
'Word Error Rate (WER)': 'Word Accuracy',
169+
'Top-1 Accuracy': 'Top-1 Error Rate',
170+
'Top-5 Accuracy': 'Top-5 Error',
171+
}

sota_extractor2/models/linking/taxonomy.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,49 @@
11
from pathlib import Path
22
import json
33
from collections import OrderedDict
4-
4+
from sota_extractor2.models.linking.manual_dicts import complementary_metrics
55

66

77
class Taxonomy:
88
def __init__(self, taxonomy, metrics_info):
9-
self.taxonomy = self._read_taxonomy(taxonomy)
9+
self.taxonomy = self._get_taxonomy(taxonomy)
1010
self.metrics_info = self._read_metrics_info(metrics_info)
1111
self.tasks = self._get_axis('task')
1212
self.datasets = self._get_axis('dataset')
1313
self.metrics = self._get_axis('metric')
1414

15+
def normalize_metric(self, task, dataset, metric):
16+
if (task, dataset, metric) in self._complementary:
17+
return self._complementary[(task, dataset, metric)][2]
18+
return metric
19+
1520
def _read_json(self, path):
1621
with open(path, "rt") as f:
1722
return json.load(f)
1823

19-
def _read_taxonomy(self, path):
20-
self._records = self._read_json(path)
24+
def _get_complementary_metrics(self, records):
25+
complementary = []
26+
self._complementary = {}
27+
for record in records:
28+
metric = record["metric"]
29+
if metric in complementary_metrics:
30+
task = record["task"]
31+
dataset = record["dataset"]
32+
comp_metric = complementary_metrics[record["metric"]]
33+
complementary.append(
34+
dict(
35+
task=task,
36+
dataset=dataset,
37+
metric=comp_metric
38+
)
39+
)
40+
41+
self._complementary[(task, dataset, comp_metric)] = (task, dataset, metric)
42+
return complementary
43+
44+
def _get_taxonomy(self, path):
45+
records = self._read_json(path)
46+
self._records = records + self._get_complementary_metrics(records)
2147
return [(r["task"], r["dataset"], r["metric"]) for r in self._records]
2248

2349
def _get_axis(self, axis):

0 commit comments

Comments
 (0)