Skip to content

Commit 437a68e

Browse files
author
Marcin Kardas
committed
Use gold_sota_records for error analysis
* change taxonomy format to list of triplets * change hyphen to spaces when normalizing datasets
1 parent a69176e commit 437a68e

File tree

8 files changed

+49
-32
lines changed

8 files changed

+49
-32
lines changed

sota_extractor2/helpers/explainers.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,11 @@ def _repr_html_(self):
9191
class Explainer:
9292
_sota_record_columns = ['task', 'dataset', 'metric', 'format', 'model', 'model_type', 'raw_value', 'parsed']
9393

94-
def __init__(self, pipeline_logger, paper_collection):
94+
def __init__(self, pipeline_logger, paper_collection, gold_sota_records=None):
9595
self.paper_collection = paper_collection
96+
self.gold_sota_records = gold_sota_records
9697
self.spe = StructurePredictionEvaluator(pipeline_logger, paper_collection)
97-
self.le = LinkerEvaluator(pipeline_logger, paper_collection)
98+
self.le = LinkerEvaluator(pipeline_logger)
9899
self.fe = FilteringEvaluator(pipeline_logger)
99100

100101
def explain(self, paper, cell_ext_id):
@@ -179,11 +180,20 @@ def linking_metrics(self, experiment_name="unk"):
179180
print(", ".join(missing))
180181
papers = [paper for paper in papers.values() if paper is not None]
181182

182-
if not len(papers):
183+
# if not len(papers):
184+
# gold_sota_records = pd.DataFrame(columns=self._sota_record_columns)
185+
# gold_sota_records.index.rename("cell_ext_id", inplace=True)
186+
# else:
187+
# gold_sota_records = pd.concat([self._get_sota_records(paper) for paper in papers])
188+
if self.gold_sota_records is None:
183189
gold_sota_records = pd.DataFrame(columns=self._sota_record_columns)
184190
gold_sota_records.index.rename("cell_ext_id", inplace=True)
185191
else:
186-
gold_sota_records = pd.concat([self._get_sota_records(paper) for paper in papers])
192+
193+
gold_sota_records = self.gold_sota_records
194+
which = gold_sota_records.index.to_series().str.split("/", expand=True)[0]\
195+
.isin([paper.paper_id for paper in papers])
196+
gold_sota_records = gold_sota_records[which]
187197

188198
df = gold_sota_records.merge(proposals, 'outer', left_index=True, right_index=True, suffixes=['_gold', '_pred'])
189199
df = df.reindex(sorted(df.columns), axis=1)

sota_extractor2/loggers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def get_table_type_predictions(self, paper_id, table_name):
144144

145145

146146
class LinkerEvaluator:
147-
def __init__(self, pipeline_logger, pc):
147+
def __init__(self, pipeline_logger):
148148
pipeline_logger.register("linking::call", self.on_before_linking)
149149
pipeline_logger.register("linking::taxonomy_linking::call", self.on_before_taxonomy)
150150
pipeline_logger.register("linking::taxonomy_linking::topk", self.on_taxonomy_topk)

sota_extractor2/models/linking/acronym_extractor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import spacy
22
from scispacy.abbreviation import AbbreviationDetector
3-
from .utils import normalize_cell, normalize_dataset
3+
from .utils import normalize_cell, normalize_dataset_ws
44

55
class AcronymExtractor:
66
def __init__(self):
@@ -14,7 +14,7 @@ def __call__(self, text):
1414
abbrvs = {}
1515
for abrv in doc._.abbreviations:
1616
# abbrvs.setdefault(normalize_cell(str(abrv)), Counter())[str(abrv._.long_form)] += 1
17-
norm = normalize_cell(normalize_dataset(str(abrv)))
17+
norm = normalize_cell(normalize_dataset_ws(str(abrv)))
1818
if norm != '':
19-
abbrvs[norm] = normalize_cell(normalize_dataset(str(abrv._.long_form)))
19+
abbrvs[norm] = normalize_cell(normalize_dataset_ws(str(abrv._.long_form)))
2020
return abbrvs

sota_extractor2/models/linking/context_search.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from sota_extractor2.models.linking.acronym_extractor import AcronymExtractor
55
from sota_extractor2.models.linking.probs import get_probs, reverse_probs
6-
from sota_extractor2.models.linking.utils import normalize_dataset, normalize_cell, normalize_cell_ws
6+
from sota_extractor2.models.linking.utils import normalize_dataset_ws, normalize_cell, normalize_cell_ws
77
from scipy.special import softmax
88
import re
99
import pandas as pd
@@ -201,9 +201,9 @@ def dummy_item(reason):
201201

202202

203203
@njit
204-
def compute_logprobs(dataset_metric, reverse_merged_p, reverse_metrics_p, dss, mss, noise, logprobs):
204+
def compute_logprobs(taxonomy, reverse_merged_p, reverse_metrics_p, dss, mss, noise, logprobs):
205205
empty = typed.Dict.empty(types.unicode_type, types.float64)
206-
for i, (dataset, metric) in enumerate(dataset_metric):
206+
for i, (task, dataset, metric) in enumerate(taxonomy):
207207
logprob = 0.0
208208
short_probs = reverse_merged_p.get(dataset, empty)
209209
met_probs = reverse_metrics_p.get(metric, empty)
@@ -223,16 +223,16 @@ def compute_logprobs(dataset_metric, reverse_merged_p, reverse_metrics_p, dss, m
223223
class ContextSearch:
224224
def __init__(self, taxonomy, context_noise=(0.5, 0.2, 0.1), debug_gold_df=None):
225225
merged_p = \
226-
get_probs({k: Counter([normalize_cell(normalize_dataset(x)) for x in v]) for k, v in datasets.items()})[1]
226+
get_probs({k: Counter([normalize_cell(normalize_dataset_ws(x)) for x in v]) for k, v in datasets.items()})[1]
227227
metrics_p = \
228-
get_probs({k: Counter([normalize_cell(normalize_dataset(x)) for x in v]) for k, v in metrics.items()})[1]
228+
get_probs({k: Counter([normalize_cell(normalize_dataset_ws(x)) for x in v]) for k, v in metrics.items()})[1]
229229

230230

231231
self.queries = {}
232232
self.taxonomy = taxonomy
233-
self._dataset_metric = typed.List()
233+
self._taxonomy = typed.List()
234234
for t in self.taxonomy.taxonomy:
235-
self._dataset_metric.append(t)
235+
self._taxonomy.append(t)
236236
self.extract_acronyms = AcronymExtractor()
237237
self.context_noise = context_noise
238238
self.reverse_merged_p = self._numba_update_nested_dict(reverse_probs(merged_p))
@@ -254,8 +254,9 @@ def _numba_extend_list(self, lst):
254254
return l
255255

256256
def compute_context_logprobs(self, context, noise, logprobs):
257+
context = context or ""
257258
abbrvs = self.extract_acronyms(context)
258-
context = normalize_cell_ws(normalize_dataset(context))
259+
context = normalize_cell_ws(normalize_dataset_ws(context))
259260
dss = set(find_datasets(context)) | set(abbrvs.keys())
260261
mss = set(find_metrics(context))
261262
dss -= mss
@@ -265,16 +266,16 @@ def compute_context_logprobs(self, context, noise, logprobs):
265266
###print("mss", mss)
266267
dss = self._numba_extend_list(dss)
267268
mss = self._numba_extend_list(mss)
268-
compute_logprobs(self._dataset_metric, self.reverse_merged_p, self.reverse_metrics_p, dss, mss, noise, logprobs)
269+
compute_logprobs(self._taxonomy, self.reverse_merged_p, self.reverse_metrics_p, dss, mss, noise, logprobs)
269270

270271
def match(self, contexts):
271272
assert len(contexts) == len(self.context_noise)
272-
n = len(self._dataset_metric)
273-
context_logprobs = np.ones(n)
273+
n = len(self._taxonomy)
274+
context_logprobs = np.zeros(n)
274275

275276
for context, noise in zip(contexts, self.context_noise):
276277
self.compute_context_logprobs(context, noise, context_logprobs)
277-
keys = self.taxonomy.taxonomy.keys()
278+
keys = self.taxonomy.taxonomy
278279
logprobs = context_logprobs
279280
#keys, logprobs = zip(*context_logprobs.items())
280281
probs = softmax(np.array(logprobs))
@@ -290,12 +291,12 @@ def __call__(self, query, datasets, caption, debug_info=None):
290291
###print("query:", query, caption)
291292
if key in self.queries:
292293
# print(self.queries[key])
293-
for context in key:
294-
abbrvs = self.extract_acronyms(context)
295-
context = normalize_cell_ws(normalize_dataset(context))
296-
dss = set(find_datasets(context)) | set(abbrvs.keys())
297-
mss = set(find_metrics(context))
298-
dss -= mss
294+
# for context in key:
295+
# abbrvs = self.extract_acronyms(context)
296+
# context = normalize_cell_ws(normalize_dataset_ws(context))
297+
# dss = set(find_datasets(context)) | set(abbrvs.keys())
298+
# mss = set(find_metrics(context))
299+
# dss -= mss
299300
###print("dss", dss)
300301
###print("mss", mss)
301302

@@ -307,7 +308,8 @@ def __call__(self, query, datasets, caption, debug_info=None):
307308

308309
entries = []
309310
for it, prob in topk:
310-
entry = dict(self.taxonomy.taxonomy[it])
311+
task, dataset, metric = it
312+
entry = dict(task=task, dataset=dataset, metric=metric)
311313
entry.update({"evidence": "", "confidence": prob})
312314
entries.append(entry)
313315

@@ -351,4 +353,4 @@ def from_paper(self, paper):
351353
return self(text)
352354

353355
def __call__(self, text):
354-
return find_datasets(normalize_cell_ws(normalize_dataset(text)))
356+
return find_datasets(normalize_cell_ws(normalize_dataset_ws(text)))

sota_extractor2/models/linking/taxonomy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pathlib import Path
22
import json
3+
from collections import OrderedDict
34

45

56

@@ -14,7 +15,7 @@ def _read_json(self, path):
1415

1516
def _read_taxonomy(self, path):
1617
records = self._read_json(path)
17-
return {(x['dataset'], x['metric']): x for x in records}
18+
return [(r["task"], r["dataset"], r["metric"]) for r in records]
1819

1920
def _read_metrics_info(self, path):
2021
records = self._read_json(path)

sota_extractor2/models/linking/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def clean_cell(cell):
3030
def remove_references(s):
3131
return refs_re.sub("", s)
3232

33-
def normalize_dataset2(name):
33+
def normalize_dataset_ws(name):
3434
name = remove_references(name)
3535
name = hyphens_re.sub(" ", name)
3636
name = year_2k_re.sub(r"\1", name)

sota_extractor2/models/structure/structure_predictor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ def to_tables(self, df, transpose=False):
122122
return X_tables, C_tables, ids
123123

124124
def merge_with_preds(self, df, preds):
125+
if not len(df):
126+
return []
125127
ext_id = df.ext_id.str.split("/", expand=True)
126128
return list(zip(ext_id[0] + "/" + ext_id[1], ext_id[2].astype(int), ext_id[3].astype(int),
127129
preds, df.text, df.cell_content, df.cell_layout, df.cell_styles, df.cell_reference, df.label))

sota_extractor2/models/structure/type_predictor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ def predict(self, paper, tables):
3030
if len(tables) == 0:
3131
predictions = []
3232
else:
33-
df = pd.DataFrame({"caption": [table.caption if table.caption else "" for table in tables]})
34-
tl = TextList.from_df(df, cols="caption")
33+
column = "caption"
34+
df = pd.DataFrame({column: [table.caption if table.caption else "Table" for table in tables]})
35+
inputs = df.iloc[:, df_names_to_idx(column, df)]
36+
tl = TextList(items=inputs.values[:, 0], path='.', inner_df=df, processor=None)
3537
self.learner.data.add_test(tl)
3638
preds, _ = self.learner.get_preds(DatasetType.Test, ordered=True)
3739
pipeline_logger(f"{TableTypePredictor.step}::multiclass_predicted", paper=paper, tables=tables,

0 commit comments

Comments
 (0)