Skip to content

Commit 43ea412

Browse files
committed
Mark cells with sota tags
1 parent 8e40f84 commit 43ea412

File tree

3 files changed

+113
-14
lines changed

3 files changed

+113
-14
lines changed

extract-tables.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,7 @@
1010
import re
1111
from ast import literal_eval
1212

13-
14-
class Tabular:
15-
def __init__(self, data, caption):
16-
self.data = data
17-
self.caption = caption
13+
from tabular import Tabular
1814

1915

2016
def flatten_tables(soup):

label_tables.py

Lines changed: 92 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pandas as pd
99
import sys
1010
from decimal import Decimal, ROUND_DOWN, ROUND_HALF_UP, InvalidOperation
11+
from collections import Counter
1112

1213

1314
arxiv_url_re = re.compile(r"^https?://(www.)?arxiv.org/(abs|pdf|e-print)/(?P<arxiv_id>\d{4}\.[^./]*)(\.pdf)?$")
@@ -29,7 +30,7 @@ def get_table(filename):
2930
try:
3031
return pd.read_csv(filename, header=None, dtype=str).fillna('')
3132
except pd.errors.EmptyDataError:
32-
return []
33+
return pd.DataFrame()
3334

3435

3536
def get_tables(tables_dir):
@@ -137,6 +138,72 @@ def match_metric(metric, tables, value):
137138
return matching_tables
138139

139140

141+
comparators = [
142+
test_near,
143+
lambda metric, target: test_near(metric.shift(2), target),
144+
lambda metric, target: test_near(metric, target.shift(2)),
145+
lambda metric, target: test_near(metric, Decimal("1") - target),
146+
lambda metric, target: test_near(metric.shift(2), Decimal("100") - target),
147+
lambda metric, target: test_near(metric, (Decimal("1") - target).shift(2))
148+
]
149+
150+
151+
def mark_with_best_comparator(metric_name, arxiv_id, table, values):
152+
max_hits = 0
153+
best_tags = None
154+
rows, cols = table.shape
155+
156+
for comparator in comparators:
157+
hits = 0
158+
cell_tags = pd.DataFrame().reindex_like(table).fillna('')
159+
for col in range(cols):
160+
for row in range(rows):
161+
for val in table.iloc[row, col]:
162+
for record in values:
163+
if comparator(record["normalized"], val):
164+
hits += 1
165+
tags = f"<sota>{record['value']}</sota>" +\
166+
f"<paper>{record['arxiv_id']}</paper>" +\
167+
f"<model>{record['model']}</model>"
168+
if arxiv_id == record["arxiv_id"]:
169+
tags += "<this_paper>"
170+
cell_tags.iloc[row, col] += tags
171+
if max_hits < hits:
172+
max_hits = hits
173+
best_tags = cell_tags
174+
175+
if max_hits > 2:
176+
return best_tags
177+
return None
178+
179+
180+
def match_many(output_dir, metric_name, tables, values):
181+
for arxiv_id in tables:
182+
for table in tables[arxiv_id]:
183+
best = mark_with_best_comparator(metric_name, arxiv_id, tables[arxiv_id][table], values)
184+
if best is not None:
185+
out = output_dir / arxiv_id
186+
out.mkdir(parents=True, exist_ok=True)
187+
best.to_csv(out / table.replace("table", "celltags"), header=None, index=None)
188+
189+
190+
def normalize_metric(value):
191+
value = normalize_float_value(str(value))
192+
if value in metric_na:
193+
return Decimal("NaN")
194+
return Decimal(value)
195+
196+
197+
def normalize_cell(cell):
198+
matches = float_value_re.findall(cell)
199+
matches = [whitespace_re.sub("", match[0]) for match in matches]
200+
values = [Decimal(value) for value in matches]
201+
return values
202+
203+
204+
def normalize_table(table):
205+
return table.applymap(normalize_cell)
206+
140207

141208
# for each task with sota row
142209
# arxivs <- list of papers related to the task
@@ -151,18 +218,34 @@ def match_metric(metric, tables, value):
151218
# if table.arxiv_id == paper_id: mark with this-tag
152219

153220

154-
def label_tables(tasksfile, tables_dir, output):
221+
def label_tables(tasksfile, tables_dir, output, output_dir):
222+
output_dir = Path(output_dir)
155223
tasks = get_sota_tasks(tasksfile)
156224
metadata, tables = get_tables(tables_dir)
157225

158-
# for arxiv_id in tables:
159-
# for t in tables[arxiv_id]:
160-
# table = tables[arxiv_id][t]
161-
# for col in table:
162-
# for row in table[col]:
163-
# print(row)
164-
# return
226+
arxivs_by_metrics = {}
227+
228+
tables = {arxiv_id: {tab: normalize_table(tables[arxiv_id][tab]) for tab in tables[arxiv_id]} for arxiv_id in tables}
229+
230+
for task in tasks:
231+
for dataset in task.datasets:
232+
for row in dataset.sota.rows:
233+
match = arxiv_url_re.match(row.paper_url)
234+
if match is not None:
235+
arxiv_id = match.group("arxiv_id")
236+
for metric in row.metrics:
237+
arxivs_by_metrics.setdefault((task.name, dataset.name, metric), []).append(
238+
dict(arxiv_id=arxiv_id, model=row.model_name, value=row.metrics[metric],
239+
normalized=normalize_metric(row.metrics[metric])
240+
)
241+
)
242+
243+
for task, dataset, metric in arxivs_by_metrics:
244+
records = arxivs_by_metrics[(task, dataset, metric)]
245+
tabs = {r["arxiv_id"]: tables[r["arxiv_id"]] for r in records if r["arxiv_id"] in tables}
246+
match_many(output_dir, metric, tabs, records)
165247

248+
return
166249
tables_with_sota = []
167250
for task in tasks:
168251
for dataset in task.datasets:

tabular.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import pandas as pd
2+
import numpy as np
3+
import json
4+
5+
6+
class Tabular:
7+
def __init__(self, data, caption):
8+
self.data = data
9+
self.cell_tags = pd.DataFrame().reindex_like(data).fillna('')
10+
self.datasets = set()
11+
self.metrics = set()
12+
self.caption = caption
13+
14+
def mark_with_metric(self, metric_name):
15+
self.metrics.add(metric_name)
16+
17+
def mark_with_dataset(self, dataset_name):
18+
self.datasets.add(dataset_name)
19+
20+

0 commit comments

Comments
 (0)