Skip to content

Commit eda6a36

Browse files
author
Marcin Kardas
committed
Use abbreviations dicts by default
1 parent ed6aa99 commit eda6a36

File tree

1 file changed

+40
-9
lines changed

1 file changed

+40
-9
lines changed

axcell/models/linking/context_search.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,17 @@
1010
import re
1111
import pandas as pd
1212
import numpy as np
13+
import json
1314
import ahocorasick
1415
from numba import njit, typed, types
16+
from pathlib import Path
1517

1618
from axcell.pipeline_logger import pipeline_logger
1719

1820
from axcell.models.linking import manual_dicts
1921
from collections import Counter
2022

23+
2124
def dummy_item(reason):
2225
return pd.DataFrame(dict(dataset=[reason], task=[reason], metric=[reason], evidence=[""], confidence=[0.0]))
2326

@@ -28,7 +31,9 @@ class EvidenceFinder:
2831
end_letter_re = re.compile(r"\w\b")
2932
letter_re = re.compile(r"\w")
3033

31-
def __init__(self, taxonomy):
34+
def __init__(self, taxonomy, abbreviations_path=None, use_manual_dicts=False):
35+
self.abbreviations_path = abbreviations_path
36+
self.use_manual_dicts = use_manual_dicts
3237
self._init_structs(taxonomy)
3338

3439
@staticmethod
@@ -58,6 +63,14 @@ def make_trie(names):
5863
trie.make_automaton()
5964
return trie
6065

66+
@staticmethod
67+
def get_auto_evidences(name, abbreviations, abbrvs_trie):
68+
frags = EvidenceFinder.find_names(normalize_dataset_ws(name), abbrvs_trie)
69+
evidences = []
70+
for f in frags:
71+
evidences.extend(abbreviations[f])
72+
return list(set(evidences))
73+
6174
@staticmethod
6275
def find_names(text, names_trie):
6376
text = text.lower()
@@ -84,15 +97,30 @@ def find_tasks(self, text):
8497

8598
def init_evidence_dicts(self, taxonomy):
8699
self.tasks, self.datasets, self.metrics = EvidenceFinder.get_basic_dicts(taxonomy)
87-
EvidenceFinder.merge_evidences(self.tasks, manual_dicts.tasks)
88-
EvidenceFinder.merge_evidences(self.datasets, manual_dicts.datasets)
89-
EvidenceFinder.merge_evidences(self.metrics, manual_dicts.metrics)
100+
101+
if self.use_manual_dicts:
102+
EvidenceFinder.merge_evidences(self.tasks, manual_dicts.tasks)
103+
EvidenceFinder.merge_evidences(self.datasets, manual_dicts.datasets)
104+
EvidenceFinder.merge_evidences(self.metrics, manual_dicts.metrics)
105+
106+
if self.abbreviations_path is not None:
107+
with Path(self.abbreviations_path).open('rt') as f:
108+
abbreviations = json.load(f)
109+
abbrvs_trie = EvidenceFinder.make_trie(list(abbreviations.keys()))
110+
111+
ds_auto = {x: EvidenceFinder.get_auto_evidences(x, abbreviations, abbrvs_trie) for x in taxonomy.datasets}
112+
ms_auto = {x: EvidenceFinder.get_auto_evidences(x, abbreviations, abbrvs_trie) for x in taxonomy.metrics}
113+
114+
EvidenceFinder.merge_evidences(self.datasets, ds_auto)
115+
EvidenceFinder.merge_evidences(self.metrics, ms_auto)
116+
90117
self.datasets = {k: (v + ['test'] if 'val' not in k else v + ['validation', 'dev', 'development']) for k, v in
91118
self.datasets.items()}
92-
self.datasets.update({
93-
'LibriSpeech dev-clean': ['libri speech dev clean', 'libri speech', 'dev', 'clean', 'dev clean', 'development'],
94-
'LibriSpeech dev-other': ['libri speech dev other', 'libri speech', 'dev', 'other', 'dev other', 'development', 'noisy'],
95-
})
119+
if self.use_manual_dicts:
120+
self.datasets.update({
121+
'LibriSpeech dev-clean': ['libri speech dev clean', 'libri speech', 'dev', 'clean', 'dev clean', 'development'],
122+
'LibriSpeech dev-other': ['libri speech dev other', 'libri speech', 'dev', 'other', 'dev other', 'development', 'noisy'],
123+
})
96124

97125
def _init_structs(self, taxonomy):
98126
self.init_evidence_dicts(taxonomy)
@@ -163,7 +191,10 @@ def _to_typed_list(iterable):
163191

164192

165193
class ContextSearch:
166-
def __init__(self, taxonomy, evidence_finder, context_noise=(0.5, 0.1, 0.2, 0.2, 0.1), metric_noise=None, task_noise=None,
194+
def __init__(self, taxonomy, evidence_finder,
195+
context_noise=(0.99, 1.0, 1.0, 0.25, 0.01),
196+
metric_noise=(0.99, 1.0, 1.0, 0.25, 0.01),
197+
task_noise=(0.1, 1.0, 1.0, 0.1, 0.1),
167198
ds_pb=0.001, ms_pb=0.01, ts_pb=0.01, debug_gold_df=None):
168199
merged_p = \
169200
get_probs({k: Counter([normalize_cell(normalize_dataset(x)) for x in v]) for k, v in evidence_finder.datasets.items()})[1]

0 commit comments

Comments
 (0)