Skip to content

Commit 0286cca

Browse files
committed
Save cell tags to files
1 parent 43ea412 commit 0286cca

File tree

1 file changed

+26
-14
lines changed

1 file changed

+26
-14
lines changed

label_tables.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,13 @@ def match_metric(metric, tables, value):
142142
test_near,
143143
lambda metric, target: test_near(metric.shift(2), target),
144144
lambda metric, target: test_near(metric, target.shift(2)),
145-
lambda metric, target: test_near(metric, Decimal("1") - target),
146-
lambda metric, target: test_near(metric.shift(2), Decimal("100") - target),
147-
lambda metric, target: test_near(metric, (Decimal("1") - target).shift(2))
145+
lambda metric, target: test_near(Decimal("1") - metric, target),
146+
lambda metric, target: test_near(Decimal("100") - metric.shift(2), target),
147+
lambda metric, target: test_near(Decimal("100") - metric, target.shift(2))
148148
]
149149

150150

151-
def mark_with_best_comparator(metric_name, arxiv_id, table, values):
151+
def mark_with_best_comparator(task_name, dataset_name, metric_name, arxiv_id, table, values):
152152
max_hits = 0
153153
best_tags = None
154154
rows, cols = table.shape
@@ -164,27 +164,32 @@ def mark_with_best_comparator(metric_name, arxiv_id, table, values):
164164
hits += 1
165165
tags = f"<sota>{record['value']}</sota>" +\
166166
f"<paper>{record['arxiv_id']}</paper>" +\
167-
f"<model>{record['model']}</model>"
167+
f"<model>{record['model']}</model>" +\
168+
f"<metric>{metric_name}</metric>" +\
169+
f"<dataset>{dataset_name}</dataset>" +\
170+
f"<task>{task_name}</task>"
168171
if arxiv_id == record["arxiv_id"]:
169172
tags += "<this_paper>"
170173
cell_tags.iloc[row, col] += tags
171174
if max_hits < hits:
172175
max_hits = hits
173176
best_tags = cell_tags
174177

175-
if max_hits > 2:
176-
return best_tags
177-
return None
178+
return best_tags
178179

179180

180-
def match_many(output_dir, metric_name, tables, values):
181+
metatables = {}
182+
def match_many(output_dir, task_name, dataset_name, metric_name, tables, values):
181183
for arxiv_id in tables:
182184
for table in tables[arxiv_id]:
183-
best = mark_with_best_comparator(metric_name, arxiv_id, tables[arxiv_id][table], values)
185+
best = mark_with_best_comparator(task_name, dataset_name, metric_name, arxiv_id, tables[arxiv_id][table], values)
184186
if best is not None:
185-
out = output_dir / arxiv_id
186-
out.mkdir(parents=True, exist_ok=True)
187-
best.to_csv(out / table.replace("table", "celltags"), header=None, index=None)
187+
global metatables
188+
key = (arxiv_id, table)
189+
if key in metatables:
190+
metatables[key] += best
191+
else:
192+
metatables[key] = best
188193

189194

190195
def normalize_metric(value):
@@ -243,7 +248,14 @@ def label_tables(tasksfile, tables_dir, output, output_dir):
243248
for task, dataset, metric in arxivs_by_metrics:
244249
records = arxivs_by_metrics[(task, dataset, metric)]
245250
tabs = {r["arxiv_id"]: tables[r["arxiv_id"]] for r in records if r["arxiv_id"] in tables}
246-
match_many(output_dir, metric, tabs, records)
251+
match_many(output_dir, task, dataset, metric, tabs, records)
252+
253+
global metatables
254+
255+
for (arxiv_id, table), best in metatables.items():
256+
out = output_dir / arxiv_id
257+
out.mkdir(parents=True, exist_ok=True)
258+
best.to_csv(out / table.replace("table", "celltags"), header=None, index=None)
247259

248260
return
249261
tables_with_sota = []

0 commit comments

Comments
 (0)