8
8
import pandas as pd
9
9
import sys
10
10
from decimal import Decimal , ROUND_DOWN , ROUND_HALF_UP , InvalidOperation
11
+ from collections import Counter
11
12
12
13
13
14
arxiv_url_re = re .compile (r"^https?://(www.)?arxiv.org/(abs|pdf|e-print)/(?P<arxiv_id>\d{4}\.[^./]*)(\.pdf)?$" )
@@ -29,7 +30,7 @@ def get_table(filename):
29
30
try :
30
31
return pd .read_csv (filename , header = None , dtype = str ).fillna ('' )
31
32
except pd .errors .EmptyDataError :
32
- return []
33
+ return pd . DataFrame ()
33
34
34
35
35
36
def get_tables (tables_dir ):
@@ -137,6 +138,72 @@ def match_metric(metric, tables, value):
137
138
return matching_tables
138
139
139
140
141
+ comparators = [
142
+ test_near ,
143
+ lambda metric , target : test_near (metric .shift (2 ), target ),
144
+ lambda metric , target : test_near (metric , target .shift (2 )),
145
+ lambda metric , target : test_near (metric , Decimal ("1" ) - target ),
146
+ lambda metric , target : test_near (metric .shift (2 ), Decimal ("100" ) - target ),
147
+ lambda metric , target : test_near (metric , (Decimal ("1" ) - target ).shift (2 ))
148
+ ]
149
+
150
+
151
+ def mark_with_best_comparator (metric_name , arxiv_id , table , values ):
152
+ max_hits = 0
153
+ best_tags = None
154
+ rows , cols = table .shape
155
+
156
+ for comparator in comparators :
157
+ hits = 0
158
+ cell_tags = pd .DataFrame ().reindex_like (table ).fillna ('' )
159
+ for col in range (cols ):
160
+ for row in range (rows ):
161
+ for val in table .iloc [row , col ]:
162
+ for record in values :
163
+ if comparator (record ["normalized" ], val ):
164
+ hits += 1
165
+ tags = f"<sota>{ record ['value' ]} </sota>" + \
166
+ f"<paper>{ record ['arxiv_id' ]} </paper>" + \
167
+ f"<model>{ record ['model' ]} </model>"
168
+ if arxiv_id == record ["arxiv_id" ]:
169
+ tags += "<this_paper>"
170
+ cell_tags .iloc [row , col ] += tags
171
+ if max_hits < hits :
172
+ max_hits = hits
173
+ best_tags = cell_tags
174
+
175
+ if max_hits > 2 :
176
+ return best_tags
177
+ return None
178
+
179
+
180
+ def match_many (output_dir , metric_name , tables , values ):
181
+ for arxiv_id in tables :
182
+ for table in tables [arxiv_id ]:
183
+ best = mark_with_best_comparator (metric_name , arxiv_id , tables [arxiv_id ][table ], values )
184
+ if best is not None :
185
+ out = output_dir / arxiv_id
186
+ out .mkdir (parents = True , exist_ok = True )
187
+ best .to_csv (out / table .replace ("table" , "celltags" ), header = None , index = None )
188
+
189
+
190
+ def normalize_metric (value ):
191
+ value = normalize_float_value (str (value ))
192
+ if value in metric_na :
193
+ return Decimal ("NaN" )
194
+ return Decimal (value )
195
+
196
+
197
+ def normalize_cell (cell ):
198
+ matches = float_value_re .findall (cell )
199
+ matches = [whitespace_re .sub ("" , match [0 ]) for match in matches ]
200
+ values = [Decimal (value ) for value in matches ]
201
+ return values
202
+
203
+
204
+ def normalize_table (table ):
205
+ return table .applymap (normalize_cell )
206
+
140
207
141
208
# for each task with sota row
142
209
# arxivs <- list of papers related to the task
@@ -151,18 +218,34 @@ def match_metric(metric, tables, value):
151
218
# if table.arxiv_id == paper_id: mark with this-tag
152
219
153
220
154
- def label_tables (tasksfile , tables_dir , output ):
221
+ def label_tables (tasksfile , tables_dir , output , output_dir ):
222
+ output_dir = Path (output_dir )
155
223
tasks = get_sota_tasks (tasksfile )
156
224
metadata , tables = get_tables (tables_dir )
157
225
158
- # for arxiv_id in tables:
159
- # for t in tables[arxiv_id]:
160
- # table = tables[arxiv_id][t]
161
- # for col in table:
162
- # for row in table[col]:
163
- # print(row)
164
- # return
226
+ arxivs_by_metrics = {}
227
+
228
+ tables = {arxiv_id : {tab : normalize_table (tables [arxiv_id ][tab ]) for tab in tables [arxiv_id ]} for arxiv_id in tables }
229
+
230
+ for task in tasks :
231
+ for dataset in task .datasets :
232
+ for row in dataset .sota .rows :
233
+ match = arxiv_url_re .match (row .paper_url )
234
+ if match is not None :
235
+ arxiv_id = match .group ("arxiv_id" )
236
+ for metric in row .metrics :
237
+ arxivs_by_metrics .setdefault ((task .name , dataset .name , metric ), []).append (
238
+ dict (arxiv_id = arxiv_id , model = row .model_name , value = row .metrics [metric ],
239
+ normalized = normalize_metric (row .metrics [metric ])
240
+ )
241
+ )
242
+
243
+ for task , dataset , metric in arxivs_by_metrics :
244
+ records = arxivs_by_metrics [(task , dataset , metric )]
245
+ tabs = {r ["arxiv_id" ]: tables [r ["arxiv_id" ]] for r in records if r ["arxiv_id" ] in tables }
246
+ match_many (output_dir , metric , tabs , records )
165
247
248
+ return
166
249
tables_with_sota = []
167
250
for task in tasks :
168
251
for dataset in task .datasets :
0 commit comments