Skip to content

Commit 6b85869

Browse files
committed
2 parents 1f999a7 + 6ad4c02 commit 6b85869

File tree

11 files changed

+892
-38
lines changed

11 files changed

+892
-38
lines changed

Makefile

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,12 @@ $(ANNOTATIONS_DIR)/evaluation-tables.json.gz:
7272
$(shell mkdir -p "$(ANNOTATIONS_DIR)")
7373
wget https://paperswithcode.com/media/about/evaluation-tables.json.gz -O $@
7474

75+
.PHONY: pull_images
76+
pull_images:
77+
docker pull arxivvanity/engrafo:b3db888fefa118eacf4f13566204b68ce100b3a6
78+
docker pull zenika/alpine-chrome:73
7579

76-
.PHONY : clean
80+
.PHONY: clean
7781
clean :
7882
cd "$(ANNOTATIONS_DIR)" && rm -f *.json *.csv
7983
#rm -f *.gz

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ Directory structure:
2020
```
2121

2222

23-
To preprocess data and extract tables, run:
23+
To preprocess data and extract tables and texts, run:
2424
```
25+
make pull_images
2526
conda env create -f environment.yml
2627
source activate xtables
2728
make -j 8 -i extract_all > stdout.log 2> stderr.log

helpers.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from fire import Fire
2+
from pathlib import Path
3+
from sota_extractor2.data.paper_collection import PaperCollection
4+
from sota_extractor2.data.structure import CellEvidenceExtractor
5+
from elasticsearch_dsl import connections
6+
from tqdm import tqdm
7+
import pandas as pd
8+
from joblib import delayed, Parallel
9+
10+
class Helper:
11+
def split_pc_pickle(self, path, outdir="pc-parts", parts=8):
12+
outdir = Path(outdir)
13+
outdir.mkdir(parents=True, exist_ok=True)
14+
pc = PaperCollection.from_pickle(path)
15+
step = (len(pc) + parts - 1) // parts
16+
for idx, i in enumerate(range(0, len(pc), step)):
17+
part = PaperCollection(pc[i:i + step])
18+
part.to_pickle(outdir / f"pc-part-{idx:02}.pkl")
19+
20+
def _evidences_for_pc(self, path):
21+
path = Path(path)
22+
pc = PaperCollection.from_pickle(path)
23+
cell_evidences = CellEvidenceExtractor()
24+
connections.create_connection(hosts=['10.0.1.145'], timeout=20)
25+
raw_evidences = []
26+
for paper in tqdm(pc):
27+
raw_evidences.append(cell_evidences(paper, paper.tables, paper_limit=100, corpus_limit=20))
28+
raw_evidences = pd.concat(raw_evidences)
29+
path = path.with_suffix(".evidences.pkl")
30+
raw_evidences.to_pickle(path)
31+
32+
def evidences_for_pc(self, pattern="pc-parts/pc-part-??.pkl", jobs=-1):
33+
pickles = sorted(Path(".").glob(pattern))
34+
Parallel(backend="multiprocessing", n_jobs=jobs)(delayed(self._evidences_for_pc)(path) for path in pickles)
35+
36+
def merge_evidences(self, output="evidences.pkl", pattern="pc-parts/pc-part-*.evidences.pkl"):
37+
pickles = sorted(Path(".").glob(pattern))
38+
evidences = [pd.read_pickle(pickle) for pickle in pickles]
39+
evidences = pd.concat(evidences)
40+
evidences.to_pickle(output)
41+
42+
43+
if __name__ == "__main__": Fire(Helper())

sota_extractor2/data/table.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ def set_tags(self, tags):
146146
# todo: change gold_tags to tags to avoid confusion
147147
self.df.iloc[r,c].gold_tags = cell.strip()
148148

149+
@property
150+
def shape(self):
151+
return self.df.shape
152+
149153
@property
150154
def matrix(self):
151155
return self.df.applymap(lambda x: x.value)

sota_extractor2/helpers/training.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11

2-
def set_seed(seed, name, quiet=False):
2+
def set_seed(seed, name, quiet=False, all_gpus=True):
33
import torch
44
import numpy as np
5+
import random
56
if not quiet:
67
print(f"Setting {name} seed to {seed}")
78
torch.manual_seed(seed)
89
torch.backends.cudnn.deterministic = True
910
torch.backends.cudnn.benchmark = False
10-
np.random.seed(seed)
11+
np.random.seed(seed)
12+
random.seed(seed)
13+
if all_gpus:
14+
torch.cuda.manual_seed_all(seed)

sota_extractor2/models/linking/context_search.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,17 @@
138138
'LibriSpeech dev-other': ['libri speech dev other', 'libri speech', 'dev', 'other', 'dev other', 'development', 'noisy'],
139139
})
140140

141+
tasks = {}
142+
141143
# escaped_ws_re = re.compile(r'\\\s+')
142144
# def name_to_re(name):
143145
# return re.compile(r'(?:^|\s+)' + escaped_ws_re.sub(r'\\s*', re.escape(name.strip())) + r'(?:$|\s+)', re.I)
144146

145147
#all_datasets = set(k for k,v in merged_p.items() if k != '' and not re.match("^\d+$", k) and v.get('NOMATCH', 0.0) < 0.9)
146148
all_datasets = set(normalize_cell_ws(normalize_dataset(y)) for x in datasets.values() for y in x)
147149
all_metrics = set(normalize_cell_ws(y) for x in metrics.values() for y in x)
150+
all_tasks = set(normalize_cell_ws(normalize_dataset(y)) for x in tasks.values() for y in x)
151+
148152
#all_metrics = set(metrics_p.keys())
149153

150154
# all_datasets_re = {x:name_to_re(x) for x in all_datasets}
@@ -187,6 +191,7 @@ def find_names(text, names_trie):
187191

188192
all_datasets_trie = make_trie(all_datasets)
189193
all_metrics_trie = make_trie(all_metrics)
194+
all_tasks_trie = make_trie(all_tasks)
190195

191196

192197
def find_datasets(text):
@@ -195,18 +200,23 @@ def find_datasets(text):
195200
def find_metrics(text):
196201
return find_names(text, all_metrics_trie)
197202

203+
def find_tasks(text):
204+
return find_names(text, all_tasks_trie)
205+
198206
def dummy_item(reason):
199207
return pd.DataFrame(dict(dataset=[reason], task=[reason], metric=[reason], evidence=[""], confidence=[0.0]))
200208

201209

202210

203211
@njit
204-
def compute_logprobs(taxonomy, reverse_merged_p, reverse_metrics_p, dss, mss, noise, ms_noise, ds_pb, ms_pb, logprobs):
212+
def compute_logprobs(taxonomy, reverse_merged_p, reverse_metrics_p, reverse_task_p,
213+
dss, mss, tss, noise, ms_noise, ts_noise, ds_pb, ms_pb, ts_pb, logprobs):
205214
empty = typed.Dict.empty(types.unicode_type, types.float64)
206215
for i, (task, dataset, metric) in enumerate(taxonomy):
207216
logprob = 0.0
208217
short_probs = reverse_merged_p.get(dataset, empty)
209218
met_probs = reverse_metrics_p.get(metric, empty)
219+
task_probs = reverse_task_p.get(task, empty)
210220
for ds in dss:
211221
# for abbrv, long_form in abbrvs.items():
212222
# if ds == abbrv:
@@ -216,17 +226,21 @@ def compute_logprobs(taxonomy, reverse_merged_p, reverse_metrics_p, dss, mss, no
216226
logprob += np.log(noise * ds_pb + (1 - noise) * short_probs.get(ds, 0.0))
217227
for ms in mss:
218228
logprob += np.log(ms_noise * ms_pb + (1 - ms_noise) * met_probs.get(ms, 0.0))
229+
for ts in tss:
230+
logprob += np.log(ts_noise * ts_pb + (1 - ts_noise) * task_probs.get(ts, 0.0))
219231
logprobs[i] += logprob
220232
#logprobs[(dataset, metric)] = logprob
221233

222234

223235
class ContextSearch:
224-
def __init__(self, taxonomy, context_noise=(0.5, 0.2, 0.1), metrics_noise=None, ds_pb=0.001, ms_pb=0.01, debug_gold_df=None):
236+
def __init__(self, taxonomy, context_noise=(0.5, 0.2, 0.1), metrics_noise=None, task_noise=None,
237+
ds_pb=0.001, ms_pb=0.01, ts_pb=0.01, debug_gold_df=None):
225238
merged_p = \
226239
get_probs({k: Counter([normalize_cell(normalize_dataset(x)) for x in v]) for k, v in datasets.items()})[1]
227240
metrics_p = \
228241
get_probs({k: Counter([normalize_cell(normalize_dataset(x)) for x in v]) for k, v in metrics.items()})[1]
229-
242+
tasks_p = \
243+
get_probs({k: Counter([normalize_cell(normalize_dataset(x)) for x in v]) for k, v in tasks.items()})[1]
230244

231245
self.queries = {}
232246
self.taxonomy = taxonomy
@@ -236,10 +250,13 @@ def __init__(self, taxonomy, context_noise=(0.5, 0.2, 0.1), metrics_noise=None,
236250
self.extract_acronyms = AcronymExtractor()
237251
self.context_noise = context_noise
238252
self.metrics_noise = metrics_noise if metrics_noise else context_noise
253+
self.task_noise = task_noise if task_noise else context_noise
239254
self.ds_pb = ds_pb
240255
self.ms_pb = ms_pb
256+
self.ts_pb = ts_pb
241257
self.reverse_merged_p = self._numba_update_nested_dict(reverse_probs(merged_p))
242258
self.reverse_metrics_p = self._numba_update_nested_dict(reverse_probs(metrics_p))
259+
self.reverse_tasks_p = self._numba_update_nested_dict(reverse_probs(tasks_p))
243260
self.debug_gold_df = debug_gold_df
244261

245262
def _numba_update_nested_dict(self, nested):
@@ -256,29 +273,33 @@ def _numba_extend_list(self, lst):
256273
l.append(x)
257274
return l
258275

259-
def compute_context_logprobs(self, context, noise, ms_noise, logprobs):
276+
def compute_context_logprobs(self, context, noise, ms_noise, ts_noise, logprobs):
260277
context = context or ""
261278
abbrvs = self.extract_acronyms(context)
262279
context = normalize_cell_ws(normalize_dataset(context))
263280
dss = set(find_datasets(context)) | set(abbrvs.keys())
264281
mss = set(find_metrics(context))
282+
tss = set(find_tasks(context))
265283
dss -= mss
284+
dss -= tss
266285
dss = [normalize_cell(ds) for ds in dss]
267286
mss = [normalize_cell(ms) for ms in mss]
287+
tss = [normalize_cell(ts) for ts in tss]
268288
###print("dss", dss)
269289
###print("mss", mss)
270290
dss = self._numba_extend_list(dss)
271291
mss = self._numba_extend_list(mss)
272-
compute_logprobs(self._taxonomy, self.reverse_merged_p, self.reverse_metrics_p,
273-
dss, mss, noise, ms_noise, self.ds_pb, self.ms_pb, logprobs)
292+
tss = self._numba_extend_list(tss)
293+
compute_logprobs(self._taxonomy, self.reverse_merged_p, self.reverse_metrics_p, self.reverse_tasks_p,
294+
dss, mss, tss, noise, ms_noise, ts_noise, self.ds_pb, self.ms_pb, self.ts_pb, logprobs)
274295

275296
def match(self, contexts):
276297
assert len(contexts) == len(self.context_noise)
277298
n = len(self._taxonomy)
278299
context_logprobs = np.zeros(n)
279300

280-
for context, noise, ms_noise in zip(contexts, self.context_noise, self.metrics_noise):
281-
self.compute_context_logprobs(context, noise, ms_noise, context_logprobs)
301+
for context, noise, ms_noise, ts_noise in zip(contexts, self.context_noise, self.metrics_noise, self.task_noise):
302+
self.compute_context_logprobs(context, noise, ms_noise, ts_noise, context_logprobs)
282303
keys = self.taxonomy.taxonomy
283304
logprobs = context_logprobs
284305
#keys, logprobs = zip(*context_logprobs.items())

sota_extractor2/models/structure/experiment.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,40 @@ class Labels(Enum):
2020
EMPTY=5
2121

2222

23+
class LabelsExt(Enum):
24+
OTHER=0
25+
PARAMS=6
26+
TASK=7
27+
DATASET=1
28+
SUBDATASET=8
29+
PAPER_MODEL=2
30+
BEST_MODEL=9
31+
ENSEMBLE_MODEL=10
32+
COMPETING_MODEL=3
33+
METRIC=4
34+
EMPTY=5
35+
36+
2337
label_map = {
2438
"dataset": Labels.DATASET.value,
2539
"dataset-sub": Labels.DATASET.value,
2640
"model-paper": Labels.PAPER_MODEL.value,
2741
"model-best": Labels.PAPER_MODEL.value,
2842
"model-ensemble": Labels.PAPER_MODEL.value,
2943
"model-competing": Labels.COMPETING_MODEL.value,
30-
"dataset-metric": Labels.METRIC.value,
31-
# "model-params": Labels.PARAMS.value
44+
"dataset-metric": Labels.METRIC.value
45+
}
46+
47+
label_map_ext = {
48+
"dataset": LabelsExt.DATASET.value,
49+
"dataset-sub": LabelsExt.SUBDATASET.value,
50+
"model-paper": LabelsExt.PAPER_MODEL.value,
51+
"model-best": LabelsExt.BEST_MODEL.value,
52+
"model-ensemble": LabelsExt.ENSEMBLE_MODEL.value,
53+
"model-competing": LabelsExt.COMPETING_MODEL.value,
54+
"dataset-metric": LabelsExt.METRIC.value,
55+
"model-params": LabelsExt.PARAMS.value,
56+
"dataset-task": LabelsExt.TASK.value
3257
}
3358

3459
# put here to avoid recompiling, used only in _limit_context
@@ -63,6 +88,7 @@ class Experiment:
6388
remove_num: bool = True
6489
drop_duplicates: bool = True
6590
mark_this_paper: bool = False
91+
distinguish_model_source: bool = True
6692

6793
results: dict = dataclasses.field(default_factory=dict)
6894

@@ -219,6 +245,8 @@ def _transform_df(self, df):
219245
df = df.replace(re.compile(r"(^|[ ])\d+(\b|%)"), " xxnum ")
220246
df = df.replace(re.compile(r"\bdata set\b"), " dataset ")
221247
df["label"] = df["cell_type"].apply(lambda x: label_map.get(x, 0))
248+
if not self.distinguish_model_source:
249+
df["label"] = df["label"].apply(lambda x: x if x != Labels.COMPETING_MODEL.value else Labels.PAPER_MODEL.value)
222250
df["label"] = pd.Categorical(df["label"])
223251
return df
224252

@@ -228,13 +256,15 @@ def transform_df(self, *dfs):
228256
return transformed[0]
229257
return transformed
230258

231-
def _set_results(self, prefix, preds, true_y):
259+
def _set_results(self, prefix, preds, true_y, true_y_ext=None):
232260
m = metrics(preds, true_y)
233261
r = {}
234262
r[f"{prefix}_accuracy"] = m["accuracy"]
235263
r[f"{prefix}_precision"] = m["precision"]
236264
r[f"{prefix}_recall"] = m["recall"]
237265
r[f"{prefix}_cm"] = confusion_matrix(true_y, preds, labels=[x.value for x in Labels]).tolist()
266+
if true_y_ext is not None:
267+
r[f"{prefix}_cm_full"] = confusion_matrix(true_y_ext, preds, labels=[x.value for x in LabelsExt]).tolist()
238268
self.update_results(**r)
239269

240270
def evaluate(self, model, train_df, valid_df, test_df):
@@ -253,17 +283,19 @@ def evaluate(self, model, train_df, valid_df, test_df):
253283
true_y = vote_results["true"]
254284
else:
255285
true_y = tdf["label"]
256-
self._set_results(prefix, preds, true_y)
286+
true_y_ext = tdf["cell_type"].apply(lambda x: label_map_ext.get(x, 0))
287+
self._set_results(prefix, preds, true_y, true_y_ext)
257288

258-
def show_results(self, *ds, normalize=True):
289+
def show_results(self, *ds, normalize=True, full_cm=True):
259290
if not len(ds):
260291
ds = ["train", "valid", "test"]
261292
for prefix in ds:
262293
print(f"{prefix} dataset")
263294
print(f" * accuracy: {self.results[f'{prefix}_accuracy']:.3f}")
264295
print(f" * μ-precision: {self.results[f'{prefix}_precision']:.3f}")
265296
print(f" * μ-recall: {self.results[f'{prefix}_recall']:.3f}")
266-
self._plot_confusion_matrix(np.array(self.results[f'{prefix}_cm']), normalize=normalize)
297+
suffix = '_full' if full_cm and f'{prefix}_cm_full' in self.results else ''
298+
self._plot_confusion_matrix(np.array(self.results[f'{prefix}_cm{suffix}']), normalize=normalize)
267299

268300
def _plot_confusion_matrix(self, cm, normalize, fmt=None):
269301
if normalize:
@@ -272,7 +304,12 @@ def _plot_confusion_matrix(self, cm, normalize, fmt=None):
272304
cm = cm / s
273305
if fmt is None:
274306
fmt = "0.2f" if normalize else "d"
275-
target_names = ["OTHER", "DATASET", "MODEL (paper)", "MODEL (comp.)", "METRIC", "EMPTY"]
307+
308+
if len(cm) == 6:
309+
target_names = ["OTHER", "DATASET", "MODEL (paper)", "MODEL (comp.)", "METRIC", "EMPTY"]
310+
else:
311+
target_names = ["OTHER", "params", "task", "DATASET", "subdataset", "MODEL (paper)", "model (best)",
312+
"model (ens.)", "MODEL (comp.)", "METRIC", "EMPTY"]
276313
df_cm = pd.DataFrame(cm, index=[i for i in target_names],
277314
columns=[i for i in target_names])
278315
plt.figure(figsize=(10, 10))

0 commit comments

Comments
 (0)