Skip to content

Commit 184224c

Browse files
authored
Merge pull request #2 from paperswithcode/pipeline
Pipeline
2 parents 8385aaf + 5026084 commit 184224c

40 files changed

+2823
-126
lines changed

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ dependencies:
1616
- elasticsearch-dsl=7.0.0
1717
- ipython=7.5.0
1818
- joblib=0.13.2
19+
- python-magic=0.4.15

extract_tables.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
from dataclasses import dataclass
1414
from typing import Set
1515

16-
from tabular import Tabular
17-
16+
from sota_extractor2.data.table import Table
1817

1918
# begin of dirty hack
2019
# pandas parsing of html tables is really nice
@@ -265,18 +264,13 @@ def html2data(table):
265264
return data[0] if len(data) == 1 else None
266265

267266

268-
def save_table(data, filename):
269-
data.to_csv(filename, header=None, index=None)
270-
271-
272267
def save_tables(data, outdir):
273268
metadata = []
274269

275270
for num, table in enumerate(data, 1):
276271
filename = f"table_{num:02}.csv"
277272
layout = f"layout_{num:02}.csv"
278-
save_table(table.data, outdir / filename)
279-
save_table(table.layout, outdir / layout)
273+
table.save(outdir, filename, layout)
280274
metadata.append(dict(filename=filename, layout=layout, caption=table.caption, figure_id=table.figure_id))
281275
with open(outdir / "metadata.json", "w") as f:
282276
json.dump(metadata, f)
@@ -341,11 +335,7 @@ def remove_footnotes(soup):
341335
elem.extract()
342336

343337

344-
def extract_tables(filename, outdir):
345-
with open(filename, "rb") as f:
346-
html = f.read()
347-
outdir = Path(outdir)
348-
outdir.mkdir(parents=True, exist_ok=True)
338+
def extract_tables(html):
349339
soup = BeautifulSoup(html, "lxml", from_encoding="utf-8")
350340
set_ids_by_labels(soup)
351341
fix_span_tables(soup)
@@ -381,8 +371,15 @@ def extract_tables(filename, outdir):
381371
if cap_el is not None:
382372
caption = clear_ws(cap_el.get_text())
383373
figure_id = table.get("data-figure-id")
384-
data.append(Tabular(tab, layout, caption, figure_id))
374+
data.append(Table(f"table_{len(data)+1:02}", tab, layout.applymap(str), caption, figure_id))
375+
return data
385376

386-
save_tables(data, outdir)
377+
def extract_tables_cmd(filename, outdir):
378+
with open(filename, "rb") as f:
379+
html = f.read()
380+
tables = extract_tables(html)
381+
outdir = Path(outdir)
382+
outdir.mkdir(parents=True, exist_ok=True)
383+
save_tables(tables, outdir)
387384

388-
if __name__ == "__main__": fire.Fire(extract_tables)
385+
if __name__ == "__main__": fire.Fire(extract_tables_cmd)

label_tables.py

Lines changed: 68 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import sys
1010
from decimal import Decimal, ROUND_DOWN, ROUND_HALF_UP, InvalidOperation
1111
from collections import Counter, namedtuple
12-
12+
from joblib import delayed, Parallel
13+
from sota_extractor2.data.paper_collection import PaperCollection, remove_arxiv_version
14+
from functools import reduce
1315

1416
arxiv_url_re = re.compile(r"^https?://(www.)?arxiv.org/(abs|pdf|e-print)/(?P<arxiv_id>\d{4}\.[^./]*)(\.pdf)?$")
1517

@@ -33,6 +35,8 @@ def get_table(filename):
3335
return pd.DataFrame()
3436

3537

38+
# all_metadata[arxiv_id] = {'table_01.csv': 'Table 1: ...', ...}
39+
# all_tables[arxiv_id] = {'table_01.csv': DataFrame(...), ...}
3640
def get_tables(tables_dir):
3741
tables_dir = Path(tables_dir)
3842
all_metadata = {}
@@ -223,19 +227,20 @@ def mark_strings(table, tags, values):
223227
if match_str(real, s):
224228
cell_tags += f"{beg}{s}{end}"
225229
return cell_tags
226-
230+
227231

228232
metatables = {}
229-
def match_many(output_dir, task_name, dataset_name, metric_name, tables, values):
233+
def match_many(task_name, dataset_name, metric_name, tables, values):
234+
metatables = {}
230235
for arxiv_id in tables:
231236
for table in tables[arxiv_id]:
232237
tags = mark_with_all_comparators(task_name, dataset_name, metric_name, arxiv_id, tables[arxiv_id][table], values)
233-
global metatables
234238
key = (arxiv_id, table)
235239
if key in metatables:
236240
metatables[key] += tags
237241
else:
238242
metatables[key] = tags
243+
return metatables
239244

240245

241246
def normalize_metric(value):
@@ -256,6 +261,26 @@ def normalize_table(table):
256261
return table.applymap(normalize_cell)
257262

258263

264+
celltags_re = re.compile(r"<hit><sota>(?P<sota>.*?)</sota><paper>(?P<paper>.*?)</paper><model>(?P<model>.*?)</model><metric>(?P<metric>.*?)</metric><dataset>(?P<dataset>.*?)</dataset><task>(?P<task>.*?)</task>(?P<this_paper><this_paper/>)?<comparator>(?P<comparator>.*?)</comparator><matched_cell>(?P<matched_cell>.*?)</matched_cell><matched_str>(?P<matched_str>.*?)</matched_str></hit>")
265+
def parse_celltags(v):
266+
r = []
267+
for m in celltags_re.finditer(v):
268+
d = m.groupdict()
269+
d['this_paper'] = d['this_paper'] is not None
270+
r.append(d)
271+
return r
272+
273+
274+
def celltags_to_json(df):
275+
tags = []
276+
for r, row in df.iterrows():
277+
for c, cell in enumerate(row):
278+
if cell != "":
279+
tags.append(dict(row=r, col=c, hits=parse_celltags(cell)))
280+
return tags
281+
282+
283+
259284
# for each task with sota row
260285
# arxivs <- list of papers related to the task
261286
# for each (dataset_name, metric_name) of the task:
@@ -269,40 +294,62 @@ def normalize_table(table):
269294
# if table.arxiv_id == paper_id: mark with this-tag
270295
PaperResult = namedtuple("PaperResult", ["arxiv_id", "model", "value", "normalized"])
271296

297+
arxivs_by_metrics = {}
298+
tables = {}
299+
300+
def match_for(task, dataset, metric):
301+
records = arxivs_by_metrics[(task, dataset, metric)]
302+
tabs = {r.arxiv_id: tables[r.arxiv_id] for r in records if r.arxiv_id in tables}
303+
return match_many(task, dataset, metric, tabs, records)
304+
272305

273-
def label_tables(tasksfile, tables_dir):
274-
output_dir = Path(tables_dir)
306+
def label_tables(tasksfile, papers_dir, output, jobs=-1):
307+
print("Reading PwC entries...", file=sys.stderr)
275308
tasks = get_sota_tasks(tasksfile)
276-
metadata, tables = get_tables(tables_dir)
309+
print("Reading tables from files...", file=sys.stderr)
310+
pc = PaperCollection.from_files(papers_dir, load_texts=False, load_annotations=False, jobs=jobs)
277311

278-
arxivs_by_metrics = {}
312+
# share data between processes to avoid costly joblib serialization
313+
global arxivs_by_metrics, tables
279314

280-
tables = {arxiv_id: {tab: normalize_table(tables[arxiv_id][tab]) for tab in tables[arxiv_id]} for arxiv_id in tables}
315+
print("Normalizing tables...", file=sys.stderr)
316+
tables = {p.arxiv_no_version: {tab.name: normalize_table(tab.matrix) for tab in p.tables} for p in pc}
281317

318+
print("Aggregating papers...", file=sys.stderr)
282319
for task in tasks:
283320
for dataset in task.datasets:
284321
for row in dataset.sota.rows:
285322
match = arxiv_url_re.match(row.paper_url)
286323
if match is not None:
287-
arxiv_id = match.group("arxiv_id")
324+
arxiv_id = remove_arxiv_version(match.group("arxiv_id"))
288325
for metric in row.metrics:
289326
arxivs_by_metrics.setdefault((task.name, dataset.name, metric), set()).add(
290327
PaperResult(arxiv_id=arxiv_id, model=row.model_name, value=row.metrics[metric],
291328
normalized=normalize_metric(row.metrics[metric])
292329
)
293330
)
294331

295-
for task, dataset, metric in arxivs_by_metrics:
296-
records = arxivs_by_metrics[(task, dataset, metric)]
297-
tabs = {r.arxiv_id: tables[r.arxiv_id] for r in records if r.arxiv_id in tables}
298-
match_many(output_dir, task, dataset, metric, tabs, records)
299-
300-
global metatables
301-
302-
for (arxiv_id, table), best in metatables.items():
303-
out = output_dir / arxiv_id
304-
out.mkdir(parents=True, exist_ok=True)
305-
best.to_csv(out / table.replace("table", "celltags"), header=None, index=None)
332+
print("Matching results...", file=sys.stderr)
333+
metatables_list = Parallel(n_jobs=jobs, backend="multiprocessing")(
334+
[delayed(match_for)(task, dataset, metric)
335+
for task, dataset, metric in arxivs_by_metrics])
336+
337+
print("Aggregating results...", file=sys.stderr)
338+
metatables = {}
339+
for mt in metatables_list:
340+
for k, v in mt.items():
341+
metatables[k] = metatables.get(k, "") + v
342+
grouped_metatables = {}
343+
for (arxiv_id, tablename), df in metatables.items():
344+
grouped_metatables.setdefault(arxiv_id, {})[tablename] = celltags_to_json(df)
345+
346+
with open(output, 'wt') as f:
347+
json.dump(grouped_metatables, f)
348+
# print("Saving matches...", file=sys.stderr)
349+
# for (arxiv_id, table), best in metatables.items():
350+
# out = output_dir / arxiv_id
351+
# out.mkdir(parents=True, exist_ok=True)
352+
# best.to_csv(out / table.replace("table", "celltags"), header=None, index=None)
306353

307354

308355
if __name__ == "__main__": fire.Fire(label_tables)

sota_extractor2/config.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
# otherwise use this files
1212
data = Path("/mnt/efs/pwc/data")
13-
goldtags_dump = data / "dumps" / "goldtags-2019.08.06_0835.json.gz"
13+
goldtags_dump = data / "dumps" / "goldtags-2019.10.15_2227.json.gz"
1414

1515

1616
elastic = dict(hosts=['localhost'], timeout=20)
@@ -22,3 +22,11 @@
2222

2323
datasets = data/"datasets"
2424
datasets_structure = datasets/"structure"
25+
structure_models = datasets / "structure" / "models"
26+
27+
mocks = datasets / "mocks"
28+
29+
linking_models = datasets / "linking" / "models"
30+
linking_data = datasets / "linking" / "data"
31+
32+
autodict = linking_data / "autodict"

sota_extractor2/data/elastic.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from bs4 import BeautifulSoup
12
import pandas as pd
23
import re
34

@@ -162,9 +163,10 @@ def from_json(cls, json, paper_id=None):
162163
return paper
163164

164165
@classmethod
165-
def from_file(cls, path):
166+
def from_file(cls, path, paper_id=None):
166167
path = Path(path)
167-
paper_id = path.parent.name
168+
if paper_id is None:
169+
paper_id = path.parent.name
168170
with open(path, "rt") as f:
169171
json = f.read()
170172
return cls.from_json(json, paper_id)
@@ -187,6 +189,12 @@ def save(self, **kwargs):
187189
else:
188190
return super().save(**kwargs)
189191

192+
def delete(self, **kwargs):
193+
if hasattr(self, 'fragments'):
194+
for f in self.fragments:
195+
f.delete()
196+
return super().delete(**kwargs)
197+
190198
@classmethod
191199
def parse_html(cls, soup, paper_id):
192200
put_dummy_anchors(soup)
@@ -254,9 +262,17 @@ def read_html(cls, file):
254262
return read_html(file)
255263

256264
@classmethod
257-
def parse_paper(cls, file):
265+
def from_html(cls, html, paper_id):
266+
soup = BeautifulSoup(html, "html.parser")
267+
return cls.parse_html(soup, paper_id)
268+
269+
@classmethod
270+
def parse_paper(cls, file, paper_id=None):
271+
file = Path(file)
258272
soup = cls.read_html(file)
259-
return cls.parse_html(soup, file.stem)
273+
if paper_id is None:
274+
paper_id = file.stem
275+
return cls.parse_html(soup, paper_id)
260276

261277

262278
class Author(InnerDoc):

sota_extractor2/data/paper_collection.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
from joblib import Parallel, delayed
88
from collections import UserList
99
from ..helpers.jupyter import display_table
10+
import string
11+
import random
12+
from extract_tables import extract_tables
13+
1014

1115
class Paper:
1216
def __init__(self, paper_id, text, tables, annotations):
@@ -24,6 +28,33 @@ def __init__(self, paper_id, text, tables, annotations):
2428
else:
2529
self.gold_tags = ''
2630

31+
def table_by_name(self, name):
32+
for table in self.tables:
33+
if table.name == name:
34+
return table
35+
return None
36+
37+
38+
# todo: make sure multithreading/processing won't cause collisions
39+
def random_id():
40+
return "temp_" + ''.join(random.choice(string.ascii_lowercase) for i in range(10))
41+
42+
43+
class TempPaper(Paper):
44+
"""Similar to Paper, but can be used as context manager, temporarily saving the paper to elastic"""
45+
def __init__(self, html):
46+
paper_id = random_id()
47+
text = PaperText.from_html(html, paper_id)
48+
tables = extract_tables(html)
49+
super().__init__(paper_id=paper_id, text=text, tables=tables, annotations=None)
50+
51+
def __enter__(self):
52+
self.text.save()
53+
return self
54+
55+
def __exit__(self, exc, value, tb):
56+
self.text.delete()
57+
2758

2859
arxiv_version_re = re.compile(r"v\d+$")
2960
def remove_arxiv_version(arxiv_id):
@@ -42,8 +73,12 @@ def _load_tables(path, annotations, jobs, migrate):
4273
return {f.parent.name: tbls for f, tbls in zip(files, tables)}
4374

4475

45-
def _load_annotated_papers(path):
46-
dump = load_gql_dump(path, compressed=path.suffix == ".gz")["allPapers"]
76+
def _load_annotated_papers(data_or_path):
77+
if isinstance(data_or_path, dict):
78+
compressed = False
79+
else:
80+
compressed = data_or_path.suffix == ".gz"
81+
dump = load_gql_dump(data_or_path, compressed=compressed)["allPapers"]
4782
annotations = {remove_arxiv_version(a.arxiv_id): a for a in dump}
4883
annotations.update({a.arxiv_id: a for a in dump})
4984
return annotations
@@ -54,21 +89,24 @@ def __init__(self, data=None):
5489
super().__init__(data)
5590

5691
@classmethod
57-
def from_files(cls, path, annotations_path=None, load_texts=True, load_tables=True, jobs=-1, migrate=False):
92+
def from_files(cls, path, annotations_path=None, load_texts=True, load_tables=True, load_annotations=True, jobs=-1, migrate=False):
5893
path = Path(path)
5994
if annotations_path is None:
6095
annotations_path = path / "structure-annotations.json"
96+
else:
97+
annotations_path = Path(annotations_path)
6198
if load_texts:
6299
texts = _load_texts(path, jobs)
63100
else:
64101
texts = {}
65102

66-
annotations = _load_annotated_papers(annotations_path)
103+
annotations = {}
67104
if load_tables:
105+
if load_annotations:
106+
annotations = _load_annotated_papers(annotations_path)
68107
tables = _load_tables(path, annotations, jobs, migrate)
69108
else:
70109
tables = {}
71-
annotations = {}
72110
outer_join = set(texts).union(set(tables))
73111

74112
papers = [Paper(k, texts.get(k), tables.get(k, []), annotations.get(k)) for k in outer_join]

0 commit comments

Comments
 (0)