Skip to content

Commit c04a03c

Browse files
committed
Faster fuzzy matchets and changes to linking
1 parent 9bf1f13 commit c04a03c

File tree

11 files changed

+275
-94
lines changed

11 files changed

+275
-94
lines changed

label_tables.py

Lines changed: 68 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import sys
1010
from decimal import Decimal, ROUND_DOWN, ROUND_HALF_UP, InvalidOperation
1111
from collections import Counter, namedtuple
12-
12+
from joblib import delayed, Parallel
13+
from sota_extractor2.data.paper_collection import PaperCollection, remove_arxiv_version
14+
from functools import reduce
1315

1416
arxiv_url_re = re.compile(r"^https?://(www.)?arxiv.org/(abs|pdf|e-print)/(?P<arxiv_id>\d{4}\.[^./]*)(\.pdf)?$")
1517

@@ -33,6 +35,8 @@ def get_table(filename):
3335
return pd.DataFrame()
3436

3537

38+
# all_metadata[arxiv_id] = {'table_01.csv': 'Table 1: ...', ...}
39+
# all_tables[arxiv_id] = {'table_01.csv': DataFrame(...), ...}
3640
def get_tables(tables_dir):
3741
tables_dir = Path(tables_dir)
3842
all_metadata = {}
@@ -223,19 +227,20 @@ def mark_strings(table, tags, values):
223227
if match_str(real, s):
224228
cell_tags += f"{beg}{s}{end}"
225229
return cell_tags
226-
230+
227231

228232
metatables = {}
229-
def match_many(output_dir, task_name, dataset_name, metric_name, tables, values):
233+
def match_many(task_name, dataset_name, metric_name, tables, values):
234+
metatables = {}
230235
for arxiv_id in tables:
231236
for table in tables[arxiv_id]:
232237
tags = mark_with_all_comparators(task_name, dataset_name, metric_name, arxiv_id, tables[arxiv_id][table], values)
233-
global metatables
234238
key = (arxiv_id, table)
235239
if key in metatables:
236240
metatables[key] += tags
237241
else:
238242
metatables[key] = tags
243+
return metatables
239244

240245

241246
def normalize_metric(value):
@@ -256,6 +261,26 @@ def normalize_table(table):
256261
return table.applymap(normalize_cell)
257262

258263

264+
celltags_re = re.compile(r"<hit><sota>(?P<sota>.*?)</sota><paper>(?P<paper>.*?)</paper><model>(?P<model>.*?)</model><metric>(?P<metric>.*?)</metric><dataset>(?P<dataset>.*?)</dataset><task>(?P<task>.*?)</task>(?P<this_paper><this_paper/>)?<comparator>(?P<comparator>.*?)</comparator><matched_cell>(?P<matched_cell>.*?)</matched_cell><matched_str>(?P<matched_str>.*?)</matched_str></hit>")
265+
def parse_celltags(v):
266+
r = []
267+
for m in celltags_re.finditer(v):
268+
d = m.groupdict()
269+
d['this_paper'] = d['this_paper'] is not None
270+
r.append(d)
271+
return r
272+
273+
274+
def celltags_to_json(df):
275+
tags = []
276+
for r, row in df.iterrows():
277+
for c, cell in enumerate(row):
278+
if cell != "":
279+
tags.append(dict(row=r, col=c, hits=parse_celltags(cell)))
280+
return tags
281+
282+
283+
259284
# for each task with sota row
260285
# arxivs <- list of papers related to the task
261286
# for each (dataset_name, metric_name) of the task:
@@ -269,40 +294,62 @@ def normalize_table(table):
269294
# if table.arxiv_id == paper_id: mark with this-tag
270295
PaperResult = namedtuple("PaperResult", ["arxiv_id", "model", "value", "normalized"])
271296

297+
arxivs_by_metrics = {}
298+
tables = {}
299+
300+
def match_for(task, dataset, metric):
301+
records = arxivs_by_metrics[(task, dataset, metric)]
302+
tabs = {r.arxiv_id: tables[r.arxiv_id] for r in records if r.arxiv_id in tables}
303+
return match_many(task, dataset, metric, tabs, records)
304+
272305

273-
def label_tables(tasksfile, tables_dir):
274-
output_dir = Path(tables_dir)
306+
def label_tables(tasksfile, papers_dir, output, jobs=-1):
307+
print("Reading PwC entries...", file=sys.stderr)
275308
tasks = get_sota_tasks(tasksfile)
276-
metadata, tables = get_tables(tables_dir)
309+
print("Reading tables from files...", file=sys.stderr)
310+
pc = PaperCollection.from_files(papers_dir, load_texts=False, load_annotations=False, jobs=jobs)
277311

278-
arxivs_by_metrics = {}
312+
# share data between processes to avoid costly joblib serialization
313+
global arxivs_by_metrics, tables
279314

280-
tables = {arxiv_id: {tab: normalize_table(tables[arxiv_id][tab]) for tab in tables[arxiv_id]} for arxiv_id in tables}
315+
print("Normalizing tables...", file=sys.stderr)
316+
tables = {p.arxiv_no_version: {tab.name: normalize_table(tab.matrix) for tab in p.tables} for p in pc}
281317

318+
print("Aggregating papers...", file=sys.stderr)
282319
for task in tasks:
283320
for dataset in task.datasets:
284321
for row in dataset.sota.rows:
285322
match = arxiv_url_re.match(row.paper_url)
286323
if match is not None:
287-
arxiv_id = match.group("arxiv_id")
324+
arxiv_id = remove_arxiv_version(match.group("arxiv_id"))
288325
for metric in row.metrics:
289326
arxivs_by_metrics.setdefault((task.name, dataset.name, metric), set()).add(
290327
PaperResult(arxiv_id=arxiv_id, model=row.model_name, value=row.metrics[metric],
291328
normalized=normalize_metric(row.metrics[metric])
292329
)
293330
)
294331

295-
for task, dataset, metric in arxivs_by_metrics:
296-
records = arxivs_by_metrics[(task, dataset, metric)]
297-
tabs = {r.arxiv_id: tables[r.arxiv_id] for r in records if r.arxiv_id in tables}
298-
match_many(output_dir, task, dataset, metric, tabs, records)
299-
300-
global metatables
301-
302-
for (arxiv_id, table), best in metatables.items():
303-
out = output_dir / arxiv_id
304-
out.mkdir(parents=True, exist_ok=True)
305-
best.to_csv(out / table.replace("table", "celltags"), header=None, index=None)
332+
print("Matching results...", file=sys.stderr)
333+
metatables_list = Parallel(n_jobs=jobs, backend="multiprocessing")(
334+
[delayed(match_for)(task, dataset, metric)
335+
for task, dataset, metric in arxivs_by_metrics])
336+
337+
print("Aggregating results...", file=sys.stderr)
338+
metatables = {}
339+
for mt in metatables_list:
340+
for k, v in mt.items():
341+
metatables[k] = metatables.get(k, "") + v
342+
grouped_metatables = {}
343+
for (arxiv_id, tablename), df in metatables.items():
344+
grouped_metatables.setdefault(arxiv_id, {})[tablename] = celltags_to_json(df)
345+
346+
with open(output, 'wt') as f:
347+
json.dump(grouped_metatables, f)
348+
# print("Saving matches...", file=sys.stderr)
349+
# for (arxiv_id, table), best in metatables.items():
350+
# out = output_dir / arxiv_id
351+
# out.mkdir(parents=True, exist_ok=True)
352+
# best.to_csv(out / table.replace("table", "celltags"), header=None, index=None)
306353

307354

308355
if __name__ == "__main__": fire.Fire(label_tables)

sota_extractor2/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
# otherwise use this files
1212
data = Path("/mnt/efs/pwc/data")
13-
goldtags_dump = data / "dumps" / "goldtags-2019.08.06_0835.json.gz"
13+
goldtags_dump = data / "dumps" / "goldtags-2019.09.13_0219.json.gz"
1414

1515

1616
elastic = dict(hosts=['localhost'], timeout=20)

sota_extractor2/data/paper_collection.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(self, data=None):
5454
super().__init__(data)
5555

5656
@classmethod
57-
def from_files(cls, path, annotations_path=None, load_texts=True, load_tables=True, jobs=-1, migrate=False):
57+
def from_files(cls, path, annotations_path=None, load_texts=True, load_tables=True, load_annotations=True, jobs=-1, migrate=False):
5858
path = Path(path)
5959
if annotations_path is None:
6060
annotations_path = path / "structure-annotations.json"
@@ -63,12 +63,13 @@ def from_files(cls, path, annotations_path=None, load_texts=True, load_tables=Tr
6363
else:
6464
texts = {}
6565

66-
annotations = _load_annotated_papers(annotations_path)
66+
annotations = {}
6767
if load_tables:
68+
if load_annotations:
69+
annotations = _load_annotated_papers(annotations_path)
6870
tables = _load_tables(path, annotations, jobs, migrate)
6971
else:
7072
tables = {}
71-
annotations = {}
7273
outer_join = set(texts).union(set(tables))
7374

7475
papers = [Paper(k, texts.get(k), tables.get(k, []), annotations.get(k)) for k in outer_join]

sota_extractor2/data/structure.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ def create_evidence_records(textfrag, cell, table):
101101
"cell_type": cell.vals[1],
102102
"cell_content": fix_refs(cell.vals[0]),
103103
"cell_reference": cell.vals[2],
104+
"cell_layout": cell.vals[3],
105+
"cell_styles": cell.vals[4],
104106
"this_paper": textfrag.paper_id == table.paper_id,
105107
"row": cell.row,
106108
"col": cell.col,
@@ -125,7 +127,7 @@ def get_limits(cell_type):
125127
return dict(paper_limit=paper_limit, corpus_limit=corpus_limit)
126128
records = [
127129
record
128-
for cell in consume_cells(table.matrix, table.matrix_gold_tags, table.matrix_references) if filter_cells(cell)
130+
for cell in consume_cells(table.matrix, table.matrix_gold_tags, table.matrix_references, table.matrix_layout, table.matrix_styles) if filter_cells(cell)
129131
for evidence in fetch_evidence(cell.vals[0], cell.vals[2], paper_id=table.paper_id, **get_limits(cell.vals[1]))
130132
for record in create_evidence_records(evidence, cell, table=table)
131133
]

sota_extractor2/data/table.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,14 @@ def __init__(self, name, df, layout, caption=None, figure_id=None, annotations=N
112112
self.dataset_text = ''
113113
self.notes = ''
114114

115+
@property
116+
def matrix(self):
117+
return self.df.applymap(lambda x: x.value)
118+
119+
@property
120+
def matrix_gold_tags(self):
121+
return self.df.applymap(lambda x: x.gold_tags)
122+
115123
@classmethod
116124
def from_file(cls, path, metadata, annotations=None, migrate=False, match_name=None, guessed_tags=None):
117125
path = Path(path)

sota_extractor2/models/linking/bm25_naive.py

Lines changed: 80 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
import pandas as pd
66
from elasticsearch import Elasticsearch, client
77
import logging
8+
#from .extractors import DatasetExtractor
9+
import spacy
10+
from scispacy.abbreviation import AbbreviationDetector
11+
from sota_extractor2.models.linking.format import extract_value
812

913

1014
@dataclass()
@@ -84,9 +88,36 @@ def __init__(self, mkquery=mkquery_ngrams, es=None):
8488
self.log = logging.getLogger(__name__)
8589
self.mkquery = mkquery
8690

87-
def preproc(self, val):
91+
self.nlp = spacy.load("en_core_web_sm")
92+
abbreviation_pipe = AbbreviationDetector(self.nlp)
93+
self.nlp.add_pipe(abbreviation_pipe)
94+
self.nlp.disable_pipes("tagger", "ner", "parser")
95+
96+
def match_abrv(self, dataset, datasets):
97+
abrvs = []
98+
for ds in datasets:
99+
# "!" is a workaround to scispacy error
100+
doc = self.nlp(f"! {ds} ({dataset})")
101+
for abrv in doc._.abbreviations:
102+
if str(abrv) == dataset and str(abrv._.long_form) == ds:
103+
abrvs.append(str(abrv._.long_form))
104+
abrvs = list(set(abrvs))
105+
if len(abrvs) == 1:
106+
print(f"abrv. for {dataset}: {abrvs[0]}")
107+
return abrvs[0]
108+
elif len(abrvs) == 0:
109+
return None
110+
else:
111+
print(f"Multiple abrvs. for {dataset}: {abrvs}")
112+
return None
113+
114+
def preproc(self, val, datasets=None):
88115
val = val.strip(',- ')
89116
val = re.sub("dataset", '', val, flags=re.I)
117+
if datasets:
118+
abrv = self.match_abrv(val, datasets)
119+
if abrv:
120+
val += " " + abrv
90121
# if self.case:
91122
# val += (" " +re.sub("([a-z])([A-Z])", r'\1 \2', val)
92123
# +" " +re.sub("([a-zA-Z])([0-9])", r'\1 \2', val)
@@ -99,9 +130,11 @@ def search(self, query, explain_doc_id=None):
99130
return self.es.explain('et_taxonomy', doc_type='doc', id=explain_doc_id, body=body)
100131
return self.es.search('et_taxonomy', doc_type='doc', body=body)["hits"]
101132

102-
def __call__(self, query):
133+
def __call__(self, query, datasets, caption):
103134
split_re = re.compile('([^a-zA-Z0-9])')
104-
query = self.preproc(query).strip()
135+
query = self.preproc(query, datasets).strip()
136+
if caption:
137+
query += " " + self.preproc(caption).strip()[:400]
105138
results = self.search(query)
106139
hits = results["hits"][:3]
107140
df = pd.DataFrame.from_records([
@@ -136,7 +169,7 @@ def handle_pm(value):
136169
pass
137170
# %%
138171

139-
def generate_proposals_for_table(table_ext_id, matrix, structure, desc, taxonomy_linking):
172+
def generate_proposals_for_table(table_ext_id, matrix, structure, desc, taxonomy_linking, datasets):
140173
# %%
141174
# Proposal generation
142175
def consume_cells(matrix):
@@ -170,30 +203,37 @@ def annotations(r, c, type='model'):
170203
for r, c, val in consume_cells(matrix)
171204
if structure[r, c] == '' and number_re.match(matrix[r, c].strip())]
172205

206+
# def empty_proposal(cell_ext_id, reason):
207+
# np = "not-present"
208+
# return dict(
209+
# dataset=np, metric=np, task=np, format=np, raw_value=np, model=np,
210+
# model_type=np, cell_ext_id=cell_ext_id, confidence=-1, debug_reason=reason
211+
# )
212+
173213
def linked_proposals(proposals):
174214
for prop in proposals:
175-
if prop.dataset == '' or prop.model_type == '':
176-
continue
177-
if 'dev' in prop.dataset.lower() or 'train' in prop.dataset.lower():
178-
continue
179-
180-
df = taxonomy_linking(prop.dataset)
181-
if not len(df):
182-
continue
215+
df = taxonomy_linking(prop.dataset, datasets, desc, debug_info=prop)
216+
assert len(df) == 1
183217

184218
metric = df['metric'][0]
185219

186220
# heuristyic to handle accuracy vs error
187221
first_num = (list(handle_pm(prop.raw_value)) + [0])[0]
188222
format = "{x}"
189-
if first_num > 1:
190-
first_num /= 100
191-
format = "{x/100}"
192-
193-
if ("error" in metric or "Error" in metric) and (first_num > 0.5):
223+
# if first_num > 1:
224+
# first_num /= 100
225+
# format = "{x/100}"
226+
if first_num < 1 and '%' not in prop.raw_value:
227+
first_num *= 100
228+
format = "{100*x}"
229+
if '%' in prop.raw_value:
230+
format += '%'
231+
232+
# if ("error" in metric or "Error" in metric) and (first_num > 0.5):
233+
if (metric.strip().lower() == "error") and (first_num > 0.5):
194234
metric = "Accuracy"
195235

196-
yield {
236+
linked = {
197237
'dataset': df['dataset'][0],
198238
'metric': metric,
199239
'task': df['task'][0],
@@ -203,22 +243,38 @@ def linked_proposals(proposals):
203243
'model_type': prop.model_type,
204244
'cell_ext_id': prop.cell.cell_ext_id,
205245
'confidence': df['confidence'][0],
246+
'struct_model_type': prop.model_type,
247+
'struct_dataset': prop.dataset
206248
}
249+
yield linked
250+
251+
# specify columns in case there's no proposal
252+
columns = ['dataset', 'metric', 'task', 'format', 'raw_value', 'model', 'model_type', 'cell_ext_id', 'confidence', 'parsed',
253+
'struct_model_type', 'struct_dataset']
254+
proposals = pd.DataFrame.from_records(list(linked_proposals(proposals)), columns=columns)
255+
256+
if len(proposals):
257+
proposals["parsed"]=proposals[["raw_value", "format"]].apply(
258+
lambda row: float(extract_value(row.raw_value, row.format)), axis=1)
259+
return proposals
207260

208-
return list(linked_proposals(proposals))
209261

210-
def linked_proposals(paper_ext_id, tables, structure_annotator, taxonomy_linking=MatchSearch()):
262+
def linked_proposals(paper_ext_id, paper, tables, structure_annotator, taxonomy_linking=MatchSearch(),
263+
dataset_extractor=None):
264+
# dataset_extractor=DatasetExtractor()):
211265
proposals = []
266+
datasets = dataset_extractor.from_paper(paper)
267+
print(f"Extracted datasets: {datasets}")
212268
for idx, table in enumerate(tables):
213269
matrix = np.array(table.matrix)
214-
structure, tags = structure_annotator(table)
270+
structure, tags = structure_annotator(paper, table)
215271
structure = np.array(structure)
216-
desc = table.desc
272+
desc = table.caption
217273
table_ext_id = f"{paper_ext_id}/{table.name}"
218274

219275
if 'sota' in tags and 'no_sota_records' not in tags: # only parse tables that are marked as sota
220-
proposals += list(generate_proposals_for_table(table_ext_id, matrix, structure, desc, taxonomy_linking))
221-
return pd.DataFrame.from_records(proposals)
276+
proposals.append(generate_proposals_for_table(table_ext_id, matrix, structure, desc, taxonomy_linking, datasets))
277+
return pd.concat(proposals)
222278

223279

224280
def test_link_taxonomy():

0 commit comments

Comments
 (0)