Skip to content

Commit 9bf1f13

Browse files
committed
Copy linking files from pwc repo
1 parent ff48090 commit 9bf1f13

File tree

4 files changed

+499
-0
lines changed

4 files changed

+499
-0
lines changed
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
import re
2+
from decimal import Decimal
3+
from dataclasses import dataclass
4+
import numpy as np
5+
import pandas as pd
6+
from elasticsearch import Elasticsearch, client
7+
import logging
8+
9+
10+
@dataclass()
11+
class Value:
12+
type: str
13+
value: str
14+
def __str__(self):
15+
return self.value
16+
17+
18+
@dataclass()
19+
class Cell:
20+
cell_ext_id: str
21+
table_ext_id: str
22+
row: int
23+
col: int
24+
25+
26+
@dataclass()
27+
class Proposal:
28+
cell: Cell
29+
dataset_values: list
30+
table_description: str
31+
model_values: list # best paper competing
32+
model_params: dict = None
33+
raw_value: str = ""
34+
35+
def __post_init__(self):
36+
if self.model_params is None:
37+
self.model_params = {}
38+
39+
@property
40+
def dataset(self):
41+
return ' '.join(map(str, self.dataset_values)).strip()
42+
43+
@property
44+
def model_name(self):
45+
return ' '.join(map(str, self.model_values)).strip()
46+
47+
@property
48+
def model_type(self):
49+
types = [v.type for v in self.model_values] + ['']
50+
if 'model-competing' in types:
51+
return 'model-competing' # competing model is different from model-paper and model-best so we return it first
52+
return types[0]
53+
54+
def __str__(self):
55+
return f"{self.model_name}: {self.raw_value} on {self.dataset}"
56+
57+
def mkquery_ngrams(query):
58+
return {
59+
"query": {
60+
"multi_match": {
61+
"query": query,
62+
"fields": ["dataset^3", "dataset.ngrams^1", "metric^1", "metric.ngrams^1", "task^1",
63+
"task.ngrams^1"]
64+
}
65+
}
66+
}
67+
68+
69+
def mkquery_fullmatch(query):
70+
return {
71+
"query": {
72+
"multi_match": {
73+
"query": query,
74+
"fields": ["dataset^3", "metric^1", "task^1"]
75+
}
76+
}
77+
}
78+
79+
class MatchSearch:
80+
def __init__(self, mkquery=mkquery_ngrams, es=None):
81+
self.case = True
82+
self.all_fields = True
83+
self.es = es or Elasticsearch()
84+
self.log = logging.getLogger(__name__)
85+
self.mkquery = mkquery
86+
87+
def preproc(self, val):
88+
val = val.strip(',- ')
89+
val = re.sub("dataset", '', val, flags=re.I)
90+
# if self.case:
91+
# val += (" " +re.sub("([a-z])([A-Z])", r'\1 \2', val)
92+
# +" " +re.sub("([a-zA-Z])([0-9])", r'\1 \2', val)
93+
# )
94+
return val
95+
96+
def search(self, query, explain_doc_id=None):
97+
body = self.mkquery(query)
98+
if explain_doc_id is not None:
99+
return self.es.explain('et_taxonomy', doc_type='doc', id=explain_doc_id, body=body)
100+
return self.es.search('et_taxonomy', doc_type='doc', body=body)["hits"]
101+
102+
def __call__(self, query):
103+
split_re = re.compile('([^a-zA-Z0-9])')
104+
query = self.preproc(query).strip()
105+
results = self.search(query)
106+
hits = results["hits"][:3]
107+
df = pd.DataFrame.from_records([
108+
dict(**hit["_source"],
109+
confidence=hit["_score"] / len(split_re.split(query)),
110+
# Roughly normalize the score not to ignore query length
111+
evidence=query) for hit in hits
112+
], columns=["dataset", "metric", "task", "confidence", "evidence"])
113+
if not len(df):
114+
self.log.debug("Elastic query didn't produce any output", query, hits)
115+
else:
116+
scores = []
117+
for dataset in df["dataset"]:
118+
r = self.search(dataset)
119+
scores.append(
120+
dict(ok_score=r['hits'][0]['_score'] / len(split_re.split(dataset)),
121+
bad_score=r['hits'][1]['_score'] / len(split_re.split(dataset))))
122+
123+
scores = pd.DataFrame.from_records(scores)
124+
df['confidence'] = ((scores['ok_score'] - scores['bad_score']) / scores['bad_score']) * df['confidence'] / scores['ok_score']
125+
return df[["dataset", "metric", "task", "confidence", "evidence"]]
126+
127+
float_pm_re = re.compile(r"(±?)([+-]?\s*(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?)\s*(%?)")
128+
whitespace_re = re.compile(r"\s+")
129+
def handle_pm(value):
130+
"handle precentage metric"
131+
for match in float_pm_re.findall(value):
132+
if not match[0]:
133+
try:
134+
yield Decimal(whitespace_re.sub("", match[1])) / (100 if match[-1] else 1)
135+
except:
136+
pass
137+
# %%
138+
139+
def generate_proposals_for_table(table_ext_id, matrix, structure, desc, taxonomy_linking):
140+
# %%
141+
# Proposal generation
142+
def consume_cells(matrix):
143+
for row_id, row in enumerate(matrix):
144+
for col_id, cell in enumerate(row):
145+
yield (row_id, col_id, cell)
146+
147+
148+
def annotations(r, c, type='model'):
149+
for nc in range(0, c):
150+
if type in structure[r, nc]:
151+
yield Value(structure[r, nc], matrix[r, nc])
152+
for nr in range(0, r):
153+
if type in structure[nr, c]:
154+
yield Value(structure[nr, c], matrix[nr, c])
155+
156+
157+
number_re = re.compile(r'^[± Ee /()^0-9.%±_-]{2,}$')
158+
159+
proposals = [Proposal(
160+
cell=Cell(cell_ext_id=f"{table_ext_id}/{r}.{c}",
161+
table_ext_id=table_ext_id,
162+
row=r,
163+
col=c
164+
),
165+
# TODO Add table type: sota / error ablation
166+
table_description=desc,
167+
model_values=list(annotations(r, c, 'model')),
168+
dataset_values=list(annotations(r, c, 'dataset')),
169+
raw_value=val)
170+
for r, c, val in consume_cells(matrix)
171+
if structure[r, c] == '' and number_re.match(matrix[r, c].strip())]
172+
173+
def linked_proposals(proposals):
174+
for prop in proposals:
175+
if prop.dataset == '' or prop.model_type == '':
176+
continue
177+
if 'dev' in prop.dataset.lower() or 'train' in prop.dataset.lower():
178+
continue
179+
180+
df = taxonomy_linking(prop.dataset)
181+
if not len(df):
182+
continue
183+
184+
metric = df['metric'][0]
185+
186+
# heuristyic to handle accuracy vs error
187+
first_num = (list(handle_pm(prop.raw_value)) + [0])[0]
188+
format = "{x}"
189+
if first_num > 1:
190+
first_num /= 100
191+
format = "{x/100}"
192+
193+
if ("error" in metric or "Error" in metric) and (first_num > 0.5):
194+
metric = "Accuracy"
195+
196+
yield {
197+
'dataset': df['dataset'][0],
198+
'metric': metric,
199+
'task': df['task'][0],
200+
'format': format,
201+
'raw_value': prop.raw_value,
202+
'model': prop.model_name,
203+
'model_type': prop.model_type,
204+
'cell_ext_id': prop.cell.cell_ext_id,
205+
'confidence': df['confidence'][0],
206+
}
207+
208+
return list(linked_proposals(proposals))
209+
210+
def linked_proposals(paper_ext_id, tables, structure_annotator, taxonomy_linking=MatchSearch()):
211+
proposals = []
212+
for idx, table in enumerate(tables):
213+
matrix = np.array(table.matrix)
214+
structure, tags = structure_annotator(table)
215+
structure = np.array(structure)
216+
desc = table.desc
217+
table_ext_id = f"{paper_ext_id}/{table.name}"
218+
219+
if 'sota' in tags and 'no_sota_records' not in tags: # only parse tables that are marked as sota
220+
proposals += list(generate_proposals_for_table(table_ext_id, matrix, structure, desc, taxonomy_linking))
221+
return pd.DataFrame.from_records(proposals)
222+
223+
224+
def test_link_taxonomy():
225+
link_taxonomy_raw = MatchSearch()
226+
results = link_taxonomy_raw.search(link_taxonomy_raw.preproc("miniImageNet 5-way 1-shot"))
227+
# assert "Mini-ImageNet - 1-Shot Learning" == results["hits"][0]["_source"]["dataset"], results
228+
results = link_taxonomy_raw.search(link_taxonomy_raw.preproc("CoNLL2003"))
229+
assert "CoNLL 2003 (English)" == results["hits"][0]["_source"]["dataset"], results
230+
results = link_taxonomy_raw.search(link_taxonomy_raw.preproc("AGNews"))
231+
assert "AG News" == results["hits"][0]["_source"]["dataset"], results
232+
link_taxonomy_raw("miniImageNet 5-way 1-shot")
233+
# %%
234+
split_re = re.compile('([^a-zA-Z0-9])')
235+
236+
# %%
237+
q = "miniImageNet 5-way 1-shot Mini ImageNet 1-Shot Learning" * 1
238+
r = link_taxonomy_raw.search(q)
239+
f = len(split_re.split(q))
240+
r['hits'][0]['_score'] / f, r['hits'][1]['_score'] / f, r['hits'][0]['_source']
241+
# %%
242+
q = "Mini ImageNet 1-Shot Learning" * 1
243+
r = link_taxonomy_raw.search(q)
244+
f = len(split_re.split(q))
245+
r['hits'][0]['_score'] / f, r['hits'][1]['_score'] / f, r['hits'][0]['_source']
246+
# %%
247+
q = "Mini ImageNet 1-Shot" * 1
248+
r = link_taxonomy_raw.search(q)
249+
f = len(split_re.split(q))
250+
r['hits'][0]['_score'] / f, r['hits'][1]['_score'] / f, r['hits'][0]['_source']
251+
#
252+
# # %%
253+
# prop = proposals[1]
254+
# print(prop)
255+
# # todo issue with STS-B matching IJB-B
256+
# link_taxonomy_raw(prop.dataset)
257+
258+
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import pandas as pd
2+
from django.db import connection
3+
4+
from sota import models
5+
from sota.pipeline.format import extract_value
6+
from sota.pipeline.metrics import Metrics
7+
8+
9+
def q(query, limit=10, index_col=None):
10+
if limit is not None:
11+
query = query.rstrip(" ;") + f" LIMIT {limit}"
12+
return pd.read_sql(query, connection, index_col=index_col)
13+
14+
def execute_model_on_papers(model, papers):
15+
papers = models.Paper.objects.filter(pk__in=papers)
16+
proposals = []
17+
for paper in papers:
18+
print("Parsing ", paper.id)
19+
paper_proposals = model(paper.id, paper.table_set.all())
20+
proposals.append(paper_proposals)
21+
proposals = pd.concat(proposals)
22+
proposals["parsed"]=proposals[["raw_value", "format"]].apply(
23+
lambda row: float(extract_value(row.raw_value, row.format)), axis=1)
24+
proposals["experiment_name"] = model.__name__
25+
return proposals.set_index('cell_ext_id')
26+
27+
28+
def fetch_gold_sota_records():
29+
gold_sota_records = q("""
30+
SELECT sc.id as cell_id,
31+
st.paper_id,
32+
CONCAT(st.paper_id, '/', st.name, '/', sr.row,'.', sr.col) as cell_ext_id,
33+
(SELECT gold_tags FROM sota_cell WHERE (row=sc.row or col=sc.col) and table_id=sc.table_id and gold_tags LIKE 'model%' LIMIT 1) as model_type,
34+
task, dataset, metric, model, format, sc.value as raw_value
35+
FROM
36+
sota_record sr
37+
JOIN sota_cell sc USING (table_id, row, col)
38+
JOIN sota_table st ON (sc.table_id=st.id)
39+
WHERE dataset != '' and task != '' and metric != '' and model != '';""", limit=None)
40+
gold_sota_records["parsed"] = gold_sota_records[["raw_value", "format"]].apply(
41+
lambda row: float(extract_value(row.raw_value, row.format)), axis=1)
42+
43+
gold_sota_records = gold_sota_records[gold_sota_records["parsed"] == gold_sota_records["parsed"]]
44+
45+
strip_cols=["task", "dataset", "format", "metric", "raw_value", "model", "model_type"]
46+
gold_sota_records = gold_sota_records.transform(
47+
lambda x: x.str.strip() if x.name in strip_cols else x)
48+
gold_sota_records = gold_sota_records.set_index('cell_ext_id')
49+
return gold_sota_records
50+
51+
def fetch_gold_sota_papers():
52+
return q("""
53+
SELECT st.paper_id
54+
FROM
55+
sota_record sr
56+
JOIN sota_cell sc USING (table_id, row, col)
57+
JOIN sota_table st ON (sc.table_id=st.id)
58+
WHERE dataset != '' and task != '' and metric != '' and model != ''
59+
GROUP BY st.paper_id;""", limit=None)["paper_id"].tolist()
60+
61+
class Evaluator():
62+
def __init__(self, model):
63+
self.model = model
64+
self.annotated_papers = fetch_gold_sota_papers()
65+
self.raw_proposals = None
66+
67+
def run_model(self):
68+
self.raw_proposals = execute_model_on_papers(model=self.model, papers=self.annotated_papers)
69+
70+
def evaluate(self, confidence=-1):
71+
if self.raw_proposals is None:
72+
self.run_model()
73+
proposals = self.raw_proposals[self.raw_proposals['confidence'] > confidence]
74+
gold_sota_records = fetch_gold_sota_records()
75+
df = gold_sota_records.merge(proposals, 'outer', left_index=True, right_index=True, suffixes=['_gold', '_pred'])
76+
df = df.reindex(sorted(df.columns), axis=1)
77+
df = df.fillna('not-present')
78+
if "experiment_name" in df.columns:
79+
del df["experiment_name"]
80+
81+
return Metrics(df, experiment_name=self.model.__name__)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import re
2+
from decimal import Decimal, ROUND_DOWN, ROUND_HALF_UP, InvalidOperation
3+
4+
float_value_re = re.compile(r"([+-]?(?:(?:\d{1,2}(?:,\d{3})+|\d+)(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)")
5+
float_value_nc = re.compile(r"(?:[+-]?(?:(?:\d{1,2}(?:,\d{3})+|\d+)(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)")
6+
par_re = re.compile(r"\{([^\}]*)\}")
7+
escaped_whitespace_re = re.compile(r"(\\\s)+")
8+
9+
def format_to_regexp(format):
10+
placeholders = par_re.split(format.strip())
11+
regexp = ""
12+
fn=lambda x: x
13+
for i, s in enumerate(placeholders):
14+
if i % 2 == 0:
15+
regexp += escaped_whitespace_re.sub(r"\\s+", re.escape(s))
16+
elif s.strip() == "":
17+
regexp += float_value_nc.pattern
18+
else:
19+
regexp += float_value_re.pattern
20+
ss = s.strip();
21+
if ss == "100*x" or ss == "100x":
22+
fn = lambda x: 100*x
23+
elif ss == "x/100":
24+
fn = lambda x: x/100
25+
return re.compile('^'+regexp+'$'), fn
26+
27+
def extract_value(cell_value, format):
28+
regexp, fn = format_to_regexp(format)
29+
match = regexp.match(cell_value)
30+
if match is None:
31+
return Decimal('NaN')
32+
return fn(Decimal(match.group(1)))

0 commit comments

Comments
 (0)