|
| 1 | +import re |
| 2 | +from decimal import Decimal |
| 3 | +from dataclasses import dataclass |
| 4 | +import numpy as np |
| 5 | +import pandas as pd |
| 6 | +from elasticsearch import Elasticsearch, client |
| 7 | +import logging |
| 8 | + |
| 9 | + |
| 10 | +@dataclass() |
| 11 | +class Value: |
| 12 | + type: str |
| 13 | + value: str |
| 14 | + def __str__(self): |
| 15 | + return self.value |
| 16 | + |
| 17 | + |
| 18 | +@dataclass() |
| 19 | +class Cell: |
| 20 | + cell_ext_id: str |
| 21 | + table_ext_id: str |
| 22 | + row: int |
| 23 | + col: int |
| 24 | + |
| 25 | + |
| 26 | +@dataclass() |
| 27 | +class Proposal: |
| 28 | + cell: Cell |
| 29 | + dataset_values: list |
| 30 | + table_description: str |
| 31 | + model_values: list # best paper competing |
| 32 | + model_params: dict = None |
| 33 | + raw_value: str = "" |
| 34 | + |
| 35 | + def __post_init__(self): |
| 36 | + if self.model_params is None: |
| 37 | + self.model_params = {} |
| 38 | + |
| 39 | + @property |
| 40 | + def dataset(self): |
| 41 | + return ' '.join(map(str, self.dataset_values)).strip() |
| 42 | + |
| 43 | + @property |
| 44 | + def model_name(self): |
| 45 | + return ' '.join(map(str, self.model_values)).strip() |
| 46 | + |
| 47 | + @property |
| 48 | + def model_type(self): |
| 49 | + types = [v.type for v in self.model_values] + [''] |
| 50 | + if 'model-competing' in types: |
| 51 | + return 'model-competing' # competing model is different from model-paper and model-best so we return it first |
| 52 | + return types[0] |
| 53 | + |
| 54 | + def __str__(self): |
| 55 | + return f"{self.model_name}: {self.raw_value} on {self.dataset}" |
| 56 | + |
| 57 | +def mkquery_ngrams(query): |
| 58 | + return { |
| 59 | + "query": { |
| 60 | + "multi_match": { |
| 61 | + "query": query, |
| 62 | + "fields": ["dataset^3", "dataset.ngrams^1", "metric^1", "metric.ngrams^1", "task^1", |
| 63 | + "task.ngrams^1"] |
| 64 | + } |
| 65 | + } |
| 66 | + } |
| 67 | + |
| 68 | + |
| 69 | +def mkquery_fullmatch(query): |
| 70 | + return { |
| 71 | + "query": { |
| 72 | + "multi_match": { |
| 73 | + "query": query, |
| 74 | + "fields": ["dataset^3", "metric^1", "task^1"] |
| 75 | + } |
| 76 | + } |
| 77 | + } |
| 78 | + |
| 79 | +class MatchSearch: |
| 80 | + def __init__(self, mkquery=mkquery_ngrams, es=None): |
| 81 | + self.case = True |
| 82 | + self.all_fields = True |
| 83 | + self.es = es or Elasticsearch() |
| 84 | + self.log = logging.getLogger(__name__) |
| 85 | + self.mkquery = mkquery |
| 86 | + |
| 87 | + def preproc(self, val): |
| 88 | + val = val.strip(',- ') |
| 89 | + val = re.sub("dataset", '', val, flags=re.I) |
| 90 | + # if self.case: |
| 91 | + # val += (" " +re.sub("([a-z])([A-Z])", r'\1 \2', val) |
| 92 | + # +" " +re.sub("([a-zA-Z])([0-9])", r'\1 \2', val) |
| 93 | + # ) |
| 94 | + return val |
| 95 | + |
| 96 | + def search(self, query, explain_doc_id=None): |
| 97 | + body = self.mkquery(query) |
| 98 | + if explain_doc_id is not None: |
| 99 | + return self.es.explain('et_taxonomy', doc_type='doc', id=explain_doc_id, body=body) |
| 100 | + return self.es.search('et_taxonomy', doc_type='doc', body=body)["hits"] |
| 101 | + |
| 102 | + def __call__(self, query): |
| 103 | + split_re = re.compile('([^a-zA-Z0-9])') |
| 104 | + query = self.preproc(query).strip() |
| 105 | + results = self.search(query) |
| 106 | + hits = results["hits"][:3] |
| 107 | + df = pd.DataFrame.from_records([ |
| 108 | + dict(**hit["_source"], |
| 109 | + confidence=hit["_score"] / len(split_re.split(query)), |
| 110 | + # Roughly normalize the score not to ignore query length |
| 111 | + evidence=query) for hit in hits |
| 112 | + ], columns=["dataset", "metric", "task", "confidence", "evidence"]) |
| 113 | + if not len(df): |
| 114 | + self.log.debug("Elastic query didn't produce any output", query, hits) |
| 115 | + else: |
| 116 | + scores = [] |
| 117 | + for dataset in df["dataset"]: |
| 118 | + r = self.search(dataset) |
| 119 | + scores.append( |
| 120 | + dict(ok_score=r['hits'][0]['_score'] / len(split_re.split(dataset)), |
| 121 | + bad_score=r['hits'][1]['_score'] / len(split_re.split(dataset)))) |
| 122 | + |
| 123 | + scores = pd.DataFrame.from_records(scores) |
| 124 | + df['confidence'] = ((scores['ok_score'] - scores['bad_score']) / scores['bad_score']) * df['confidence'] / scores['ok_score'] |
| 125 | + return df[["dataset", "metric", "task", "confidence", "evidence"]] |
| 126 | + |
| 127 | +float_pm_re = re.compile(r"(±?)([+-]?\s*(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?)\s*(%?)") |
| 128 | +whitespace_re = re.compile(r"\s+") |
| 129 | +def handle_pm(value): |
| 130 | + "handle precentage metric" |
| 131 | + for match in float_pm_re.findall(value): |
| 132 | + if not match[0]: |
| 133 | + try: |
| 134 | + yield Decimal(whitespace_re.sub("", match[1])) / (100 if match[-1] else 1) |
| 135 | + except: |
| 136 | + pass |
| 137 | + # %% |
| 138 | + |
| 139 | +def generate_proposals_for_table(table_ext_id, matrix, structure, desc, taxonomy_linking): |
| 140 | + # %% |
| 141 | + # Proposal generation |
| 142 | + def consume_cells(matrix): |
| 143 | + for row_id, row in enumerate(matrix): |
| 144 | + for col_id, cell in enumerate(row): |
| 145 | + yield (row_id, col_id, cell) |
| 146 | + |
| 147 | + |
| 148 | + def annotations(r, c, type='model'): |
| 149 | + for nc in range(0, c): |
| 150 | + if type in structure[r, nc]: |
| 151 | + yield Value(structure[r, nc], matrix[r, nc]) |
| 152 | + for nr in range(0, r): |
| 153 | + if type in structure[nr, c]: |
| 154 | + yield Value(structure[nr, c], matrix[nr, c]) |
| 155 | + |
| 156 | + |
| 157 | + number_re = re.compile(r'^[± Ee /()^0-9.%±_-]{2,}$') |
| 158 | + |
| 159 | + proposals = [Proposal( |
| 160 | + cell=Cell(cell_ext_id=f"{table_ext_id}/{r}.{c}", |
| 161 | + table_ext_id=table_ext_id, |
| 162 | + row=r, |
| 163 | + col=c |
| 164 | + ), |
| 165 | + # TODO Add table type: sota / error ablation |
| 166 | + table_description=desc, |
| 167 | + model_values=list(annotations(r, c, 'model')), |
| 168 | + dataset_values=list(annotations(r, c, 'dataset')), |
| 169 | + raw_value=val) |
| 170 | + for r, c, val in consume_cells(matrix) |
| 171 | + if structure[r, c] == '' and number_re.match(matrix[r, c].strip())] |
| 172 | + |
| 173 | + def linked_proposals(proposals): |
| 174 | + for prop in proposals: |
| 175 | + if prop.dataset == '' or prop.model_type == '': |
| 176 | + continue |
| 177 | + if 'dev' in prop.dataset.lower() or 'train' in prop.dataset.lower(): |
| 178 | + continue |
| 179 | + |
| 180 | + df = taxonomy_linking(prop.dataset) |
| 181 | + if not len(df): |
| 182 | + continue |
| 183 | + |
| 184 | + metric = df['metric'][0] |
| 185 | + |
| 186 | + # heuristyic to handle accuracy vs error |
| 187 | + first_num = (list(handle_pm(prop.raw_value)) + [0])[0] |
| 188 | + format = "{x}" |
| 189 | + if first_num > 1: |
| 190 | + first_num /= 100 |
| 191 | + format = "{x/100}" |
| 192 | + |
| 193 | + if ("error" in metric or "Error" in metric) and (first_num > 0.5): |
| 194 | + metric = "Accuracy" |
| 195 | + |
| 196 | + yield { |
| 197 | + 'dataset': df['dataset'][0], |
| 198 | + 'metric': metric, |
| 199 | + 'task': df['task'][0], |
| 200 | + 'format': format, |
| 201 | + 'raw_value': prop.raw_value, |
| 202 | + 'model': prop.model_name, |
| 203 | + 'model_type': prop.model_type, |
| 204 | + 'cell_ext_id': prop.cell.cell_ext_id, |
| 205 | + 'confidence': df['confidence'][0], |
| 206 | + } |
| 207 | + |
| 208 | + return list(linked_proposals(proposals)) |
| 209 | + |
| 210 | +def linked_proposals(paper_ext_id, tables, structure_annotator, taxonomy_linking=MatchSearch()): |
| 211 | + proposals = [] |
| 212 | + for idx, table in enumerate(tables): |
| 213 | + matrix = np.array(table.matrix) |
| 214 | + structure, tags = structure_annotator(table) |
| 215 | + structure = np.array(structure) |
| 216 | + desc = table.desc |
| 217 | + table_ext_id = f"{paper_ext_id}/{table.name}" |
| 218 | + |
| 219 | + if 'sota' in tags and 'no_sota_records' not in tags: # only parse tables that are marked as sota |
| 220 | + proposals += list(generate_proposals_for_table(table_ext_id, matrix, structure, desc, taxonomy_linking)) |
| 221 | + return pd.DataFrame.from_records(proposals) |
| 222 | + |
| 223 | + |
| 224 | +def test_link_taxonomy(): |
| 225 | + link_taxonomy_raw = MatchSearch() |
| 226 | + results = link_taxonomy_raw.search(link_taxonomy_raw.preproc("miniImageNet 5-way 1-shot")) |
| 227 | + # assert "Mini-ImageNet - 1-Shot Learning" == results["hits"][0]["_source"]["dataset"], results |
| 228 | + results = link_taxonomy_raw.search(link_taxonomy_raw.preproc("CoNLL2003")) |
| 229 | + assert "CoNLL 2003 (English)" == results["hits"][0]["_source"]["dataset"], results |
| 230 | + results = link_taxonomy_raw.search(link_taxonomy_raw.preproc("AGNews")) |
| 231 | + assert "AG News" == results["hits"][0]["_source"]["dataset"], results |
| 232 | + link_taxonomy_raw("miniImageNet 5-way 1-shot") |
| 233 | + # %% |
| 234 | + split_re = re.compile('([^a-zA-Z0-9])') |
| 235 | + |
| 236 | + # %% |
| 237 | + q = "miniImageNet 5-way 1-shot Mini ImageNet 1-Shot Learning" * 1 |
| 238 | + r = link_taxonomy_raw.search(q) |
| 239 | + f = len(split_re.split(q)) |
| 240 | + r['hits'][0]['_score'] / f, r['hits'][1]['_score'] / f, r['hits'][0]['_source'] |
| 241 | + # %% |
| 242 | + q = "Mini ImageNet 1-Shot Learning" * 1 |
| 243 | + r = link_taxonomy_raw.search(q) |
| 244 | + f = len(split_re.split(q)) |
| 245 | + r['hits'][0]['_score'] / f, r['hits'][1]['_score'] / f, r['hits'][0]['_source'] |
| 246 | + # %% |
| 247 | + q = "Mini ImageNet 1-Shot" * 1 |
| 248 | + r = link_taxonomy_raw.search(q) |
| 249 | + f = len(split_re.split(q)) |
| 250 | + r['hits'][0]['_score'] / f, r['hits'][1]['_score'] / f, r['hits'][0]['_source'] |
| 251 | + # |
| 252 | + # # %% |
| 253 | + # prop = proposals[1] |
| 254 | + # print(prop) |
| 255 | + # # todo issue with STS-B matching IJB-B |
| 256 | + # link_taxonomy_raw(prop.dataset) |
| 257 | + |
| 258 | + |
0 commit comments