Skip to content

Commit 51d04ce

Browse files
author
Marcin Kardas
committed
Use new annotations format
1 parent 74a1ae6 commit 51d04ce

File tree

3 files changed

+70
-9
lines changed

3 files changed

+70
-9
lines changed

axcell/data/json.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,62 @@ def cut(s, length=20):
7171
vals = pprint.pformat({to_snake_case(k): cut(str(self[k])) for k in self.keys()})
7272
return f"NodeWrap({vals})"
7373

74+
75+
def _annotations_to_gql(annotations):
76+
nodes = []
77+
for a in annotations:
78+
tables = []
79+
for t in a['tables']:
80+
tags = []
81+
if t['leaderboard']:
82+
tags.append('leaderboard')
83+
if t['ablation']:
84+
tags.append('ablation')
85+
if not tags:
86+
tags = ['irrelevant']
87+
88+
records = {}
89+
for r in t['records']:
90+
d = dict(r)
91+
del d['row']
92+
del d['column']
93+
records[f'{r["row"]}.{r["column"]}'] = d
94+
table = {
95+
'node': {
96+
'name': f'table_{t["index"] + 1:02}.csv',
97+
'datasetText': t['dataset_text'],
98+
'notes': '',
99+
'goldTags': ' '.join(tags),
100+
'matrixGoldTags': t['segmentation'],
101+
'cellsSotaRecords': json.dumps(records),
102+
'parser': 'latexml'
103+
}
104+
}
105+
tables.append(table)
106+
node = {
107+
'arxivId': a['arxiv_id'],
108+
'goldTags': a['fold'],
109+
'tableSet': {'edges': tables}
110+
}
111+
nodes.append({'node': node})
112+
return {
113+
'data': {
114+
'allPapers': {
115+
'edges': nodes
116+
}
117+
}
118+
}
119+
120+
74121
def load_gql_dump(data_or_file, compressed=True):
75-
if isinstance(data_or_file, dict):
122+
if isinstance(data_or_file, dict) or isinstance(data_or_file, list):
76123
papers_data = data_or_file
77124
else:
78125
open_fn = gzip.open if compressed else open
79126
with open_fn(data_or_file, "rt") as f:
80-
papers_data = json.load(f)
127+
papers_data = json.load(f)
128+
if "data" not in papers_data:
129+
papers_data = _annotations_to_gql(papers_data)
81130
data = papers_data["data"]
82131
return {k:wrap_dict(v) for k,v in data.items()}
83132

axcell/data/paper_collection.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,23 +75,32 @@ def _load_tables(path, annotations, jobs, migrate):
7575
return {f.parent.name: tbls for f, tbls in zip(files, tables)}
7676

7777

78+
def _gql_dump_to_annotations(dump):
79+
annotations = {remove_arxiv_version(a.arxiv_id): a for a in dump}
80+
annotations.update({a.arxiv_id: a for a in dump})
81+
return annotations
82+
7883
def _load_annotated_papers(data_or_path):
79-
if isinstance(data_or_path, dict):
84+
if isinstance(data_or_path, dict) or isinstance(data_or_path, list):
8085
compressed = False
8186
else:
8287
compressed = data_or_path.suffix == ".gz"
8388
dump = load_gql_dump(data_or_path, compressed=compressed)["allPapers"]
84-
annotations = {remove_arxiv_version(a.arxiv_id): a for a in dump}
85-
annotations.update({a.arxiv_id: a for a in dump})
86-
return annotations
89+
return _gql_dump_to_annotations(dump)
8790

8891

8992
class PaperCollection(UserList):
9093
def __init__(self, data=None):
9194
super().__init__(data)
9295

9396
@classmethod
94-
def from_files(cls, path, annotations_path=None, load_texts=True, load_tables=True, load_annotations=True, jobs=-1, migrate=False):
97+
def from_files(cls, path, annotations=None, load_texts=True, load_tables=True, jobs=-1):
98+
return cls._from_files(path, annotations=annotations, annotations_path=None,
99+
load_texts=load_texts, load_tables=load_tables, load_annotations=False,
100+
jobs=jobs)
101+
102+
@classmethod
103+
def _from_files(cls, path, annotations=None, annotations_path=None, load_texts=True, load_tables=True, load_annotations=True, jobs=-1, migrate=False):
95104
path = Path(path)
96105
if annotations_path is None:
97106
annotations_path = path / "structure-annotations.json"
@@ -102,7 +111,10 @@ def from_files(cls, path, annotations_path=None, load_texts=True, load_tables=Tr
102111
else:
103112
texts = {}
104113

105-
annotations = {}
114+
if annotations is None:
115+
annotations = {}
116+
else:
117+
annotations = _load_annotated_papers(annotations)
106118
if load_tables:
107119
if load_annotations:
108120
annotations = _load_annotated_papers(annotations_path)

extract_tables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def remove_footnotes(soup):
348348

349349

350350
def extract_tables(html):
351-
soup = BeautifulSoup(html, "lxml", from_encoding="utf-8")
351+
soup = BeautifulSoup(html, "lxml")
352352
set_ids_by_labels(soup)
353353
fix_span_tables(soup)
354354
fix_th(soup)

0 commit comments

Comments
 (0)