Skip to content

Commit 8e40f84

Browse files
committed
List all sota tables
1 parent c155467 commit 8e40f84

File tree

1 file changed

+38
-5
lines changed

1 file changed

+38
-5
lines changed

label_tables.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ def get_sota_tasks(filename):
2020

2121
def get_metadata(filename):
2222
with open(filename, "r") as f:
23-
return json.load(f)
23+
j = json.load(f)
24+
metadata = {x["filename"]:x["caption"] for x in j}
25+
return metadata
2426

2527

2628
def get_table(filename):
@@ -39,7 +41,7 @@ def get_tables(tables_dir):
3941
basedir = metadata_filename.parent
4042
arxiv_id = basedir.name
4143
all_metadata[arxiv_id] = metadata
42-
all_tables[arxiv_id] = {m['filename']:get_table(basedir / m['filename']) for m in metadata}
44+
all_tables[arxiv_id] = {t:get_table(basedir / t) for t in metadata}
4345
return all_metadata, all_tables
4446

4547

@@ -135,7 +137,21 @@ def match_metric(metric, tables, value):
135137
return matching_tables
136138

137139

138-
def label_tables(tasksfile, tables_dir):
140+
141+
# for each task with sota row
142+
# arxivs <- list of papers related to the task
143+
# for each (dataset_name, metric_name) of the task:
144+
# for each table in arxivs
145+
# for each fuzzy_comparator
146+
# count number of task's sota rows found in the table using comparator
147+
# comparator <- comparator with the largest number of hits
148+
# if hits > hits_threshold:
149+
# mark table with a given dataset_name and metric_name
150+
# mark hit cells with sota-tag, model_name and paper_id
151+
# if table.arxiv_id == paper_id: mark with this-tag
152+
153+
154+
def label_tables(tasksfile, tables_dir, output):
139155
tasks = get_sota_tasks(tasksfile)
140156
metadata, tables = get_tables(tables_dir)
141157

@@ -146,6 +162,8 @@ def label_tables(tasksfile, tables_dir):
146162
# for row in table[col]:
147163
# print(row)
148164
# return
165+
166+
tables_with_sota = []
149167
for task in tasks:
150168
for dataset in task.datasets:
151169
for row in dataset.sota.rows:
@@ -163,13 +181,28 @@ def label_tables(tasksfile, tables_dir):
163181
#print(f"{metric}\t{row.metrics[metric]}")
164182
#print((task.name, dataset.name, metric, row.model_name, row.metrics[metric], row.paper_url))
165183
matching = match_metric(metric, tables[arxiv_id], row.metrics[metric])
184+
if len(matching) == 1:
185+
sota_table = matching[0]
186+
187+
tables_with_sota.append(
188+
dict(
189+
task_name=task.name,
190+
dataset_name=dataset.name,
191+
metric_name=metric,
192+
model_name=row.model_name,
193+
metric_value=row.metrics[metric],
194+
paper_url=row.paper_url,
195+
table_caption=metadata[arxiv_id][sota_table],
196+
table_filename=f"{arxiv_id}/{sota_table}"
197+
)
198+
)
166199
#if not matching:
167200
# print(f"{metric}, {row.metrics[metric]}, {arxiv_id}")
168-
print(f"{metric},{len(matching)}")
201+
#print(f"{metric},{len(matching)}")
169202
#if matching:
170203
# print((task.name, dataset.name, metric, row.model_name, row.metrics[metric], row.paper_url))
171204
# print(matching)
172-
205+
pd.DataFrame(tables_with_sota).to_csv(output, index=None)
173206

174207

175208

0 commit comments

Comments
 (0)