|
2 | 2 |
|
3 | 3 | import fire
|
4 | 4 | from sota_extractor.taskdb import TaskDB
|
| 5 | +from pathlib import Path |
| 6 | +import json |
| 7 | +import re |
| 8 | +import pandas as pd |
| 9 | +import sys |
| 10 | +from decimal import Decimal, ROUND_DOWN, ROUND_HALF_UP, InvalidOperation |
| 11 | + |
| 12 | + |
| 13 | +arxiv_url_re = re.compile(r"^https?://(www.)?arxiv.org/(abs|pdf|e-print)/(?P<arxiv_id>\d{4}\.[^./]*)(\.pdf)?$") |
5 | 14 |
|
6 | 15 | def get_sota_tasks(filename):
|
7 | 16 | db = TaskDB()
|
8 | 17 | db.load_tasks(filename)
|
9 | 18 | return db.tasks_with_sota()
|
10 | 19 |
|
11 | 20 |
|
12 |
| -def label_tables(tasksfile): |
| 21 | +def get_metadata(filename): |
| 22 | + with open(filename, "r") as f: |
| 23 | + return json.load(f) |
| 24 | + |
| 25 | + |
| 26 | +def get_table(filename): |
| 27 | + try: |
| 28 | + return pd.read_csv(filename, header=None, dtype=str).fillna('') |
| 29 | + except pd.errors.EmptyDataError: |
| 30 | + return [] |
| 31 | + |
| 32 | + |
| 33 | +def get_tables(tables_dir): |
| 34 | + tables_dir = Path(tables_dir) |
| 35 | + all_metadata = {} |
| 36 | + all_tables = {} |
| 37 | + for metadata_filename in tables_dir.glob("*/metadata.json"): |
| 38 | + metadata = get_metadata(metadata_filename) |
| 39 | + basedir = metadata_filename.parent |
| 40 | + arxiv_id = basedir.name |
| 41 | + all_metadata[arxiv_id] = metadata |
| 42 | + all_tables[arxiv_id] = {m['filename']:get_table(basedir / m['filename']) for m in metadata} |
| 43 | + return all_metadata, all_tables |
| 44 | + |
| 45 | + |
| 46 | +metric_na = ['-',''] |
| 47 | + |
| 48 | + |
| 49 | +# problematic values of metrics found in evaluation-tables.json |
| 50 | +# F0.5, 70.14 (measured by Ge et al., 2018) |
| 51 | +# Test Time, 0.33s/img |
| 52 | +# Accuracy, 77,62% |
| 53 | +# Electronics, 85,06 |
| 54 | +# BLEU-1, 54.60/55.55 |
| 55 | +# BLEU-4, 26.71/27.78 |
| 56 | +# MRPC, 78.6/84.4 |
| 57 | +# MRPC, 76.2/83.1 |
| 58 | +# STS, 78.9/78.6 |
| 59 | +# STS, 75.8/75.5 |
| 60 | +# BLEU score,41.0* |
| 61 | +# BLEU score,28.5* |
| 62 | +# SemEval 2007,**55.6** |
| 63 | +# Senseval 2,**69.0** |
| 64 | +# Senseval 3,**66.9** |
| 65 | +# MAE, 2.42±0.01 |
| 66 | + |
| 67 | +## multiple times |
| 68 | +# Number of params, 0.8B |
| 69 | +# Number of params, 88M |
| 70 | +# Parameters, 580k |
| 71 | +# Parameters, 3.1m |
| 72 | +# Params, 22M |
| 73 | + |
| 74 | + |
| 75 | + |
| 76 | +float_value_re = re.compile(r"([+-]?\s*(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?)") |
| 77 | +whitespace_re = re.compile(r"\s+") |
| 78 | + |
| 79 | + |
| 80 | +def normalize_float_value(s): |
| 81 | + match = float_value_re.search(s) |
| 82 | + if match: |
| 83 | + return whitespace_re.sub("", match.group(0)) |
| 84 | + return '-' |
| 85 | + |
| 86 | + |
| 87 | +def test_near(x, precise): |
| 88 | + for rounding in [ROUND_DOWN, ROUND_HALF_UP]: |
| 89 | + try: |
| 90 | + if x == precise.quantize(x, rounding=rounding): |
| 91 | + return True |
| 92 | + except InvalidOperation: |
| 93 | + pass |
| 94 | + return False |
| 95 | + |
| 96 | + |
| 97 | +def fuzzy_match(metric, metric_value, target_value): |
| 98 | + metric_value = normalize_float_value(str(metric_value)) |
| 99 | + if metric_value in metric_na: |
| 100 | + return False |
| 101 | + metric_value = Decimal(metric_value) |
| 102 | + |
| 103 | + for match in float_value_re.findall(target_value): |
| 104 | + value = whitespace_re.sub("", match[0]) |
| 105 | + value = Decimal(value) |
| 106 | + |
| 107 | + if test_near(metric_value, value): |
| 108 | + return True |
| 109 | + if test_near(metric_value.shift(2), value): |
| 110 | + return True |
| 111 | + if test_near(metric_value, value.shift(2)): |
| 112 | + return True |
| 113 | + |
| 114 | + return False |
| 115 | +# |
| 116 | +# if metric_value in metric_na or target_value in metric_na: |
| 117 | +# return False |
| 118 | +# if metric_value != target_value and metric_value in target_value: |
| 119 | +# print(f"|{metric_value}|{target_value}|") |
| 120 | +# return metric_value in target_value |
| 121 | + |
| 122 | + |
| 123 | +def match_metric(metric, tables, value): |
| 124 | + matching_tables = [] |
| 125 | + for table in tables: |
| 126 | + for col in tables[table]: |
| 127 | + for row in tables[table][col]: |
| 128 | + if fuzzy_match(metric, value, row): |
| 129 | + matching_tables.append(table) |
| 130 | + break |
| 131 | + else: |
| 132 | + continue |
| 133 | + break |
| 134 | + |
| 135 | + return matching_tables |
| 136 | + |
| 137 | + |
| 138 | +def label_tables(tasksfile, tables_dir): |
13 | 139 | tasks = get_sota_tasks(tasksfile)
|
| 140 | + metadata, tables = get_tables(tables_dir) |
| 141 | + |
| 142 | +# for arxiv_id in tables: |
| 143 | +# for t in tables[arxiv_id]: |
| 144 | +# table = tables[arxiv_id][t] |
| 145 | +# for col in table: |
| 146 | +# for row in table[col]: |
| 147 | +# print(row) |
| 148 | +# return |
14 | 149 | for task in tasks:
|
15 | 150 | for dataset in task.datasets:
|
16 | 151 | for row in dataset.sota.rows:
|
17 |
| - if 'arxiv.org' in row.paper_url: |
| 152 | + # TODO: some results have more than one url, CoRR + journal / conference |
| 153 | + # check if we have the same results for both |
| 154 | + |
| 155 | + match = arxiv_url_re.match(row.paper_url) |
| 156 | + if match is not None: |
| 157 | + arxiv_id = match.group("arxiv_id") |
| 158 | + if arxiv_id not in tables: |
| 159 | + print(f"No tables for {arxiv_id}. Skipping", file=sys.stderr) |
| 160 | + continue |
| 161 | + |
18 | 162 | for metric in row.metrics:
|
19 |
| - print((task.name, dataset.name, metric, row.model_name, row.metrics[metric], row.paper_url)) |
20 |
| - |
| 163 | + #print(f"{metric}\t{row.metrics[metric]}") |
| 164 | + #print((task.name, dataset.name, metric, row.model_name, row.metrics[metric], row.paper_url)) |
| 165 | + matching = match_metric(metric, tables[arxiv_id], row.metrics[metric]) |
| 166 | + #if not matching: |
| 167 | + # print(f"{metric}, {row.metrics[metric]}, {arxiv_id}") |
| 168 | + print(f"{metric},{len(matching)}") |
| 169 | + #if matching: |
| 170 | + # print((task.name, dataset.name, metric, row.model_name, row.metrics[metric], row.paper_url)) |
| 171 | + # print(matching) |
| 172 | + |
| 173 | + |
| 174 | + |
21 | 175 |
|
22 | 176 | if __name__ == "__main__": fire.Fire(label_tables)
|
0 commit comments