Skip to content

Commit 917f1b3

Browse files
committed
Test multiprocessing
1 parent b40eefe commit 917f1b3

File tree

3 files changed

+27
-19
lines changed

3 files changed

+27
-19
lines changed

environment.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@ dependencies:
1515
- Unidecode=1.0.23
1616
- elasticsearch-dsl=7.0.0
1717
- ipython=7.5.0
18+
- tqdm=4.28.1
19+
- joblib=0.13.2

sota_extractor2/data/paper_collection.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from .json import load_gql_dump
44
from pathlib import Path
55
import re
6+
from tqdm import tqdm
7+
from joblib import Parallel, delayed
68

79
class Paper:
810
def __init__(self, text, tables, annotations):
@@ -15,11 +17,11 @@ def __init__(self, text, tables, annotations):
1517

1618

1719
arxiv_version_re = re.compile(r"v\d+$")
18-
def clean_arxiv_version(arxiv_id):
20+
def clear_arxiv_version(arxiv_id):
1921
return arxiv_version_re.sub("", arxiv_id)
2022

2123

22-
class PaperCollection:
24+
class PaperCollection(dict):
2325
def __init__(self, path, load_texts=True, load_tables=True):
2426
self.path = path
2527
self.load_texts = load_texts
@@ -50,27 +52,20 @@ def __iter__(self):
5052
return iter(self._papers)
5153

5254
def _load_texts(self):
53-
texts = {}
54-
55-
for f in (self.path / "texts").glob("**/*.json"):
56-
text = PaperText.from_file(f)
57-
texts[clean_arxiv_version(text.meta.id)] = text
58-
return texts
55+
files = list((self.path / "texts").glob("**/*.json"))
56+
texts = Parallel(n_jobs=-1, prefer="processes")(delayed(PaperText.from_file)(f) for f in files)
57+
return {clear_arxiv_version(text.meta.id): text for text in texts}
5958

6059

6160
def _load_tables(self, annotations):
62-
tables = {}
63-
64-
for f in (self.path / "tables").glob("**/metadata.json"):
65-
paper_dir = f.parent
66-
tbls = read_tables(paper_dir, annotations)
67-
tables[clean_arxiv_version(paper_dir.name)] = tbls
68-
return tables
61+
files = list((self.path / "tables").glob("**/metadata.json"))
62+
tables = Parallel(n_jobs=-1, prefer="processes")(delayed(read_tables)(f.parent, annotations) for f in files)
63+
return {clear_arxiv_version(f.parent.name): tbls for f, tbls in zip(files, tables)}
6964

7065
def _load_annotated_papers(self):
7166
dump = load_gql_dump(self.path / "structure-annotations.json.gz", compressed=True)["allPapers"]
7267
annotations = {}
7368
for a in dump:
74-
arxiv_id = clean_arxiv_version(a.arxiv_id)
69+
arxiv_id = clear_arxiv_version(a.arxiv_id)
7570
annotations[arxiv_id] = a
7671
return annotations

sota_extractor2/data/table.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,34 @@
11
import pandas as pd
22
import json
33
from pathlib import Path
4-
from dataclasses import dataclass
4+
import re
5+
from dataclasses import dataclass, field
56
from typing import List
67
from ..helpers.jupyter import display_table
78

89
@dataclass
910
class Cell:
1011
value: str
1112
gold_tags: str = ''
12-
refs: List[str] = None
13+
refs: List[str] = field(default_factory=list)
1314

1415

16+
reference_re = re.compile(r"\[(xxref-[^] ]*)\]")
17+
def extract_references(s):
18+
parts = reference_re.split(s)
19+
return ''.join(parts[::2]), parts[1::2]
20+
21+
22+
def str2cell(s):
23+
value, refs = extract_references(s)
24+
return Cell(value=value, refs=refs)
25+
1526
class Table:
1627
def __init__(self, df, caption=None, figure_id=None, annotations=None):
1728
self.df = df
1829
self.caption = caption
1930
self.figure_id = figure_id
20-
self.df = df.applymap(lambda x: Cell(value=x))
31+
self.df = df.applymap(str2cell)
2132
if annotations is not None:
2233
self.gold_tags = annotations.gold_tags.strip()
2334
tags = annotations.matrix_gold_tags

0 commit comments

Comments
 (0)