Skip to content

Commit 40df5f0

Browse files
committed
Add missing matcher. Match with all istead of best
1 parent a0aea55 commit 40df5f0

File tree

1 file changed

+56
-39
lines changed

1 file changed

+56
-39
lines changed

label_tables.py

Lines changed: 56 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pandas as pd
99
import sys
1010
from decimal import Decimal, ROUND_DOWN, ROUND_HALF_UP, InvalidOperation
11-
from collections import Counter
11+
from collections import Counter, namedtuple
1212

1313

1414
arxiv_url_re = re.compile(r"^https?://(www.)?arxiv.org/(abs|pdf|e-print)/(?P<arxiv_id>\d{4}\.[^./]*)(\.pdf)?$")
@@ -142,50 +142,67 @@ def match_metric(metric, tables, value):
142142
return matching_tables
143143

144144

145-
comparators = [
146-
test_near,
147-
lambda metric, target: test_near(metric.shift(2), target),
148-
lambda metric, target: test_near(metric, target.shift(2)),
149-
lambda metric, target: test_near(Decimal("1") - metric, target),
150-
lambda metric, target: test_near(Decimal("100") - metric.shift(2), target),
151-
lambda metric, target: test_near(Decimal("100") - metric, target.shift(2))
152-
]
145+
comparators = {
146+
"a=b": test_near,
147+
"100a=b": lambda metric, target: test_near(metric.shift(2), target),
148+
"a=100b": lambda metric, target: test_near(metric, target.shift(2)),
149+
"1-a=b": lambda metric, target: test_near(Decimal("1") - metric, target),
150+
"100-a=b": lambda metric, target: test_near(Decimal("100") - metric, target),
151+
"100-100a=b": lambda metric, target: test_near(Decimal("100") - metric.shift(2), target),
152+
"100-a=100b": lambda metric, target: test_near(Decimal("100") - metric, target.shift(2))
153+
}
153154

154155

155156
def empty_celltags_like(table):
156-
return = pd.DataFrame().reindex_like(table).fillna('')
157+
return pd.DataFrame().reindex_like(table).fillna('')
158+
159+
160+
def mark_with_comparator(task_name, dataset_name, metric_name, arxiv_id, table, values, comp_name):
161+
comparator = comparators[comp_name]
162+
rows, cols = table.shape
163+
hits = 0
164+
cell_tags = empty_celltags_like(table)
165+
for col in range(cols):
166+
for row in range(rows):
167+
for val in table.iloc[row, col]:
168+
for record in values:
169+
if comparator(record.normalized, val):
170+
hits += 1
171+
tags = f"<hit><sota>{record.value}</sota>" +\
172+
f"<paper>{record.arxiv_id}</paper>" +\
173+
f"<model>{record.model}</model>" +\
174+
f"<metric>{metric_name}</metric>" +\
175+
f"<dataset>{dataset_name}</dataset>" +\
176+
f"<task>{task_name}</task>"
177+
if arxiv_id == record.arxiv_id:
178+
tags += "<this_paper/>"
179+
tags += f"<comparator>{comp_name}</comparator>" +\
180+
f"<matched_cell>{val}</matched_cell></hit>"
181+
cell_tags.iloc[row, col] += tags
182+
return cell_tags, hits
157183

158184

159185
def mark_with_best_comparator(task_name, dataset_name, metric_name, arxiv_id, table, values):
160186
max_hits = 0
161187
best_tags = None
162-
rows, cols = table.shape
163188

164-
for comparator in comparators:
165-
hits = 0
166-
cell_tags = empty_celltags_like(table)
167-
for col in range(cols):
168-
for row in range(rows):
169-
for val in table.iloc[row, col]:
170-
for record in values:
171-
if comparator(record["normalized"], val):
172-
hits += 1
173-
tags = f"<sota>{record['value']}</sota>" +\
174-
f"<paper>{record['arxiv_id']}</paper>" +\
175-
f"<model>{record['model']}</model>" +\
176-
f"<metric>{metric_name}</metric>" +\
177-
f"<dataset>{dataset_name}</dataset>" +\
178-
f"<task>{task_name}</task>"
179-
if arxiv_id == record["arxiv_id"]:
180-
tags += "<this_paper>"
181-
cell_tags.iloc[row, col] += tags
189+
for comp_name in comparators:
190+
cell_tags, hits = mark_with_comparator(task_name, dataset_name, metric_name, arxiv_id, table, values, comp_name)
182191
if max_hits < hits:
183192
max_hits = hits
184193
best_tags = cell_tags
185194

186195
return best_tags
187196

188197

198+
def mark_with_all_comparators(task_name, dataset_name, metric_name, arxiv_id, table, values):
199+
all_tags = empty_celltags_like(table)
200+
for comp_name in comparators:
201+
cell_tags, _ = mark_with_comparator(task_name, dataset_name, metric_name, arxiv_id, table, values, comp_name)
202+
all_tags += cell_tags
203+
204+
return all_tags
205+
189206
def normalize_string(s):
190207
return s.lower.strip()
191208

@@ -211,14 +228,13 @@ def mark_strings(table, tags, values):
211228
def match_many(output_dir, task_name, dataset_name, metric_name, tables, values):
212229
for arxiv_id in tables:
213230
for table in tables[arxiv_id]:
214-
best = mark_with_best_comparator(task_name, dataset_name, metric_name, arxiv_id, tables[arxiv_id][table], values)
231+
tags = mark_with_all_comparators(task_name, dataset_name, metric_name, arxiv_id, tables[arxiv_id][table], values)
215232
global metatables
216-
if best is not None:
217-
key = (arxiv_id, table)
218-
if key in metatables:
219-
metatables[key] += best
220-
else:
221-
metatables[key] = best
233+
key = (arxiv_id, table)
234+
if key in metatables:
235+
metatables[key] += tags
236+
else:
237+
metatables[key] = tags
222238

223239

224240
def normalize_metric(value):
@@ -252,6 +268,7 @@ def normalize_table(table):
252268
# mark table with a given dataset_name and metric_name
253269
# mark hit cells with sota-tag, model_name and paper_id
254270
# if table.arxiv_id == paper_id: mark with this-tag
271+
PaperResult = namedtuple("PaperResult", ["arxiv_id", "model", "value", "normalized"])
255272

256273

257274
def label_tables(tasksfile, tables_dir, output, output_dir):
@@ -270,15 +287,15 @@ def label_tables(tasksfile, tables_dir, output, output_dir):
270287
if match is not None:
271288
arxiv_id = match.group("arxiv_id")
272289
for metric in row.metrics:
273-
arxivs_by_metrics.setdefault((task.name, dataset.name, metric), []).append(
274-
dict(arxiv_id=arxiv_id, model=row.model_name, value=row.metrics[metric],
290+
arxivs_by_metrics.setdefault((task.name, dataset.name, metric), set()).add(
291+
PaperResult(arxiv_id=arxiv_id, model=row.model_name, value=row.metrics[metric],
275292
normalized=normalize_metric(row.metrics[metric])
276293
)
277294
)
278295

279296
for task, dataset, metric in arxivs_by_metrics:
280297
records = arxivs_by_metrics[(task, dataset, metric)]
281-
tabs = {r["arxiv_id"]: tables[r["arxiv_id"]] for r in records if r["arxiv_id"] in tables}
298+
tabs = {r.arxiv_id: tables[r.arxiv_id] for r in records if r.arxiv_id in tables}
282299
match_many(output_dir, task, dataset, metric, tabs, records)
283300

284301
global metatables

0 commit comments

Comments
 (0)