8
8
import pandas as pd
9
9
import sys
10
10
from decimal import Decimal , ROUND_DOWN , ROUND_HALF_UP , InvalidOperation
11
- from collections import Counter
11
+ from collections import Counter , namedtuple
12
12
13
13
14
14
arxiv_url_re = re .compile (r"^https?://(www.)?arxiv.org/(abs|pdf|e-print)/(?P<arxiv_id>\d{4}\.[^./]*)(\.pdf)?$" )
@@ -142,50 +142,67 @@ def match_metric(metric, tables, value):
142
142
return matching_tables
143
143
144
144
145
- comparators = [
146
- test_near ,
147
- lambda metric , target : test_near (metric .shift (2 ), target ),
148
- lambda metric , target : test_near (metric , target .shift (2 )),
149
- lambda metric , target : test_near (Decimal ("1" ) - metric , target ),
150
- lambda metric , target : test_near (Decimal ("100" ) - metric .shift (2 ), target ),
151
- lambda metric , target : test_near (Decimal ("100" ) - metric , target .shift (2 ))
152
- ]
145
+ comparators = {
146
+ "a=b" : test_near ,
147
+ "100a=b" : lambda metric , target : test_near (metric .shift (2 ), target ),
148
+ "a=100b" : lambda metric , target : test_near (metric , target .shift (2 )),
149
+ "1-a=b" : lambda metric , target : test_near (Decimal ("1" ) - metric , target ),
150
+ "100-a=b" : lambda metric , target : test_near (Decimal ("100" ) - metric , target ),
151
+ "100-100a=b" : lambda metric , target : test_near (Decimal ("100" ) - metric .shift (2 ), target ),
152
+ "100-a=100b" : lambda metric , target : test_near (Decimal ("100" ) - metric , target .shift (2 ))
153
+ }
153
154
154
155
155
156
def empty_celltags_like (table ):
156
- return = pd .DataFrame ().reindex_like (table ).fillna ('' )
157
+ return pd .DataFrame ().reindex_like (table ).fillna ('' )
158
+
159
+
160
+ def mark_with_comparator (task_name , dataset_name , metric_name , arxiv_id , table , values , comp_name ):
161
+ comparator = comparators [comp_name ]
162
+ rows , cols = table .shape
163
+ hits = 0
164
+ cell_tags = empty_celltags_like (table )
165
+ for col in range (cols ):
166
+ for row in range (rows ):
167
+ for val in table .iloc [row , col ]:
168
+ for record in values :
169
+ if comparator (record .normalized , val ):
170
+ hits += 1
171
+ tags = f"<hit><sota>{ record .value } </sota>" + \
172
+ f"<paper>{ record .arxiv_id } </paper>" + \
173
+ f"<model>{ record .model } </model>" + \
174
+ f"<metric>{ metric_name } </metric>" + \
175
+ f"<dataset>{ dataset_name } </dataset>" + \
176
+ f"<task>{ task_name } </task>"
177
+ if arxiv_id == record .arxiv_id :
178
+ tags += "<this_paper/>"
179
+ tags += f"<comparator>{ comp_name } </comparator>" + \
180
+ f"<matched_cell>{ val } </matched_cell></hit>"
181
+ cell_tags .iloc [row , col ] += tags
182
+ return cell_tags , hits
157
183
158
184
159
185
def mark_with_best_comparator (task_name , dataset_name , metric_name , arxiv_id , table , values ):
160
186
max_hits = 0
161
187
best_tags = None
162
- rows , cols = table .shape
163
188
164
- for comparator in comparators :
165
- hits = 0
166
- cell_tags = empty_celltags_like (table )
167
- for col in range (cols ):
168
- for row in range (rows ):
169
- for val in table .iloc [row , col ]:
170
- for record in values :
171
- if comparator (record ["normalized" ], val ):
172
- hits += 1
173
- tags = f"<sota>{ record ['value' ]} </sota>" + \
174
- f"<paper>{ record ['arxiv_id' ]} </paper>" + \
175
- f"<model>{ record ['model' ]} </model>" + \
176
- f"<metric>{ metric_name } </metric>" + \
177
- f"<dataset>{ dataset_name } </dataset>" + \
178
- f"<task>{ task_name } </task>"
179
- if arxiv_id == record ["arxiv_id" ]:
180
- tags += "<this_paper>"
181
- cell_tags .iloc [row , col ] += tags
189
+ for comp_name in comparators :
190
+ cell_tags , hits = mark_with_comparator (task_name , dataset_name , metric_name , arxiv_id , table , values , comp_name )
182
191
if max_hits < hits :
183
192
max_hits = hits
184
193
best_tags = cell_tags
185
194
186
195
return best_tags
187
196
188
197
198
+ def mark_with_all_comparators (task_name , dataset_name , metric_name , arxiv_id , table , values ):
199
+ all_tags = empty_celltags_like (table )
200
+ for comp_name in comparators :
201
+ cell_tags , _ = mark_with_comparator (task_name , dataset_name , metric_name , arxiv_id , table , values , comp_name )
202
+ all_tags += cell_tags
203
+
204
+ return all_tags
205
+
189
206
def normalize_string (s ):
190
207
return s .lower .strip ()
191
208
@@ -211,14 +228,13 @@ def mark_strings(table, tags, values):
211
228
def match_many (output_dir , task_name , dataset_name , metric_name , tables , values ):
212
229
for arxiv_id in tables :
213
230
for table in tables [arxiv_id ]:
214
- best = mark_with_best_comparator (task_name , dataset_name , metric_name , arxiv_id , tables [arxiv_id ][table ], values )
231
+ tags = mark_with_all_comparators (task_name , dataset_name , metric_name , arxiv_id , tables [arxiv_id ][table ], values )
215
232
global metatables
216
- if best is not None :
217
- key = (arxiv_id , table )
218
- if key in metatables :
219
- metatables [key ] += best
220
- else :
221
- metatables [key ] = best
233
+ key = (arxiv_id , table )
234
+ if key in metatables :
235
+ metatables [key ] += tags
236
+ else :
237
+ metatables [key ] = tags
222
238
223
239
224
240
def normalize_metric (value ):
@@ -252,6 +268,7 @@ def normalize_table(table):
252
268
# mark table with a given dataset_name and metric_name
253
269
# mark hit cells with sota-tag, model_name and paper_id
254
270
# if table.arxiv_id == paper_id: mark with this-tag
271
+ PaperResult = namedtuple ("PaperResult" , ["arxiv_id" , "model" , "value" , "normalized" ])
255
272
256
273
257
274
def label_tables (tasksfile , tables_dir , output , output_dir ):
@@ -270,15 +287,15 @@ def label_tables(tasksfile, tables_dir, output, output_dir):
270
287
if match is not None :
271
288
arxiv_id = match .group ("arxiv_id" )
272
289
for metric in row .metrics :
273
- arxivs_by_metrics .setdefault ((task .name , dataset .name , metric ), []). append (
274
- dict (arxiv_id = arxiv_id , model = row .model_name , value = row .metrics [metric ],
290
+ arxivs_by_metrics .setdefault ((task .name , dataset .name , metric ), set ()). add (
291
+ PaperResult (arxiv_id = arxiv_id , model = row .model_name , value = row .metrics [metric ],
275
292
normalized = normalize_metric (row .metrics [metric ])
276
293
)
277
294
)
278
295
279
296
for task , dataset , metric in arxivs_by_metrics :
280
297
records = arxivs_by_metrics [(task , dataset , metric )]
281
- tabs = {r [ " arxiv_id" ] : tables [r [ " arxiv_id" ]] for r in records if r [ " arxiv_id" ] in tables }
298
+ tabs = {r . arxiv_id : tables [r . arxiv_id ] for r in records if r . arxiv_id in tables }
282
299
match_many (output_dir , task , dataset , metric , tabs , records )
283
300
284
301
global metatables
0 commit comments