Skip to content

Commit 4fef880

Browse files
committed
Extract numbers for fuzzy search
1 parent 969717c commit 4fef880

File tree

1 file changed

+158
-4
lines changed

1 file changed

+158
-4
lines changed

label_tables.py

Lines changed: 158 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,175 @@
22

33
import fire
44
from sota_extractor.taskdb import TaskDB
5+
from pathlib import Path
6+
import json
7+
import re
8+
import pandas as pd
9+
import sys
10+
from decimal import Decimal, ROUND_DOWN, ROUND_HALF_UP, InvalidOperation
11+
12+
13+
arxiv_url_re = re.compile(r"^https?://(www.)?arxiv.org/(abs|pdf|e-print)/(?P<arxiv_id>\d{4}\.[^./]*)(\.pdf)?$")
514

615
def get_sota_tasks(filename):
716
db = TaskDB()
817
db.load_tasks(filename)
918
return db.tasks_with_sota()
1019

1120

12-
def label_tables(tasksfile):
21+
def get_metadata(filename):
22+
with open(filename, "r") as f:
23+
return json.load(f)
24+
25+
26+
def get_table(filename):
27+
try:
28+
return pd.read_csv(filename, header=None, dtype=str).fillna('')
29+
except pd.errors.EmptyDataError:
30+
return []
31+
32+
33+
def get_tables(tables_dir):
34+
tables_dir = Path(tables_dir)
35+
all_metadata = {}
36+
all_tables = {}
37+
for metadata_filename in tables_dir.glob("*/metadata.json"):
38+
metadata = get_metadata(metadata_filename)
39+
basedir = metadata_filename.parent
40+
arxiv_id = basedir.name
41+
all_metadata[arxiv_id] = metadata
42+
all_tables[arxiv_id] = {m['filename']:get_table(basedir / m['filename']) for m in metadata}
43+
return all_metadata, all_tables
44+
45+
46+
metric_na = ['-','']
47+
48+
49+
# problematic values of metrics found in evaluation-tables.json
50+
# F0.5, 70.14 (measured by Ge et al., 2018)
51+
# Test Time, 0.33s/img
52+
# Accuracy, 77,62%
53+
# Electronics, 85,06
54+
# BLEU-1, 54.60/55.55
55+
# BLEU-4, 26.71/27.78
56+
# MRPC, 78.6/84.4
57+
# MRPC, 76.2/83.1
58+
# STS, 78.9/78.6
59+
# STS, 75.8/75.5
60+
# BLEU score,41.0*
61+
# BLEU score,28.5*
62+
# SemEval 2007,**55.6**
63+
# Senseval 2,**69.0**
64+
# Senseval 3,**66.9**
65+
# MAE, 2.42±0.01
66+
67+
## multiple times
68+
# Number of params, 0.8B
69+
# Number of params, 88M
70+
# Parameters, 580k
71+
# Parameters, 3.1m
72+
# Params, 22M
73+
74+
75+
76+
float_value_re = re.compile(r"([+-]?\s*(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?)")
77+
whitespace_re = re.compile(r"\s+")
78+
79+
80+
def normalize_float_value(s):
81+
match = float_value_re.search(s)
82+
if match:
83+
return whitespace_re.sub("", match.group(0))
84+
return '-'
85+
86+
87+
def test_near(x, precise):
88+
for rounding in [ROUND_DOWN, ROUND_HALF_UP]:
89+
try:
90+
if x == precise.quantize(x, rounding=rounding):
91+
return True
92+
except InvalidOperation:
93+
pass
94+
return False
95+
96+
97+
def fuzzy_match(metric, metric_value, target_value):
98+
metric_value = normalize_float_value(str(metric_value))
99+
if metric_value in metric_na:
100+
return False
101+
metric_value = Decimal(metric_value)
102+
103+
for match in float_value_re.findall(target_value):
104+
value = whitespace_re.sub("", match[0])
105+
value = Decimal(value)
106+
107+
if test_near(metric_value, value):
108+
return True
109+
if test_near(metric_value.shift(2), value):
110+
return True
111+
if test_near(metric_value, value.shift(2)):
112+
return True
113+
114+
return False
115+
#
116+
# if metric_value in metric_na or target_value in metric_na:
117+
# return False
118+
# if metric_value != target_value and metric_value in target_value:
119+
# print(f"|{metric_value}|{target_value}|")
120+
# return metric_value in target_value
121+
122+
123+
def match_metric(metric, tables, value):
124+
matching_tables = []
125+
for table in tables:
126+
for col in tables[table]:
127+
for row in tables[table][col]:
128+
if fuzzy_match(metric, value, row):
129+
matching_tables.append(table)
130+
break
131+
else:
132+
continue
133+
break
134+
135+
return matching_tables
136+
137+
138+
def label_tables(tasksfile, tables_dir):
13139
tasks = get_sota_tasks(tasksfile)
140+
metadata, tables = get_tables(tables_dir)
141+
142+
# for arxiv_id in tables:
143+
# for t in tables[arxiv_id]:
144+
# table = tables[arxiv_id][t]
145+
# for col in table:
146+
# for row in table[col]:
147+
# print(row)
148+
# return
14149
for task in tasks:
15150
for dataset in task.datasets:
16151
for row in dataset.sota.rows:
17-
if 'arxiv.org' in row.paper_url:
152+
# TODO: some results have more than one url, CoRR + journal / conference
153+
# check if we have the same results for both
154+
155+
match = arxiv_url_re.match(row.paper_url)
156+
if match is not None:
157+
arxiv_id = match.group("arxiv_id")
158+
if arxiv_id not in tables:
159+
print(f"No tables for {arxiv_id}. Skipping", file=sys.stderr)
160+
continue
161+
18162
for metric in row.metrics:
19-
print((task.name, dataset.name, metric, row.model_name, row.metrics[metric], row.paper_url))
20-
163+
#print(f"{metric}\t{row.metrics[metric]}")
164+
#print((task.name, dataset.name, metric, row.model_name, row.metrics[metric], row.paper_url))
165+
matching = match_metric(metric, tables[arxiv_id], row.metrics[metric])
166+
#if not matching:
167+
# print(f"{metric}, {row.metrics[metric]}, {arxiv_id}")
168+
print(f"{metric},{len(matching)}")
169+
#if matching:
170+
# print((task.name, dataset.name, metric, row.model_name, row.metrics[metric], row.paper_url))
171+
# print(matching)
172+
173+
174+
21175

22176
if __name__ == "__main__": fire.Fire(label_tables)

0 commit comments

Comments
 (0)