@@ -142,13 +142,13 @@ def match_metric(metric, tables, value):
142
142
test_near ,
143
143
lambda metric , target : test_near (metric .shift (2 ), target ),
144
144
lambda metric , target : test_near (metric , target .shift (2 )),
145
- lambda metric , target : test_near (metric , Decimal ("1" ) - target ),
146
- lambda metric , target : test_near (metric . shift ( 2 ), Decimal ("100" ) - target ),
147
- lambda metric , target : test_near (metric , ( Decimal ("1 " ) - target ) .shift (2 ))
145
+ lambda metric , target : test_near (Decimal ("1" ) - metric , target ),
146
+ lambda metric , target : test_near (Decimal ("100" ) - metric . shift ( 2 ), target ),
147
+ lambda metric , target : test_near (Decimal ("100 " ) - metric , target .shift (2 ))
148
148
]
149
149
150
150
151
- def mark_with_best_comparator (metric_name , arxiv_id , table , values ):
151
+ def mark_with_best_comparator (task_name , dataset_name , metric_name , arxiv_id , table , values ):
152
152
max_hits = 0
153
153
best_tags = None
154
154
rows , cols = table .shape
@@ -164,27 +164,32 @@ def mark_with_best_comparator(metric_name, arxiv_id, table, values):
164
164
hits += 1
165
165
tags = f"<sota>{ record ['value' ]} </sota>" + \
166
166
f"<paper>{ record ['arxiv_id' ]} </paper>" + \
167
- f"<model>{ record ['model' ]} </model>"
167
+ f"<model>{ record ['model' ]} </model>" + \
168
+ f"<metric>{ metric_name } </metric>" + \
169
+ f"<dataset>{ dataset_name } </dataset>" + \
170
+ f"<task>{ task_name } </task>"
168
171
if arxiv_id == record ["arxiv_id" ]:
169
172
tags += "<this_paper>"
170
173
cell_tags .iloc [row , col ] += tags
171
174
if max_hits < hits :
172
175
max_hits = hits
173
176
best_tags = cell_tags
174
177
175
- if max_hits > 2 :
176
- return best_tags
177
- return None
178
+ return best_tags
178
179
179
180
180
- def match_many (output_dir , metric_name , tables , values ):
181
+ metatables = {}
182
+ def match_many (output_dir , task_name , dataset_name , metric_name , tables , values ):
181
183
for arxiv_id in tables :
182
184
for table in tables [arxiv_id ]:
183
- best = mark_with_best_comparator (metric_name , arxiv_id , tables [arxiv_id ][table ], values )
185
+ best = mark_with_best_comparator (task_name , dataset_name , metric_name , arxiv_id , tables [arxiv_id ][table ], values )
184
186
if best is not None :
185
- out = output_dir / arxiv_id
186
- out .mkdir (parents = True , exist_ok = True )
187
- best .to_csv (out / table .replace ("table" , "celltags" ), header = None , index = None )
187
+ global metatables
188
+ key = (arxiv_id , table )
189
+ if key in metatables :
190
+ metatables [key ] += best
191
+ else :
192
+ metatables [key ] = best
188
193
189
194
190
195
def normalize_metric (value ):
@@ -243,7 +248,14 @@ def label_tables(tasksfile, tables_dir, output, output_dir):
243
248
for task , dataset , metric in arxivs_by_metrics :
244
249
records = arxivs_by_metrics [(task , dataset , metric )]
245
250
tabs = {r ["arxiv_id" ]: tables [r ["arxiv_id" ]] for r in records if r ["arxiv_id" ] in tables }
246
- match_many (output_dir , metric , tabs , records )
251
+ match_many (output_dir , task , dataset , metric , tabs , records )
252
+
253
+ global metatables
254
+
255
+ for (arxiv_id , table ), best in metatables .items ():
256
+ out = output_dir / arxiv_id
257
+ out .mkdir (parents = True , exist_ok = True )
258
+ best .to_csv (out / table .replace ("table" , "celltags" ), header = None , index = None )
247
259
248
260
return
249
261
tables_with_sota = []
0 commit comments